# koki2 — Stage 2 (L1.0 deplete/respawn + L1.1 intermittent gradient): plastic vs no-plastic

This notebook runs a Stage 2 comparison aligned with `thesis/12_IMPLEMENTATION_ENVIRONMENT_LADDER_SPEC.md` (L1.0 + L1.1) and logged in `WORK.md`.

- Environment: L1.0 **deplete/respawn** + L0.2 **harmful sources** + L1.1 **intermittent gradient** (`--grad-dropout-p 0.5`).
- Compare: **no plasticity** vs **plasticity enabled** (`--plast-enabled --plast-eta 0.05 --plast-lambda 0.9`).

Repo hygiene tip: if you opened this from GitHub, use **File → Save a copy in Drive** (instead of saving back to GitHub) to avoid committing execution outputs/metadata.

1. **Runtime → Change runtime type** → select **GPU** (L4/T4/A100, etc.) or **TPU** (v5/v6e, etc.).
2. Run the cells top-to-bottom.


## (Optional) Mount Google Drive\n\nIf you want `runs/` to persist across sessions, mount Drive and set `OUT_ROOT` to a Drive path (later).\n

In [None]:
# from google.colab import drive\n# drive.mount('/content/drive')\n

## Clone the repo\n

In [None]:
REPO_URL = "https://github.com/Krisztiaan/koki2.git"  # override if using a fork\nREPO_DIR = "koki2"\n\nimport os\nimport pathlib\nimport subprocess\n\n\ndef _run(cmd: list[str]) -> None:\n    proc = subprocess.run(cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)\n    if proc.returncode != 0:\n        print(proc.stdout)\n        proc.check_returncode()\n\n\nrepo_path = pathlib.Path(REPO_DIR)\nif not repo_path.exists():\n    _run(["git", "clone", "--depth", "1", "--quiet", REPO_URL, REPO_DIR])\n\nos.chdir(REPO_DIR)\nprint("Working dir:", pathlib.Path.cwd())\n

## Install JAX for your accelerator\n\nThis picks **TPU** if a TPU runtime is detected, otherwise picks **GPU** if `nvidia-smi` is available, else falls back to **CPU**.\n\nIf you manually change the runtime accelerator type after this, re-run this cell and restart the runtime.\n

In [None]:
import os\nimport subprocess\nimport sys\n\n\nos.environ.setdefault("PIP_DISABLE_PIP_VERSION_CHECK", "1")\n\n\ndef _run(cmd: list[str], *, check: bool = True) -> None:\n    proc = subprocess.run(cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)\n    if check and proc.returncode != 0:\n        print(proc.stdout)\n        proc.check_returncode()\n\n\ndef _pip(*args: str) -> None:\n    _run([sys.executable, "-m", "pip", "install", "-q", "-U", *args], check=True)\n\n\ndef _pip_uninstall(*args: str) -> None:\n    _run([sys.executable, "-m", "pip", "uninstall", "-q", "-y", *args], check=False)\n\n\ndef _is_tpu_runtime() -> bool:\n    return any(k in os.environ for k in ["COLAB_TPU_ADDR", "TPU_NAME", "XRT_TPU_CONFIG"])\n\n\ndef _has_nvidia_smi() -> bool:\n    return subprocess.run(["bash", "-lc", "command -v nvidia-smi"], check=False).returncode == 0\n\n\naccelerator = "tpu" if _is_tpu_runtime() else ("gpu" if _has_nvidia_smi() else "cpu")\nprint("Detected accelerator:", accelerator)\n\n# Colab-friendly default: avoid preallocating most GPU memory up-front.\nif accelerator == "gpu":\n    os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")\n\n# Keep the environment clean if you rerun this cell after switching runtime type.\n_pip_uninstall("jax", "jaxlib")\n\n_pip("pip")\n\n# JAX wheels/indices (adjust if Colab images change).\nJAX_CPU_PKG = "jax"\nJAX_GPU_PKG = "jax[cuda12_pip]"  # if this fails, try: jax[cuda11_pip]\nJAX_GPU_WHL_INDEX = "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"\nJAX_TPU_PKG = "jax[tpu]"\nJAX_TPU_WHL_INDEX = "https://storage.googleapis.com/jax-releases/libtpu_releases.html"\n\nif accelerator == "tpu":\n    _pip(JAX_TPU_PKG, "-f", JAX_TPU_WHL_INDEX)\nelif accelerator == "gpu":\n    _pip(JAX_GPU_PKG, "-f", JAX_GPU_WHL_INDEX)\nelse:\n    _pip(JAX_CPU_PKG)\n\nprint("JAX install finished.")\nprint("If you see import/backend issues below, use Runtime → Restart runtime, then rerun from the top.")\n

## Install koki2 (editable)\n

In [None]:
import subprocess\nimport sys\n\nproc = subprocess.run(\n    [sys.executable, "-m", "pip", "install", "-q", "-e", ".", "--no-deps"],\n    text=True,\n    stdout=subprocess.PIPE,\n    stderr=subprocess.STDOUT,\n)\nif proc.returncode != 0:\n    print(proc.stdout)\nproc.check_returncode()\n

