# koki2 — Stage 1 (L1.0 deplete/respawn + L0.2 harmful sources)\n\nThis notebook runs the L1.0 **deplete/respawn** variant of the L0.2 **harmful sources** experiment described in `thesis/12_IMPLEMENTATION_ENVIRONMENT_LADDER_SPEC.md`.\n\nRepo hygiene tip: if you opened this from GitHub, use **File → Save a copy in Drive** (instead of saving back to GitHub) to avoid committing execution outputs/metadata.\n\n1. **Runtime → Change runtime type** → select **GPU** (L4/T4/A100, etc.) or **TPU** (v5/v6e, etc.).\n2. Run the cells top-to-bottom.\n

## (Optional) Mount Google Drive\n\nIf you want `runs/` to persist across sessions, mount Drive and set `OUT_ROOT` to a Drive path (later).\n

In [None]:
# from google.colab import drive\n# drive.mount('/content/drive')\n

## Clone the repo\n

In [None]:
REPO_URL = "https://github.com/Krisztiaan/koki2.git"  # override if using a fork\nREPO_DIR = "koki2"\n\nimport os\nimport pathlib\nimport subprocess\n\n\ndef _run(cmd: list[str]) -> None:\n    proc = subprocess.run(cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)\n    if proc.returncode != 0:\n        print(proc.stdout)\n        proc.check_returncode()\n\n\nrepo_path = pathlib.Path(REPO_DIR)\nif not repo_path.exists():\n    _run(["git", "clone", "--depth", "1", "--quiet", REPO_URL, REPO_DIR])\n\nos.chdir(REPO_DIR)\nprint("Working dir:", pathlib.Path.cwd())\n

## Install JAX for your accelerator\n\nThis picks **TPU** if a TPU runtime is detected, otherwise picks **GPU** if `nvidia-smi` is available, else falls back to **CPU**.\n\nIf you manually change the runtime accelerator type after this, re-run this cell and restart the runtime.\n

In [None]:
import os\nimport subprocess\nimport sys\n\n\nos.environ.setdefault("PIP_DISABLE_PIP_VERSION_CHECK", "1")\n\n\ndef _run(cmd: list[str], *, check: bool = True) -> None:\n    proc = subprocess.run(cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)\n    if check and proc.returncode != 0:\n        print(proc.stdout)\n        proc.check_returncode()\n\n\ndef _pip(*args: str) -> None:\n    _run([sys.executable, "-m", "pip", "install", "-q", "-U", *args], check=True)\n\n\ndef _pip_uninstall(*args: str) -> None:\n    _run([sys.executable, "-m", "pip", "uninstall", "-q", "-y", *args], check=False)\n\n\ndef _is_tpu_runtime() -> bool:\n    return any(k in os.environ for k in ["COLAB_TPU_ADDR", "TPU_NAME", "XRT_TPU_CONFIG"])\n\n\ndef _has_nvidia_smi() -> bool:\n    return subprocess.run(["bash", "-lc", "command -v nvidia-smi"], check=False).returncode == 0\n\n\naccelerator = "tpu" if _is_tpu_runtime() else ("gpu" if _has_nvidia_smi() else "cpu")\nprint("Detected accelerator:", accelerator)\n\n# Colab-friendly default: avoid preallocating most GPU memory up-front.\nif accelerator == "gpu":\n    os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")\n\n# Keep the environment clean if you rerun this cell after switching runtime type.\n_pip_uninstall("jax", "jaxlib")\n\n_pip("pip")\n\n# JAX wheels/indices (adjust if Colab images change).\nJAX_CPU_PKG = "jax"\nJAX_GPU_PKG = "jax[cuda12_pip]"  # if this fails, try: jax[cuda11_pip]\nJAX_GPU_WHL_INDEX = "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"\nJAX_TPU_PKG = "jax[tpu]"\nJAX_TPU_WHL_INDEX = "https://storage.googleapis.com/jax-releases/libtpu_releases.html"\n\nif accelerator == "tpu":\n    _pip(JAX_TPU_PKG, "-f", JAX_TPU_WHL_INDEX)\nelif accelerator == "gpu":\n    _pip(JAX_GPU_PKG, "-f", JAX_GPU_WHL_INDEX)\nelse:\n    _pip(JAX_CPU_PKG)\n\nprint("JAX install finished.")\nprint("If you see import/backend issues below, use Runtime → Restart runtime, then rerun from the top.")\n

## Install koki2 (editable)\n

In [None]:
import subprocess\nimport sys\n\nproc = subprocess.run(\n    [sys.executable, "-m", "pip", "install", "-q", "-e", ".", "--no-deps"],\n    text=True,\n    stdout=subprocess.PIPE,\n    stderr=subprocess.STDOUT,\n)\nif proc.returncode != 0:\n    print(proc.stdout)\nproc.check_returncode()\n

## Sanity check: JAX sees the accelerator\n

In [None]:
import jax\n\nprint("jax:", jax.__version__)\nprint("backend:", jax.default_backend())\nprint("devices:")\nfor d in jax.devices():\n    print(" -", d)\n

## Run: multi-seed ES + held-out eval\n\nDefaults match the larger-budget L1.0 deplete/respawn run in `WORK.md` (update as needed).\n

In [None]:
import re\nimport subprocess\n\nOUT_ROOT = "runs/stage1_es_big_l10"  # or e.g. "/content/drive/MyDrive/koki2_runs/stage1_es_big_l10"\nTAG = "stage1_l10_deplete_badsrc_g200_p128_ep8"\nSEED_START = 0\nSEED_COUNT = 5\nGENERATIONS = 200\nPOP_SIZE = 128\nTRAIN_EPISODES = 8\nSTEPS = 128\nLOG_EVERY = 10\n\nEVAL_EPISODES = 512\nEVAL_SEED = 424242\n\ncmd = [\n    "koki2",\n    "batch-evo-l0",\n    "--seed-start", str(SEED_START),\n    "--seed-count", str(SEED_COUNT),\n    "--out-root", OUT_ROOT,\n    "--tag", TAG,\n    "--generations", str(GENERATIONS),\n    "--pop-size", str(POP_SIZE),\n    "--episodes", str(TRAIN_EPISODES),\n    "--steps", str(STEPS),\n    "--deplete-sources",\n    "--respawn-delay", "4",\n    "--num-sources", "4",\n    "--num-bad-sources", "2",\n    "--bad-source-integrity-loss", "0.25",\n    "--log-every", str(LOG_EVERY),\n]\n\nproc = subprocess.run(cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)\nif proc.returncode != 0:\n    print(proc.stdout)\nproc.check_returncode()\n\nprint("\n".join(proc.stdout.splitlines()[-20:]))\nout_dirs = re.findall(r"out_dir=(\\S+)", proc.stdout)\nprint("out_dirs:", out_dirs)\n\nfor r in out_dirs:\n    for policy in ["greedy", "random"]:\n        cmd_eval = [\n            "koki2",\n            "eval-run",\n            "--run-dir", r,\n            "--episodes", str(EVAL_EPISODES),\n            "--seed", str(EVAL_SEED),\n            "--baseline-policy", policy,\n        ]\n        proc2 = subprocess.run(cmd_eval, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)\n        if proc2.returncode != 0:\n            print(proc2.stdout)\n        proc2.check_returncode()\n        print(proc2.stdout.strip())\n        print("-")\n