# koki2 — L1.0 effect-size grid (steps × hazard persistence × success bonus)

This notebook runs a small **pre-registered grid** to strengthen the evidence for the thesis claim that L1.0 (temporal structure) amplifies survival-weighted strategies, and to check how that depends on:

- **Horizon length** (`--steps`)
- **Hazard persistence** (`--bad-source-respawn-delay`)
- **Success shaping** (`--success-bonus`)

The protocol is logged in `WORK.md` and can be run locally or on Colab GPU/TPU.

Repo hygiene tip: if you opened this from GitHub, use **File → Save a copy in Drive** (instead of saving back to GitHub) to avoid committing execution outputs/metadata.

1. **Runtime → Change runtime type** → select **GPU** (L4/T4/A100, etc.) or **TPU** (v5/v6e, etc.).
2. Run the cells top-to-bottom.


## (Optional) Mount Google Drive\n\nIf you want `runs/` to persist across sessions, mount Drive and set `OUT_ROOT` to a Drive path (later).\n

In [None]:
# from google.colab import drive\n# drive.mount('/content/drive')\n

## Clone the repo\n

In [None]:
REPO_URL = "https://github.com/Krisztiaan/koki2.git"  # override if using a fork\nREPO_DIR = "koki2"\n\nimport os\nimport pathlib\nimport subprocess\n\n\ndef _run(cmd: list[str]) -> None:\n    proc = subprocess.run(cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)\n    if proc.returncode != 0:\n        print(proc.stdout)\n        proc.check_returncode()\n\n\nrepo_path = pathlib.Path(REPO_DIR)\nif not repo_path.exists():\n    _run(["git", "clone", "--depth", "1", "--quiet", REPO_URL, REPO_DIR])\n\nos.chdir(REPO_DIR)\nprint("Working dir:", pathlib.Path.cwd())\n

## Install JAX for your accelerator\n\nThis picks **TPU** if a TPU runtime is detected, otherwise picks **GPU** if `nvidia-smi` is available, else falls back to **CPU**.\n\nIf you manually change the runtime accelerator type after this, re-run this cell and restart the runtime.\n

In [None]:
import os\nimport subprocess\nimport sys\n\n\nos.environ.setdefault("PIP_DISABLE_PIP_VERSION_CHECK", "1")\n\n\ndef _run(cmd: list[str], *, check: bool = True) -> None:\n    proc = subprocess.run(cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)\n    if check and proc.returncode != 0:\n        print(proc.stdout)\n        proc.check_returncode()\n\n\ndef _pip(*args: str) -> None:\n    _run([sys.executable, "-m", "pip", "install", "-q", "-U", *args], check=True)\n\n\ndef _pip_uninstall(*args: str) -> None:\n    _run([sys.executable, "-m", "pip", "uninstall", "-q", "-y", *args], check=False)\n\n\ndef _is_tpu_runtime() -> bool:\n    return any(k in os.environ for k in ["COLAB_TPU_ADDR", "TPU_NAME", "XRT_TPU_CONFIG"])\n\n\ndef _has_nvidia_smi() -> bool:\n    return subprocess.run(["bash", "-lc", "command -v nvidia-smi"], check=False).returncode == 0\n\n\naccelerator = "tpu" if _is_tpu_runtime() else ("gpu" if _has_nvidia_smi() else "cpu")\nprint("Detected accelerator:", accelerator)\n\n# Colab-friendly default: avoid preallocating most GPU memory up-front.\nif accelerator == "gpu":\n    os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")\n\n# Keep the environment clean if you rerun this cell after switching runtime type.\n_pip_uninstall("jax", "jaxlib")\n\n_pip("pip")\n\n# JAX wheels/indices (adjust if Colab images change).\nJAX_CPU_PKG = "jax"\nJAX_GPU_PKG = "jax[cuda12_pip]"  # if this fails, try: jax[cuda11_pip]\nJAX_GPU_WHL_INDEX = "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"\nJAX_TPU_PKG = "jax[tpu]"\nJAX_TPU_WHL_INDEX = "https://storage.googleapis.com/jax-releases/libtpu_releases.html"\n\nif accelerator == "tpu":\n    _pip(JAX_TPU_PKG, "-f", JAX_TPU_WHL_INDEX)\nelif accelerator == "gpu":\n    _pip(JAX_GPU_PKG, "-f", JAX_GPU_WHL_INDEX)\nelse:\n    _pip(JAX_CPU_PKG)\n\nprint("JAX install finished.")\nprint("If you see import/backend issues below, use Runtime → Restart runtime, then rerun from the top.")\n