## Sanity check: JAX sees the accelerator\n

In [None]:
import jax\n\nprint("jax:", jax.__version__)\nprint("backend:", jax.default_backend())\nprint("devices:")\nfor d in jax.devices():\n    print(" -", d)\n

## Run: multi-seed ES + held-out eval (plastic vs no-plastic)

Defaults match the ES30 comparison protocol in `WORK.md` (update budgets/seeds as needed).


In [None]:
import re
import subprocess
from statistics import mean, stdev

OUT_ROOT = "runs/stage2_es30_l10l11"  # or e.g. "/content/drive/MyDrive/koki2_runs/stage2_es30_l10l11"
SEEDS = "0,1,2,3,4"

GENERATIONS = 30
POP_SIZE = 64
TRAIN_EPISODES = 4
STEPS = 128
LOG_EVERY = 1

# L1.0 + L0.2 (+ L1.1 dropout)
ENV_ARGS = [
    "--deplete-sources",
    "--respawn-delay", "4",
    "--num-sources", "4",
    "--num-bad-sources", "2",
    "--bad-source-integrity-loss", "0.25",
    "--grad-dropout-p", "0.5",
]

# Plasticity config (spike modulator default).
PLAST_ETA = 0.05
PLAST_LAMBDA = 0.9

EVAL_EPISODES = 512
EVAL_SEEDS = [424242, 0]


def _run(cmd: list[str]) -> str:
    proc = subprocess.run(cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
    if proc.returncode != 0:
        print(proc.stdout)
    proc.check_returncode()
    return proc.stdout


def _batch(tag: str, extra: list[str]) -> list[str]:
    cmd = [
        "koki2",
        "batch-evo-l0",
        "--seeds", SEEDS,
        "--out-root", OUT_ROOT,
        "--tag", tag,
        "--generations", str(GENERATIONS),
        "--pop-size", str(POP_SIZE),
        "--episodes", str(TRAIN_EPISODES),
        "--steps", str(STEPS),
        "--log-every", str(LOG_EVERY),
    ] + ENV_ARGS + extra

    out = _run(cmd)
    print("\n".join(out.splitlines()[-20:]))
    out_dirs = re.findall(r"out_dir=(\S+)", out)
    print(tag, "out_dirs:", out_dirs)
    return out_dirs


# A) No plasticity
dirs_noplast = _batch("stage2_es30_noplast", [])

# B) Plasticity enabled
dirs_plast = _batch("stage2_es30_plast_eta0.05", [
    "--plast-enabled",
    "--plast-eta", str(PLAST_ETA),
    "--plast-lambda", str(PLAST_LAMBDA),
])


def _parse_best(line: str) -> dict:
    # Example:
    # best_genome episodes=512 mean_fitness=... success_rate=... mean_t_alive=... mean_bad_arrivals=... mean_integrity_min=... mean_abs_dw_mean=... mean_abs_dw_on_event=... event_step_frac=... mean_abs_modulator_mean=...
    parts = dict(p.split('=') for p in line.split() if '=' in p)
    return {
        'mean_fitness': float(parts['mean_fitness']),
        'success_rate': float(parts['success_rate']),
        'mean_t_alive': float(parts['mean_t_alive']),
        'mean_bad_arrivals': float(parts['mean_bad_arrivals']),
        'mean_integrity_min': float(parts['mean_integrity_min']),
        'mean_abs_dw_mean': float(parts.get('mean_abs_dw_mean', '0.0')),
        'mean_abs_dw_on_event': float(parts.get('mean_abs_dw_on_event', '0.0')),
        'event_step_frac': float(parts.get('event_step_frac', '0.0')),
        'mean_abs_modulator_mean': float(parts.get('mean_abs_modulator_mean', '0.0')),
    }


# Held-out eval
for eval_seed in EVAL_SEEDS:
    print('\n=== Held-out eval_seed=', eval_seed, '===')
    rows = []
    for tag, dirs in [('noplast', dirs_noplast), ('plast', dirs_plast)]:
        for r in dirs:
            out = _run([
                'koki2',
                'eval-run',
                '--run-dir', r,
                '--episodes', str(EVAL_EPISODES),
                '--seed', str(eval_seed),
                '--baseline-policy', 'none',
            ])
            print(out.strip())
            best_line = [ln for ln in out.splitlines() if ln.startswith('best_genome')][0]
            row = _parse_best(best_line)
            row['cond'] = tag
            rows.append(row)

    # Small aggregate summary (mean ± stdev across seeds)
    for cond in ['noplast', 'plast']:
        xs = [r for r in rows if r['cond'] == cond]
        print('\n--', cond, '(n=%d)' % len(xs))
        for k in ['mean_fitness', 'success_rate', 'mean_t_alive', 'mean_bad_arrivals', 'mean_integrity_min', 'mean_abs_dw_mean', 'mean_abs_dw_on_event', 'event_step_frac', 'mean_abs_modulator_mean']:
            vals = [r[k] for r in xs]
            if len(vals) >= 2:
                print(f"  {k}: {mean(vals):.6f} ± {stdev(vals):.6f}")
            else:
                print(f"  {k}: {vals[0]:.6f}")
