# koki2 — Stage 2 modulator grid (L1.0 deplete/respawn + L1.1 intermittent gradient)

This notebook runs a **Stage 2** comparison to test whether consequence-aligned neuromodulation (drive/event) can yield useful within-life plasticity without the hazard-contact regression observed with spike-modulated plasticity.

Protocol (aligned with `PLAN.md`/`WORK.md`):
- Environment: L1.0 **deplete/respawn** + L0.2 **harmful sources** + L1.1 **intermittent gradient** (`--grad-dropout-p 0.5`).
- Stronger effect size: `--steps 256`, hazard persistence pressure `--bad-source-respawn-delay 0`.
- Grid: no-plastic vs plastic with `--modulator-kind {spike,drive,event}` and `--plast-eta ∈ {0.01, 0.05}` (`--plast-lambda 0.9`).

Repo hygiene tip: if you opened this from GitHub, use **File → Save a copy in Drive** (instead of saving back to GitHub) to avoid committing execution outputs/metadata.

1. **Runtime → Change runtime type** → select **GPU** (L4/T4/A100, etc.) or **TPU** (v5/v6e, etc.).
2. Run the cells top-to-bottom.


## (Optional) Mount Google Drive

If you want `runs/` to persist across sessions, mount Drive and set `OUT_ROOT` to a Drive path (later).


In [None]:
# from google.colab import drive
# drive.mount('/content/drive')


## Clone the repo


In [None]:
REPO_URL = "https://github.com/Krisztiaan/koki2.git"  # override if using a fork
REPO_DIR = "koki2"

import os
import pathlib
import subprocess