## Install koki2 (editable)\n

In [None]:
import subprocess\nimport sys\n\nproc = subprocess.run(\n    [sys.executable, "-m", "pip", "install", "-q", "-e", ".", "--no-deps"],\n    text=True,\n    stdout=subprocess.PIPE,\n    stderr=subprocess.STDOUT,\n)\nif proc.returncode != 0:\n    print(proc.stdout)\nproc.check_returncode()\n

## Sanity check: JAX sees the accelerator\n

In [None]:
import jax\n\nprint("jax:", jax.__version__)\nprint("backend:", jax.default_backend())\nprint("devices:")\nfor d in jax.devices():\n    print(" -", d)\n

## Run: 12-condition grid (ES30 × seeds 0..4; held-out eval)

Defaults match the protocol in `WORK.md`:

- Grid: `steps ∈ {128,256}`, `bad_source_respawn_delay ∈ {0,1,4}`, `success_bonus ∈ {0,50}` (12 conditions)
- ES per condition: `generations=30`, `pop_size=64`, `train_episodes=4` (JIT ES)
- Eval: 512 episodes, eval seeds {424242, 0}, baselines greedy/random


In [None]:
import json
import re
import subprocess
import time
from collections import defaultdict
from pathlib import Path
from statistics import mean, stdev

OUT_ROOT = Path("runs/stage1_grid_l10_effectsize")  # or e.g. Path('/content/drive/MyDrive/koki2_runs/stage1_grid_l10_effectsize')
SCAN_DIR = Path("runs/stage1_scans")
OUT_ROOT.mkdir(parents=True, exist_ok=True)
SCAN_DIR.mkdir(parents=True, exist_ok=True)

SEEDS = "0,1,2,3,4"
TRAIN_GENERATIONS = 30
TRAIN_POP = 64
TRAIN_EP = 4
LOG_EVERY = 10

EVAL_EP = 512
EVAL_SEEDS = [424242, 0]

GRID_STEPS = [128, 256]
GRID_BAD_RESP = [0, 1, 4]
GRID_SUCC = [0.0, 50.0]

# Env base: L1.0 deplete/respawn + L0.2 harmful sources.
# Keep bad_source_deplete_p=1.0 to avoid the 'camp on non-depleting hazard (gradient=0)' confound.
ENV_BASE = [
    "--deplete-sources",
    "--respawn-delay", "4",
    "--num-sources", "4",
    "--num-bad-sources", "2",
    "--bad-source-integrity-loss", "0.25",
    "--bad-source-deplete-p", "1.0",
]

pat_out_dir = re.compile(r"out_dir=(\S+)")
pat_kv = re.compile(r"(\w+)=([-0-9.]+)")
pat_seed = re.compile(r"_seed(\d+)$")


