# koki2 on Google Colab\n\nThis notebook bootstraps `koki2` on **CPU/GPU/TPU** Colab runtimes.\n\n1. **Runtime → Change runtime type** → select **GPU** (L4/T4/A100, etc.) or **TPU** (v5/v6e, etc.).\n2. Run the cells top-to-bottom.\n3. If you change JAX/JAXlib installs, **restart the runtime** when prompted.\n\nRepo hygiene tip: if you opened this from GitHub, use **File → Save a copy in Drive** (instead of saving back to GitHub) to avoid committing execution outputs/metadata.\n

## (Optional) Mount Google Drive\n\nIf you want `runs/` to persist across sessions, mount Drive and write outputs there (e.g. set `--out-dir /content/drive/MyDrive/koki2_runs/...`).\n

In [None]:
# from google.colab import drive\n# drive.mount('/content/drive')\n

## Clone the repo\n

In [None]:
REPO_URL = "https://github.com/Krisztiaan/koki2.git"  # override if using a fork\nREPO_DIR = "koki2"\n\nimport os\nimport pathlib\nimport subprocess\n\n\ndef _run(cmd: list[str]) -> None:\n    proc = subprocess.run(cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)\n    if proc.returncode != 0:\n        print(proc.stdout)\n        proc.check_returncode()\nrepo_path = pathlib.Path(REPO_DIR)\nif not repo_path.exists():\n    _run(["git", "clone", "--depth", "1", "--quiet", REPO_URL, REPO_DIR])\n\nos.chdir(REPO_DIR)\nprint("Working dir:", pathlib.Path.cwd())\n

## Install JAX for your accelerator\n\nThis picks **TPU** if a TPU runtime is detected, otherwise picks **GPU** if `nvidia-smi` is available, else falls back to **CPU**.\n\nIf you manually change the runtime accelerator type after this, re-run this cell and restart the runtime.\n

In [None]:
import os\nimport subprocess\nimport sys\n\n\nos.environ.setdefault("PIP_DISABLE_PIP_VERSION_CHECK", "1")\n\n\ndef _run(cmd: list[str], *, check: bool = True) -> None:\n    proc = subprocess.run(cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)\n    if check and proc.returncode != 0:\n        print(proc.stdout)\n        proc.check_returncode()\n\n\ndef _pip(*args: str) -> None:\n    _run([sys.executable, "-m", "pip", "install", "-q", "-U", *args], check=True)\n\n\ndef _pip_uninstall(*args: str) -> None:\n    _run([sys.executable, "-m", "pip", "uninstall", "-q", "-y", *args], check=False)\n\n\ndef _is_tpu_runtime() -> bool:\n    return any(k in os.environ for k in ["COLAB_TPU_ADDR", "TPU_NAME", "XRT_TPU_CONFIG"])\n\n\ndef _has_nvidia_smi() -> bool:\n    return subprocess.run(["bash", "-lc", "command -v nvidia-smi"], check=False).returncode == 0\n\n\naccelerator = "tpu" if _is_tpu_runtime() else ("gpu" if _has_nvidia_smi() else "cpu")\nprint("Detected accelerator:", accelerator)\n\n# Colab-friendly default: avoid preallocating most GPU memory up-front.\nif accelerator == "gpu":\n    os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")\n\n# Keep the environment clean if you rerun this cell after switching runtime type.\n_pip_uninstall("jax", "jaxlib")\n\n# pip itself can be stale on some runtimes.\n_pip("pip")\n\n# JAX wheels/indices (adjust if Colab images change).\nJAX_CPU_PKG = "jax"\nJAX_GPU_PKG = "jax[cuda12_pip]"  # if this fails, try: jax[cuda11_pip]\nJAX_GPU_WHL_INDEX = "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"\nJAX_TPU_PKG = "jax[tpu]"\nJAX_TPU_WHL_INDEX = "https://storage.googleapis.com/jax-releases/libtpu_releases.html"\n\nif accelerator == "tpu":\n    # NOTE: for TPU runtimes, the JAX project publishes wheels via the libtpu index.\n    _pip(JAX_TPU_PKG, "-f", JAX_TPU_WHL_INDEX)\nelif accelerator == "gpu":\n    # NOTE: for GPU runtimes, pick the CUDA wheel index that matches your Colab image.\n    _pip(JAX_GPU_PKG, "-f", JAX_GPU_WHL_INDEX)\nelse:\n    _pip(JAX_CPU_PKG)\n\nprint("JAX install finished.")\nprint("If you see import/backend issues below, use Runtime → Restart runtime, then rerun from the top.")\n

## Install koki2 (editable)\n

In [None]:
import subprocess\nimport sys\n\nproc = subprocess.run(\n    [sys.executable, "-m", "pip", "install", "-q", "-e", ".", "--no-deps"],\n    text=True,\n    stdout=subprocess.PIPE,\n    stderr=subprocess.STDOUT,\n)\nif proc.returncode != 0:\n    print(proc.stdout)\nproc.check_returncode()\n

## Sanity check: JAX sees the accelerator\n

In [None]:
import jax\n\nprint("jax:", jax.__version__)\nprint("backend:", jax.default_backend())\nprint("devices:")\nfor d in jax.devices():\n    print(" -", d)\n

## Smoke run\n\nThis is intentionally tiny to keep compile time reasonable on Colab.\n

In [None]:
import subprocess\n\nproc = subprocess.run([\n    "koki2",\n    "evo-l0",\n    "--seed", "0",\n    "--generations", "2",\n    "--pop-size", "16",\n    "--episodes", "2",\n    "--steps", "64",\n    "--jit-es",\n    "--log-every", "1",\n], text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)\nif proc.returncode != 0:\n    print(proc.stdout)\nproc.check_returncode()\nprint("\n".join(proc.stdout.splitlines()[-20:]))\n