def _run(cmd: list[str]) -> None:
    proc = subprocess.run(cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
    if proc.returncode != 0:
        print(proc.stdout)
        proc.check_returncode()


repo_path = pathlib.Path(REPO_DIR)
if not repo_path.exists():
    _run(["git", "clone", "--depth", "1", "--quiet", REPO_URL, REPO_DIR])

os.chdir(REPO_DIR)
print("Working dir:", pathlib.Path.cwd())


## Install JAX for your accelerator

This picks **TPU** if a TPU runtime is detected, otherwise picks **GPU** if `nvidia-smi` is available, else falls back to **CPU**.

If you manually change the runtime accelerator type after this, re-run this cell and restart the runtime.


In [None]:
import os
import subprocess
import sys


os.environ.setdefault("PIP_DISABLE_PIP_VERSION_CHECK", "1")


def _run(cmd: list[str], *, check: bool = True) -> None:
    proc = subprocess.run(cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
    if check and proc.returncode != 0:
        print(proc.stdout)
        proc.check_returncode()


def _pip(*args: str) -> None:
    _run([sys.executable, "-m", "pip", "install", "-q", "-U", *args], check=True)


def _pip_uninstall(*args: str) -> None:
    _run([sys.executable, "-m", "pip", "uninstall", "-q", "-y", *args], check=False)


def _is_tpu_runtime() -> bool:
    return any(k in os.environ for k in ["COLAB_TPU_ADDR", "TPU_NAME", "XRT_TPU_CONFIG"])


def _has_nvidia_smi() -> bool:
    return subprocess.run(["bash", "-lc", "command -v nvidia-smi"], check=False).returncode == 0


accelerator = "tpu" if _is_tpu_runtime() else ("gpu" if _has_nvidia_smi() else "cpu")
print("Detected accelerator:", accelerator)

# Colab-friendly default: avoid preallocating most GPU memory up-front.
if accelerator == "gpu":
    os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")

# Keep the environment clean if you rerun this cell after switching runtime type.
_pip_uninstall("jax", "jaxlib")

_pip("pip")

# JAX wheels/indices (adjust if Colab images change).
JAX_CPU_PKG = "jax"
JAX_GPU_PKG = "jax[cuda12_pip]"  # if this fails, try: jax[cuda11_pip]"
JAX_GPU_WHL_INDEX = "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
JAX_TPU_PKG = "jax[tpu]"
JAX_TPU_WHL_INDEX = "https://storage.googleapis.com/jax-releases/libtpu_releases.html"

if accelerator == "tpu":
    _pip(JAX_TPU_PKG, "-f", JAX_TPU_WHL_INDEX)
elif accelerator == "gpu":
    _pip(JAX_GPU_PKG, "-f", JAX_GPU_WHL_INDEX)
else:
    _pip(JAX_CPU_PKG)

print("JAX install finished.")
print("If you see import/backend issues below, use Runtime → Restart runtime, then rerun from the top.")


## Install koki2 (editable)


In [None]:
import subprocess
import sys

proc = subprocess.run(
    [sys.executable, "-m", "pip", "install", "-q", "-e", ".", "--no-deps"],
    text=True,
    stdout=subprocess.PIPE,
    stderr=subprocess.STDOUT,
)
if proc.returncode != 0:
    print(proc.stdout)
proc.check_returncode()


## Sanity check: JAX sees the accelerator


In [None]:
import jax

print("jax:", jax.__version__)
print("backend:", jax.default_backend())
print("devices:")
for d in jax.devices():
    print(" -", d)


## Run: Stage 2 hazard persistence + eta=0 controls (ES + held-out eval)

This runs training for each condition, then evaluates each saved `best_genome.npz` on held-out episodes and writes a JSONL summary under `runs/stage2_scans/`.


In [None]:
import json
import re
import subprocess
import time
from collections import defaultdict
from pathlib import Path
from statistics import mean, stdev

OUT_ROOT = "runs/stage2_hazp_l10l11"  # or e.g. "/content/drive/MyDrive/koki2_runs/stage2_hazp_l10l11"
SEEDS = "0,1,2,3,4"

GENERATIONS = 30
POP_SIZE = 64
TRAIN_EPISODES = 4
STEPS = 256
LOG_EVERY = 10

EVAL_EP = 512
EVAL_SEEDS = [424242, 0]

# Env: L1.0 deplete/respawn + harmful sources + intermittent gradient.
# This notebook focuses on **hazard persistence** via bad_source_deplete_p < 1.0, which should increase the value of within-episode consequence learning.
ENV_BASE = [
    '--deplete-sources',
    '--respawn-delay', '4',
    '--bad-source-respawn-delay', '0',
    '--bad-source-deplete-p', '0.25',
    '--num-sources', '4',
    '--num-bad-sources', '2',
    '--bad-source-integrity-loss', '0.25',
    '--grad-dropout-p', '0.5',
    '--success-bonus', '50',
]

PLAST_LAMBDA = 0.9
ETA_GRID = [0.01, 0.05]

CONDITIONS = [
    ('A0', 'noplast', None, []),
]
for eta in ETA_GRID:
    CONDITIONS += [
        ('A1', 'spike', eta, ['--plast-enabled', '--plast-eta', str(eta), '--plast-lambda', str(PLAST_LAMBDA), '--modulator-kind', 'spike']),
        ('A2', 'drive', eta, ['--plast-enabled', '--plast-eta', str(eta), '--plast-lambda', str(PLAST_LAMBDA), '--modulator-kind', 'drive', '--mod-drive-scale', '1.0']),
        ('A3', 'event', eta, ['--plast-enabled', '--plast-eta', str(eta), '--plast-lambda', str(PLAST_LAMBDA), '--modulator-kind', 'event']),
    ]

pat_out_dir = re.compile(r'out_dir=(\S+)')
pat_seed = re.compile(r'_seed(\d+)$')
pat_kv = re.compile(r'(\w+)=([-+0-9.eE]+)')

def _run(cmd: list[str]) -> str:
    proc = subprocess.run(cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
    if proc.returncode != 0:
        print(proc.stdout)
    proc.check_returncode()
    return proc.stdout

stamp = time.strftime('%Y-%m-%d_%H%M%S', time.gmtime())
log_path = Path('runs/stage2_scans') / f'{stamp}_stage2_hazp_steps{STEPS}_g{GENERATIONS}_p{POP_SIZE}_ep{TRAIN_EPISODES}_eval{EVAL_EP}.txt'
jsonl_path = Path('runs/stage2_scans') / f'{stamp}_stage2_hazp_steps{STEPS}_g{GENERATIONS}_p{POP_SIZE}_ep{TRAIN_EPISODES}_eval{EVAL_EP}.jsonl'
log_path.parent.mkdir(parents=True, exist_ok=True)

rows = []
t0 = time.time()

with log_path.open('w', encoding='utf-8') as flog:
    for cond_id, variant, eta, extra in CONDITIONS:
        tag = f'stage2_hazp_{cond_id}_{variant}' + ('' if eta is None else f'_eta{eta:.2f}')
        print('\n=== TRAIN', tag, '===')
        out = _run([
            'koki2', 'batch-evo-l0',
            '--seeds', SEEDS,
            '--out-root', OUT_ROOT,
            '--tag', tag,
            '--generations', str(GENERATIONS),
            '--pop-size', str(POP_SIZE),
            '--episodes', str(TRAIN_EPISODES),
            '--steps', str(STEPS),
            '--jit-es',
            '--log-every', str(LOG_EVERY),
        ] + ENV_BASE + extra)
        flog.write(out + '\n')
        flog.flush()

        out_dirs = pat_out_dir.findall(out)
        if not out_dirs:
            raise RuntimeError('no out_dirs parsed for ' + tag)

        for eval_seed in EVAL_SEEDS:
            # Baselines once per eval seed (same env across conditions).
            for policy in ['greedy', 'random']:
                out_b = _run([
                    'koki2', 'baseline-l0',
                    '--seed', str(eval_seed),
                    '--policy', policy,
                    '--episodes', str(EVAL_EP),
                    '--steps', str(STEPS),
                    '--bad-source-respawn-delay', '0',
                    '--success-bonus', '50',
                ] + ENV_BASE)
                flog.write(out_b + '\n')
                kv = dict(pat_kv.findall(out_b.strip().splitlines()[-1]))
                rows.append({
                    'tag': tag,
                    'cond': cond_id,
                    'variant': variant,
                    'plast_eta': eta,
                    'kind': f'baseline_{policy}',
                    'override_plast_eta': None,
                    'eval_seed': eval_seed,
                    'run_seed': None,
                    'run_dir': None,
                    'mean_fitness': float(kv['mean_fitness']),
                    'success_rate': float(kv['success_rate']),
                    'mean_t_alive': float(kv['mean_t_alive']),
                    'mean_energy_gained': float(kv['mean_energy_gained']),
                    'mean_bad_arrivals': float(kv['mean_bad_arrivals']),
                    'mean_integrity_min': float(kv['mean_integrity_min']),
                    'mean_abs_dw_mean': float(kv.get('mean_abs_dw_mean', '0.0')),
                    'mean_abs_modulator_mean': float(kv.get('mean_abs_modulator_mean', '0.0')),
                    'mean_abs_dw_on_event': float(kv.get('mean_abs_dw_on_event', '0.0')),
                    'event_step_frac': float(kv.get('event_step_frac', '0.0')),
                })

            for d in out_dirs:
                out_e = _run([
                    'koki2', 'eval-run',
                    '--run-dir', d,
                    '--episodes', str(EVAL_EP),
                    '--seed', str(eval_seed),
                    '--baseline-policy', 'none',
                ])
                flog.write(out_e + '\n')
                best_line = [ln for ln in out_e.splitlines() if ln.startswith('best_genome')][0]
                kv = dict(pat_kv.findall(best_line))
                m = pat_seed.search(d)
                run_seed = int(m.group(1)) if m else None
                rows.append({
                    'tag': tag,
                    'cond': cond_id,
                    'variant': variant,
                    'plast_eta': eta,
                    'kind': 'best_genome',
                    'override_plast_eta': None,
                    'eval_seed': eval_seed,
                    'run_seed': run_seed,
                    'run_dir': d,
                    'mean_fitness': float(kv['mean_fitness']),
                    'success_rate': float(kv['success_rate']),
                    'mean_t_alive': float(kv['mean_t_alive']),
                    'mean_energy_gained': float(kv['mean_energy_gained']),
                    'mean_bad_arrivals': float(kv['mean_bad_arrivals']),
                    'mean_integrity_min': float(kv['mean_integrity_min']),
                    'mean_abs_dw_mean': float(kv.get('mean_abs_dw_mean', '0.0')),
                    'mean_abs_modulator_mean': float(kv.get('mean_abs_modulator_mean', '0.0')),
                    'mean_abs_dw_on_event': float(kv.get('mean_abs_dw_on_event', '0.0')),
                    'event_step_frac': float(kv.get('event_step_frac', '0.0')),
                })

                # Causality probe: evaluate the same genome with learning disabled.
                if eta is not None and float(eta) > 0.0 and variant != 'noplast':
                    out_e0 = _run([
                        'koki2', 'eval-run',
                        '--run-dir', d,
                        '--episodes', str(EVAL_EP),
                        '--seed', str(eval_seed),
                        '--baseline-policy', 'none',
                        '--override-plast-eta', '0.0',
                    ])
                    flog.write(out_e0 + '\n')
                    best_line0 = [ln for ln in out_e0.splitlines() if ln.startswith('best_genome')][0]
                    kv0 = dict(pat_kv.findall(best_line0))
                    rows.append({
                        'tag': tag,
                        'cond': cond_id,
                        'variant': variant,
                        'plast_eta': eta,
                        'kind': 'best_genome_eta0',
                        'override_plast_eta': 0.0,
                        'eval_seed': eval_seed,
                        'run_seed': run_seed,
                        'run_dir': d,
                        'mean_fitness': float(kv0['mean_fitness']),
                        'success_rate': float(kv0['success_rate']),
                        'mean_t_alive': float(kv0['mean_t_alive']),
                        'mean_energy_gained': float(kv0['mean_energy_gained']),
                        'mean_bad_arrivals': float(kv0['mean_bad_arrivals']),
                        'mean_integrity_min': float(kv0['mean_integrity_min']),
                        'mean_abs_dw_mean': float(kv0.get('mean_abs_dw_mean', '0.0')),
                        'mean_abs_modulator_mean': float(kv0.get('mean_abs_modulator_mean', '0.0')),
                        'mean_abs_dw_on_event': float(kv0.get('mean_abs_dw_on_event', '0.0')),
                        'event_step_frac': float(kv0.get('event_step_frac', '0.0')),
                    })

jsonl_path.write_text('\n'.join(json.dumps(r) for r in rows) + '\n', encoding='utf-8')
print('\nWrote log:', log_path)
print('Wrote results:', jsonl_path)
print('Total wall time (min):', (time.time() - t0) / 60.0)

# Quick summary (mean ± stdev across ES seeds)
best = defaultdict(list)
base = {}
for r in rows:
    k = (r['cond'], r['variant'], r['plast_eta'], r['eval_seed'])
    if r['kind'].startswith('baseline_'):
        base[(k, r['kind'])] = r
    else:
        best[(k, r['kind'])].append(r)

keys = sorted({k for (k, _kind) in best.keys()})
for k in keys:
    cond, variant, eta, eval_seed = k
    g = base[(k, 'baseline_greedy')]
    rnd = base[(k, 'baseline_random')]
    rs = best.get((k, 'best_genome'), [])
    rs0 = best.get((k, 'best_genome_eta0'), [])
    print(f"\ncond={cond} variant={variant} eta={eta} eval_seed={eval_seed}")
    for who, vals in [('best', rs), ('best_eta0', rs0), ('greedy', [g]), ('random', [rnd])]:
        mf = [v['mean_fitness'] for v in vals]
        bb = [v['mean_bad_arrivals'] for v in vals]
        im = [v['mean_integrity_min'] for v in vals]
        dw = [v['mean_abs_dw_mean'] for v in vals]
        dwe = [v.get('mean_abs_dw_on_event', 0.0) for v in vals]
        evf = [v.get('event_step_frac', 0.0) for v in vals]
        mm = [v.get('mean_abs_modulator_mean', 0.0) for v in vals]
        if len(vals) > 1:
            print(f"  {who}: fitness={mean(mf):.4f}±{stdev(mf):.4f} bad={mean(bb):.4f}±{stdev(bb):.4f} imin={mean(im):.4f}±{stdev(im):.4f} dw={mean(dw):.6f}±{stdev(dw):.6f} dwe={mean(dwe):.6f}±{stdev(dwe):.6f} evf={mean(evf):.4f}±{stdev(evf):.4f} mod={mean(mm):.6f}±{stdev(mm):.6f}")
        else:
            print(f"  {who}: fitness={mf[0]:.4f} bad={bb[0]:.4f} imin={im[0]:.4f} dw={dw[0]:.6f} dwe={dwe[0]:.6f} evf={evf[0]:.4f} mod={mm[0]:.6f}")
