# AC-Solver Iterative Refinement — Colab Runner

**Runtime recommendation: High-RAM CPU**  
The search (V-guided greedy, beam) is CPU-only. GPU only helps training, which takes ~10 min per iteration — not the bottleneck.

**Persistence:** Repo lives on Google Drive. If Colab disconnects, re-run cells 1–3 (fast), then Cell 6 to resume.

## Cell 1 — Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')
print('Drive mounted.')

## Cell 2 — Clone or Update Repo on Drive

**First time only:** enter your GitHub Personal Access Token when prompted.  
Get one at: GitHub → Settings → Developer Settings → Personal access tokens → Fine-grained → repo read access.  
After first clone, this cell just does `git pull`.

In [None]:
import os
import subprocess

REPO_URL = 'https://github.com/Avi161/AC-Solver-Caltech.git'
DRIVE_DIR = '/content/drive/MyDrive/AC-Solver-Caltech'

if not os.path.exists(DRIVE_DIR):
    # First time: clone with token
    from getpass import getpass
    token = getpass('GitHub Personal Access Token: ')
    auth_url = REPO_URL.replace('https://', f'https://{token}@')
    subprocess.run(['git', 'clone', auth_url, DRIVE_DIR], check=True)
    print('Repo cloned to Drive.')
else:
    # Already cloned: just pull latest changes
    result = subprocess.run(['git', 'pull'], cwd=DRIVE_DIR, capture_output=True, text=True)
    print(result.stdout or 'Already up to date.')

os.chdir(DRIVE_DIR)
print(f'Working directory: {os.getcwd()}')

## Cell 3 — Install Dependencies

In [None]:
import subprocess
result = subprocess.run(
    ['pip', 'install', '-q', '-r', 'requirements.txt'],
    cwd=DRIVE_DIR, capture_output=True, text=True
)
print(result.stdout[-2000:] if result.stdout else 'Done.')
if result.returncode != 0:
    print('ERRORS:', result.stderr[-1000:])

## Cell 4 — Verify Setup

In [None]:
import os, json

checks = {
    'Original model checkpoint':  'value_search/checkpoints/best_mlp.pt',
    'Feature stats':              'value_search/checkpoints/feature_stats.json',
    'Greedy solved presentations': 'ac_solver/search/miller_schupp/data/greedy_solved_presentations.txt',
    'Greedy search paths':         'ac_solver/search/miller_schupp/data/greedy_search_paths.txt',
    'Config':                      'experiments/config.yaml',
}

all_ok = True
for label, path in checks.items():
    full = os.path.join(DRIVE_DIR, path)
    exists = os.path.exists(full)
    status = '✓' if exists else '✗  MISSING'
    print(f'  {status}  {label}')
    if not exists:
        all_ok = False

# Check refinement state (if resuming)
state_file = os.path.join(DRIVE_DIR, 'experiments/refinement/refinement_state.json')
if os.path.exists(state_file):
    with open(state_file) as f:
        state = json.load(f)
    print(f'\n  Saved state found: iteration={state["iteration"]}, '
          f'total_solved_history={state["total_solved_per_iteration"]}')
else:
    print('\n  No saved state — will start fresh.')

if all_ok:
    print('\nAll checks passed. Ready to run.')
else:
    print('\nFix missing files before running.')

## Cell 5 — Configuration
Edit these before running. Then run Cell 6 (fresh start) or Cell 7 (resume).

In [None]:
# ── Tune these ────────────────────────────────────────────────────────────────
MAX_ITERATIONS  = 5       # How many search→train cycles
MAX_PATH_LENGTH = 300     # Reject training paths longer than this
ENABLE_MCTS     = False   # MCTS is slow; keep False unless you have many hours

# Set to an integer (e.g. 100_000) to cap search budget for a quick smoke test.
# Set to None for full 1M-node search.
MAX_NODES = None
# ──────────────────────────────────────────────────────────────────────────────

import sys
sys.path.insert(0, DRIVE_DIR)

print('Config set:')
print(f'  max_iterations  = {MAX_ITERATIONS}')
print(f'  max_path_length = {MAX_PATH_LENGTH}')
print(f'  enable_mcts     = {ENABLE_MCTS}')
print(f'  max_nodes       = {MAX_NODES if MAX_NODES else "1,000,000 (default)"}')

## Cell 6 — Fresh Start
Run this if you have no previous state (first time ever).  
**Skip this if resuming — use Cell 7 instead.**

In [None]:
cmd = [
    sys.executable,
    os.path.join(DRIVE_DIR, 'experiments/iterative_refinement.py'),
    '--max-iterations', str(MAX_ITERATIONS),
    '--max-path-length', str(MAX_PATH_LENGTH),
]
if ENABLE_MCTS:
    cmd.append('--enable-mcts')
if MAX_NODES is not None:
    cmd.extend(['--max-nodes', str(MAX_NODES)])

print('Running:', ' '.join(cmd))
import subprocess
result = subprocess.run(cmd, cwd=DRIVE_DIR)
print('Exit code:', result.returncode)

## Cell 7 — Resume After Disconnect
After Colab disconnects: re-run Cells 1–3 (takes ~1 min), configure Cell 5, then run this cell.  
It picks up from the last completed iteration automatically.

In [None]:
cmd = [
    sys.executable,
    os.path.join(DRIVE_DIR, 'experiments/iterative_refinement.py'),
    '--resume',
    '--max-iterations', str(MAX_ITERATIONS),
    '--max-path-length', str(MAX_PATH_LENGTH),
]
if ENABLE_MCTS:
    cmd.append('--enable-mcts')
if MAX_NODES is not None:
    cmd.extend(['--max-nodes', str(MAX_NODES)])

print('Resuming:', ' '.join(cmd))
import subprocess
result = subprocess.run(cmd, cwd=DRIVE_DIR)
print('Exit code:', result.returncode)

## Cell 8 — Check Progress Anytime

In [None]:
import os, json

state_file = os.path.join(DRIVE_DIR, 'experiments/refinement/refinement_state.json')
if not os.path.exists(state_file):
    print('No state file yet — run Cell 6 first.')
else:
    with open(state_file) as f:
        state = json.load(f)

    paths_file = os.path.join(DRIVE_DIR, 'experiments/refinement/all_solved_paths.json')
    total_solved = 0
    if os.path.exists(paths_file):
        with open(paths_file) as f:
            total_solved = len(json.load(f))

    print(f'Current iteration:  {state["iteration"]}')
    print(f'Total solved:       {total_solved}/1190')
    print(f'Solved per iter:    {state["solved_per_iteration"]}')
    print(f'Total per iter:     {state["total_solved_per_iteration"]}')

    refinement_dir = os.path.join(DRIVE_DIR, 'experiments/refinement')
    print(f'\nFiles in refinement dir:')
    for f in sorted(os.listdir(refinement_dir)):
        fpath = os.path.join(refinement_dir, f)
        size_mb = os.path.getsize(fpath) / 1e6 if os.path.isfile(fpath) else 0
        tag = f'({size_mb:.1f} MB)' if os.path.isfile(fpath) else '(dir)'
        print(f'  {f}  {tag}')