# EvoJAX NEAT (Colab-Friendly)

This notebook clones your repo and runs both modes using the script:
- `scripts/run_evojax_neat_both.py`


In [None]:
import os
import sys
import subprocess
from pathlib import Path

REPO_URL = 'https://github.com/aryangoyal7/NEAT-Sakana.git'
WORKDIR = Path('/content') if Path('/content').exists() else Path.cwd()
REPO_DIR = WORKDIR / 'NEAT-Sakana'

if not REPO_DIR.exists():
    subprocess.run(['git', 'clone', REPO_URL, str(REPO_DIR)], check=True)

SCRIPT_PATH = REPO_DIR / 'scripts' / 'run_evojax_neat_both.py'
OUT_ROOT = REPO_DIR / 'artifacts'
OUT_ROOT.mkdir(parents=True, exist_ok=True)

print('Kernel python:', sys.executable)
print('REPO_DIR    :', REPO_DIR)
print('SCRIPT_PATH :', SCRIPT_PATH)
print('OUT_ROOT    :', OUT_ROOT)

In [None]:
!{sys.executable} -m pip install -U pip
!{sys.executable} -m pip install -r "{REPO_DIR / 'requirements.txt'}"
!{sys.executable} -m pip install evojax==0.2.17 flax optax chex orbax-checkpoint tensorstore rich absl-py cma opencv-python-headless
!{sys.executable} -m pip install -e "{REPO_DIR}"

In [None]:
GENERATIONS = 8
POP_SIZE = 24
MAX_STEPS = 800
EPISODES_DIRECT = 2
EPISODES_SELFPLAY = 1

In [None]:
env = os.environ.copy()
env['PYTHONPATH'] = f"{REPO_DIR / 'src'}:{env.get('PYTHONPATH','')}"
env['MPLCONFIGDIR'] = '/tmp/mplconfig'
Path('/tmp/mplconfig').mkdir(parents=True, exist_ok=True)

cmd = [
    sys.executable, str(SCRIPT_PATH),
    '--repo-dir', str(REPO_DIR),
    '--generations', str(GENERATIONS),
    '--pop-size', str(POP_SIZE),
    '--max-steps', str(MAX_STEPS),
    '--episodes-direct', str(EPISODES_DIRECT),
    '--episodes-selfplay', str(EPISODES_SELFPLAY),
]
print('Running:', ' '.join(cmd))
subprocess.run(cmd, cwd=str(REPO_DIR), env=env, check=True)

In [None]:
from IPython.display import display, Markdown, Image

def latest_dir(prefix: str) -> Path:
    dirs = sorted(OUT_ROOT.glob(f'{prefix}_*'), key=lambda p: p.stat().st_mtime, reverse=True)
    if not dirs:
        raise RuntimeError(f'No artifact folder found for {prefix}')
    return dirs[0]

direct_dir = latest_dir('direct_vs_builtin')
selfplay_dir = latest_dir('selfplay_then_builtin')
print('direct_dir  =', direct_dir)
print('selfplay_dir=', selfplay_dir)

def show_img(path: Path, width=700):
    if path.exists():
        display(Markdown(f'`{path}`'))
        display(Image(filename=str(path), width=width))
    else:
        print('Missing:', path)

show_img(direct_dir / 'plots' / 'fitness_complexity.png')
show_img(direct_dir / 'plots' / 'species_sizes.png')
show_img(direct_dir / 'plots' / 'champion_network.png')
show_img(direct_dir / 'gifs' / 'champion_vs_builtin.gif', width=520)

show_img(selfplay_dir / 'plots' / 'fitness_complexity.png')
show_img(selfplay_dir / 'plots' / 'species_sizes.png')
show_img(selfplay_dir / 'plots' / 'champion_network.png')
show_img(selfplay_dir / 'gifs' / 'champion_vs_builtin.gif', width=520)
show_img(selfplay_dir / 'gifs' / 'champion_vs_runnerup.gif', width=520)

In [None]:
for p in [direct_dir / 'report.md', selfplay_dir / 'report.md']:
    print('\n' + '=' * 100)
    print(p)
    print('=' * 100)
    if p.exists():
        print(p.read_text(encoding='utf-8'))
    else:
        print('Missing report file')