def _run(cmd: list[str]) -> str:
    proc = subprocess.run(cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
    if proc.returncode != 0:
        print(proc.stdout)
    proc.check_returncode()
    return proc.stdout


stamp = time.strftime('%Y-%m-%d_%H%M%S')
log_path = SCAN_DIR / f"grid_l10_effectsize_{stamp}.txt"
jsonl_path = SCAN_DIR / f"grid_l10_effectsize_{stamp}.jsonl"

rows = []
t0 = time.time()

with log_path.open('w', encoding='utf-8') as flog:
    cond_i = 0
    for steps in GRID_STEPS:
        for bad_resp in GRID_BAD_RESP:
            for succ in GRID_SUCC:
                cond_i += 1
                tag = (
                    f"grid_l10_steps{steps}_badresp{bad_resp}_succ{int(succ)}_"
                    f"g{TRAIN_GENERATIONS}_p{TRAIN_POP}_ep{TRAIN_EP}"
                )
                print(f"\n=== CONDITION {cond_i}/12: {tag} ===")

                cmd_train = [
                    'koki2', 'batch-evo-l0',
                    '--seeds', SEEDS,
                    '--out-root', str(OUT_ROOT),
                    '--tag', tag,
                    '--generations', str(TRAIN_GENERATIONS),
                    '--pop-size', str(TRAIN_POP),
                    '--episodes', str(TRAIN_EP),
                    '--steps', str(steps),
                    '--bad-source-respawn-delay', str(bad_resp),
                    '--success-bonus', str(succ),
                    '--jit-es',
                    '--log-every', str(LOG_EVERY),
                ] + ENV_BASE
                out = _run(cmd_train)
                flog.write(out + '\n')
                flog.flush()

                out_dirs = pat_out_dir.findall(out)
                if len(out_dirs) != 5:
                    raise RuntimeError(f"expected 5 out_dirs, got {len(out_dirs)}")

                for eval_seed in EVAL_SEEDS:
                    for policy in ['greedy', 'random']:
                        out_b = _run([
                            'koki2', 'baseline-l0',
                            '--seed', str(eval_seed),
                            '--policy', policy,
                            '--episodes', str(EVAL_EP),
                            '--steps', str(steps),
                            '--bad-source-respawn-delay', str(bad_resp),
                            '--success-bonus', str(succ),
                        ] + ENV_BASE).strip()
                        flog.write(out_b + '\n')
                        kv = dict(pat_kv.findall(out_b))
                        rows.append({
                            'steps': steps,
                            'bad_source_respawn_delay': bad_resp,
                            'success_bonus': succ,
                            'tag': tag,
                            'kind': f'baseline_{policy}',
                            'eval_seed': eval_seed,
                            'run_seed': None,
                            'run_dir': None,
                            'mean_fitness': float(kv['mean_fitness']),
                            'success_rate': float(kv['success_rate']),
                            'mean_t_alive': float(kv['mean_t_alive']),
                            'mean_energy_gained': float(kv['mean_energy_gained']),
                            'mean_bad_arrivals': float(kv['mean_bad_arrivals']),
                            'mean_integrity_min': float(kv['mean_integrity_min']),
                        })

                    for d in out_dirs:
                        out_e = _run([
                            'koki2', 'eval-run',
                            '--run-dir', d,
                            '--episodes', str(EVAL_EP),
                            '--seed', str(eval_seed),
                            '--baseline-policy', 'none',
                        ])
                        flog.write(out_e + '\n')
                        best_line = [ln for ln in out_e.splitlines() if ln.startswith('best_genome')][0]
                        kv = dict(pat_kv.findall(best_line))
                        m = pat_seed.search(d)
                        run_seed = int(m.group(1)) if m else None
                        rows.append({
                            'steps': steps,
                            'bad_source_respawn_delay': bad_resp,
                            'success_bonus': succ,
                            'tag': tag,
                            'kind': 'best_genome',
                            'eval_seed': eval_seed,
                            'run_seed': run_seed,
                            'run_dir': d,
                            'mean_fitness': float(kv['mean_fitness']),
                            'success_rate': float(kv['success_rate']),
                            'mean_t_alive': float(kv['mean_t_alive']),
                            'mean_energy_gained': float(kv['mean_energy_gained']),
                            'mean_bad_arrivals': float(kv['mean_bad_arrivals']),
                            'mean_integrity_min': float(kv['mean_integrity_min']),
                        })

                flog.flush()

jsonl_path.write_text('\n'.join(json.dumps(r) for r in rows) + '\n', encoding='utf-8')
print('\nWrote log:', log_path)
print('Wrote results:', jsonl_path)
print('Total wall time (min):', (time.time() - t0) / 60.0)

# Quick summary (mean ± stdev across ES seeds)
best = defaultdict(list)
base = {}
for r in rows:
    k = (r['steps'], r['bad_source_respawn_delay'], r['success_bonus'], r['eval_seed'])
    if r['kind'].startswith('baseline_'):
        base[(k, r['kind'])] = r
    else:
        best[k].append(r)

for k, rs in sorted(best.items()):
    steps, bad_resp, succ, eval_seed = k
    g = base[(k, 'baseline_greedy')]
    rnd = base[(k, 'baseline_random')]
    print(f"\ncond steps={steps} bad_resp={bad_resp} succ={succ} eval_seed={eval_seed}")
    for who, vals in [('best', rs), ('greedy', [g]), ('random', [rnd])]:
        mf = [v['mean_fitness'] for v in vals]
        tb = [v['mean_t_alive'] for v in vals]
        bb = [v['mean_bad_arrivals'] for v in vals]
        im = [v['mean_integrity_min'] for v in vals]
        if len(vals) > 1:
            print(f"  {who}: mean_fitness={mean(mf):.2f}±{stdev(mf):.2f} t_alive={mean(tb):.1f}±{stdev(tb):.1f} bad={mean(bb):.2f}±{stdev(bb):.2f} imin={mean(im):.3f}±{stdev(im):.3f}")
        else:
            print(f"  {who}: mean_fitness={mf[0]:.2f} t_alive={tb[0]:.1f} bad={bb[0]:.2f} imin={im[0]:.3f}")
