# AC-Solver Iterative Refinement — Colab Runner

**Runtime recommendation: High-RAM CPU**  
The search (V-guided greedy, beam) is CPU-only. GPU only helps training, which takes ~10 min per iteration — not the bottleneck.

**Normal workflow:**
- **First time / clean restart:** Cells 1 → 2 → 3 → 4 → 5 → 6
- **After Colab disconnect:** Cells 1 → 2 → 3 → 5 → 8 *(state persists on Drive)*
- **If state files are missing but partial results exist:** Cells 1 → 2 → 3 → 5 → 7 → 8

## Cell 1 — Mount Google Drive

In [1]:
from google.colab import drive
drive.mount('/content/drive')
print('Drive mounted.')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Drive mounted.


## Cell 2 — Clone or Update Repo on Drive

**First time only:** enter your GitHub Personal Access Token when prompted.  
Get one at: GitHub → Settings → Developer Settings → Personal access tokens → Fine-grained → repo read access.  
After first clone, this cell just does `git pull`.

In [2]:
import os
import subprocess

REPO_URL = 'https://github.com/Avi161/AC-Solver-Caltech.git'
DRIVE_DIR = '/content/drive/MyDrive/AC-Solver-Caltech'
BRANCH = 'feat/test-new-exp-claude'

if not os.path.exists(DRIVE_DIR):
    # First time: clone with token
    from getpass import getpass
    token = getpass('GitHub Personal Access Token: ')
    auth_url = REPO_URL.replace('https://', f'https://{token}@')
    subprocess.run(['git', 'clone', auth_url, DRIVE_DIR], check=True)
    subprocess.run(['git', 'checkout', BRANCH], cwd=DRIVE_DIR, check=True)
    print('Repo cloned to Drive.')
else:
    # Already cloned: checkout branch first, then pull
    subprocess.run(['git', 'checkout', BRANCH], cwd=DRIVE_DIR, check=True)
    result = subprocess.run(['git', 'pull'], cwd=DRIVE_DIR, capture_output=True, text=True)
    print(result.stdout or 'Already up to date.')

os.chdir(DRIVE_DIR)
print(f'Working directory: {os.getcwd()}')

Updating d7d1132..133042e
Fast-forward
 experiments/colab_iterative_refinement.ipynb | 239 +++++++++++++++++++--------
 experiments/run_experiments.py               |  12 +-
 2 files changed, 179 insertions(+), 72 deletions(-)

Working directory: /content/drive/MyDrive/AC-Solver-Caltech


## Cell 3 — Install Dependencies

In [3]:
import subprocess
result = subprocess.run(
    ['pip', 'install', '-q', '-r', 'requirements.txt'],
    cwd=DRIVE_DIR, capture_output=True, text=True
)
print(result.stdout[-2000:] if result.stdout else 'Done.')
if result.returncode != 0:
    print('ERRORS:', result.stderr[-1000:])

Done.


## Cell 4 — Verify Setup

In [5]:
import os, json

checks = {
    'Original model checkpoint':  'value_search/checkpoints/best_mlp.pt',
    'Feature stats':              'value_search/checkpoints/feature_stats.json',
    'Greedy solved presentations': 'ac_solver/search/miller_schupp/data/greedy_solved_presentations.txt',
    'Greedy search paths':         'ac_solver/search/miller_schupp/data/greedy_search_paths.txt',
    'Config':                      'experiments/config.yaml',
}

all_ok = True
for label, path in checks.items():
    full = os.path.join(DRIVE_DIR, path)
    exists = os.path.exists(full)
    status = '✓' if exists else '✗  MISSING'
    print(f'  {status}  {label}')
    if not exists:
        all_ok = False

# Check refinement state (if resuming)
state_file = os.path.join(DRIVE_DIR, 'experiments/refinement/refinement_state.json')
if os.path.exists(state_file):
    with open(state_file) as f:
        state = json.load(f)
    print(f'\n  Saved state found: iteration={state["iteration"]}, '
          f'total_solved_history={state["total_solved_per_iteration"]}')
else:
    print('\n  No saved state — will start fresh.')

if all_ok:
    print('\nAll checks passed. Ready to run.')
else:
    print('\nFix missing files before running.')

  ✓  Original model checkpoint
  ✓  Feature stats
  ✓  Greedy solved presentations
  ✓  Greedy search paths
  ✓  Config

  Saved state found: iteration=0, total_solved_history=[533]

All checks passed. Ready to run.


## Cell 5 — Configuration
Edit these before running. Then run Cell 6 (fresh start) or Cell 8 (resume).

In [7]:
# ── Tune these ────────────────────────────────────────────────────────────────
MAX_ITERATIONS  = 5       # How many search→train cycles
MAX_PATH_LENGTH = 800     # Reject training paths longer than this
ENABLE_MCTS     = False   # MCTS is slow; keep False unless you have many hours
MAX_NODES       = 100_000 # 100K nodes: ~25s per presentation, ~5hrs per iteration
# ──────────────────────────────────────────────────────────────────────────────

import sys
sys.path.insert(0, DRIVE_DIR)

print('Config set:')
print(f'  max_iterations  = {MAX_ITERATIONS}')
print(f'  max_path_length = {MAX_PATH_LENGTH}')
print(f'  enable_mcts     = {ENABLE_MCTS}')
print(f'  max_nodes       = {MAX_NODES:,}')


# ── Progress streaming helper (used by Cells 6 and 7) ─────────────────────────
def run_with_progress(cmd, cwd):
    """Run cmd as subprocess, printing headers normally and progress on one line."""
    import subprocess, sys, os

    HEADER_PATTERNS = [
        '===', 'ITERATION', '--- Step', 'complete:', 'REFINEMENT COMPLETE',
        'ERROR', 'Traceback', 'File "', 'Error:', 'Running:', 'Resuming:',
        'Total presentations', 'Greedy baseline', 'Resumed from',
        'No previous state', 'Loading greedy', 'Converged', 'Interrupted',
        'State saved', '[idx=',
    ]

    env = os.environ.copy()
    env['PYTHONUNBUFFERED'] = '1'  # force line-by-line flushing in subprocess

    proc = subprocess.Popen(
        cmd, cwd=cwd,
        stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
        text=True, bufsize=1, env=env,
    )

    last_was_progress = False
    for raw in proc.stdout:
        line = raw.rstrip()
        if not line:
            continue
        is_header = any(p in line for p in HEADER_PATTERNS)
        if is_header:
            if last_was_progress:
                sys.stdout.write('\n')
            print(line)
            last_was_progress = False
        else:
            sys.stdout.write(f'\r  {line:<90}')
            sys.stdout.flush()
            last_was_progress = True

    if last_was_progress:
        sys.stdout.write('\n')
    proc.wait()
    return proc.returncode

Config set:
  max_iterations  = 5
  max_path_length = 800
  enable_mcts     = False
  max_nodes       = 100,000


## Cell 6 — Fresh Start
Run this to begin a new refinement run from the greedy baseline. **Skip if resuming an existing run.**

In [38]:
cmd = [
    sys.executable,
    os.path.join(DRIVE_DIR, 'experiments/iterative_refinement.py'),
    '--max-iterations', str(MAX_ITERATIONS),
    '--max-path-length', str(MAX_PATH_LENGTH),
]
if ENABLE_MCTS:
    cmd.append('--enable-mcts')
if MAX_NODES is not None:
    cmd.extend(['--max-nodes', str(MAX_NODES)])

print('Running:', ' '.join(cmd))
rc = run_with_progress(cmd, cwd=DRIVE_DIR)
print('Exit code:', rc)

Running: /usr/bin/python3 /content/drive/MyDrive/AC-Solver-Caltech/experiments/iterative_refinement.py --max-iterations 5 --max-path-length 300 --max-nodes 100000
    AC-Solver Iterative Refinement Pipeline                                                 
    State file:     /content/drive/MyDrive/AC-Solver-Caltech/experiments/refinement/refinement_state.json
  Total presentations: 1190
  Loading greedy-solved paths as seed...
  Greedy baseline: 533 solved
  ITERATION 0
    Model: /content/drive/MyDrive/AC-Solver-Caltech/value_search/checkpoints/best_mlp.pt (original, trained on greedy paths)
  --- Step 1: Search ---
    Searching 657/1190 unsolved presentations                                               
  Running: /usr/bin/python3 /content/drive/MyDrive/AC-Solver-Caltech/experiments/run_experiments.py --config /content/drive/MyDrive/AC-Solver-Caltech/experiments/refinement/iter_0/iter_config.yaml --indices /content/drive/MyDrive/AC-Solver-Caltech/experiments/refinement/iter_0/unso

: 

## Cell 7 — Reconstruct State (run only if state files are missing)
Run this **only** when `refinement_state.json` is missing but partial search results exist on Drive (e.g. after accidentally deleting state or after a very early crash before any state was saved).  
It auto-detects the latest partial results directory and rebuilds both state files.  
Then run Cell 8 to continue.

In [27]:
import os, sys, json

sys.path.insert(0, DRIVE_DIR)
from value_search.data_extraction import load_presentations, load_paths

DATA_DIR       = os.path.join(DRIVE_DIR, 'ac_solver/search/miller_schupp/data')
REFINEMENT_DIR = os.path.join(DRIVE_DIR, 'experiments/refinement')
RESULTS_BASE   = os.path.join(DRIVE_DIR, 'experiments/results')
os.makedirs(REFINEMENT_DIR, exist_ok=True)

# 1. Auto-detect the latest partial results directory
candidate_dirs = sorted([
    d for d in os.listdir(RESULTS_BASE)
    if os.path.isdir(os.path.join(RESULTS_BASE, d))
    and os.path.exists(os.path.join(RESULTS_BASE, d, 'v_guided_greedy_progress.jsonl'))
])
if not candidate_dirs:
    print("ERROR: No partial results found in", RESULTS_BASE)
    raise SystemExit

PARTIAL_DIR = os.path.join(RESULTS_BASE, candidate_dirs[-1])
print(f"Using partial results dir: {PARTIAL_DIR}")

# Show what's in it
jsonl = os.path.join(PARTIAL_DIR, 'v_guided_greedy_progress.jsonl')
total, solved_indices = 0, []
with open(jsonl) as f:
    for line in f:
        r = json.loads(line)
        total += 1
        if r.get('solved'):
            solved_indices.append(r['idx'])
print(f"  Contains: {total} processed, {len(solved_indices)} solved: {solved_indices}")

# 2. Reload greedy paths from source data
solved_pres = load_presentations(os.path.join(DATA_DIR, 'greedy_solved_presentations.txt'))
raw_paths   = load_paths(os.path.join(DATA_DIR, 'greedy_search_paths.txt'))
greedy_paths = {
    tuple(pres): [[a - 1, l] for a, l in raw_path[1:]]
    for pres, raw_path in zip(solved_pres, raw_paths)
}
print(f"Greedy paths loaded: {len(greedy_paths)}")

# 3. Write all_solved_paths.json
paths_file = os.path.join(REFINEMENT_DIR, 'all_solved_paths.json')
with open(paths_file, 'w') as f:
    json.dump({str(list(k)): v for k, v in greedy_paths.items()}, f)
print(f"Written: all_solved_paths.json ({len(greedy_paths)} paths)")

# 4. Write refinement_state.json pointing at partial results
state = {
    "iteration": 0,
    "solved_per_iteration": [],
    "total_solved_per_iteration": [len(greedy_paths)],
    "model_paths": [],
    "results_dirs": [PARTIAL_DIR],
}
state_file = os.path.join(REFINEMENT_DIR, 'refinement_state.json')
with open(state_file, 'w') as f:
    json.dump(state, f, indent=2)
print(f"Written: refinement_state.json")
print(f"\nReady — run Cell 8 (Resume).")

Using partial results dir: /content/drive/MyDrive/AC-Solver-Caltech/experiments/results/2026-02-26_04-06-25
  Contains: 397 processed, 6 solved: [599, 689, 698, 711, 712, 834]
Greedy paths loaded: 533
Written: all_solved_paths.json (533 paths)
Written: refinement_state.json

Ready — run Cell 8 (Resume).


## Cell 8 — Resume After Disconnect
After Colab disconnects: re-run Cells 1 → 2 → 3 → 5, then run this cell.  
Picks up from the last completed iteration automatically — state persists on Drive.

In [8]:
cmd = [
    sys.executable,
    os.path.join(DRIVE_DIR, 'experiments/iterative_refinement.py'),
    '--resume',
    '--max-iterations', str(MAX_ITERATIONS),
    '--max-path-length', str(MAX_PATH_LENGTH),
]
if ENABLE_MCTS:
    cmd.append('--enable-mcts')
if MAX_NODES is not None:
    cmd.extend(['--max-nodes', str(MAX_NODES)])

print('Resuming:', ' '.join(cmd))
rc = run_with_progress(cmd, cwd=DRIVE_DIR)
print('Exit code:', rc)

Resuming: /usr/bin/python3 /content/drive/MyDrive/AC-Solver-Caltech/experiments/iterative_refinement.py --resume --max-iterations 5 --max-path-length 800 --max-nodes 100000
    AC-Solver Iterative Refinement Pipeline                                                 
    State file:     /content/drive/MyDrive/AC-Solver-Caltech/experiments/refinement/refinement_state.json
  Total presentations: 1190
  Resumed from iteration 0
    Previously solved: 533/1190                                                             
  ITERATION 0
    Model: /content/drive/MyDrive/AC-Solver-Caltech/value_search/checkpoints/best_mlp.pt (original, trained on greedy paths)
  --- Step 1: Search ---
    Resuming search from: /content/drive/MyDrive/AC-Solver-Caltech/experiments/results/2026-02-26_04-06-25
  Running: /usr/bin/python3 /content/drive/MyDrive/AC-Solver-Caltech/experiments/run_experiments.py --config /content/drive/MyDrive/AC-Solver-Caltech/experiments/refinement/iter_0/iter_config.yaml --indices /c

KeyboardInterrupt: 

## Cell 9 — Skip Iter 0 Search: Retrain & Advance to Iter 1
Run this **instead of re-doing the iter 0 search** when the search is already complete
(or close enough). It:
1. Loads all solved paths (greedy seed + any solutions found in iter_0 results)
2. Rebuilds training data with the correct `negative_label = 5 × MAX_PATH_LENGTH`
3. Retrains both MLP and Seq models
4. Advances state to `iteration=1` so Cell 8 (Resume) picks up from there

In [None]:
import os, sys, json
from ast import literal_eval
sys.path.insert(0, DRIVE_DIR)

from value_search.data_extraction import build_dataset_from_dict
from value_search.benchmark import load_all_presentations

REFINEMENT_DIR = os.path.join(DRIVE_DIR, 'experiments/refinement')
all_presentations = load_all_presentations()
pres_by_idx = {i: tuple(p) for i, p in enumerate(all_presentations)}

# 1. Load existing solved paths (greedy seed)
paths_file = os.path.join(REFINEMENT_DIR, 'all_solved_paths.json')
with open(paths_file) as f:
    raw = json.load(f)
solved_paths = {tuple(literal_eval(k)): v for k, v in raw.items()}
print(f'Loaded {len(solved_paths)} solved paths from all_solved_paths.json')

# 2. Scan iter_0 search results and add any newly solved presentations
iter0_dir = os.path.join(REFINEMENT_DIR, 'iter_0')
n_extra = 0
if os.path.isdir(iter0_dir):
    subdirs = sorted([
        d for d in os.listdir(iter0_dir)
        if os.path.isdir(os.path.join(iter0_dir, d)) and d[0].isdigit()
    ])
    if subdirs:
        results_dir = os.path.join(iter0_dir, subdirs[-1])
        jsonl_path = os.path.join(results_dir, 'v_guided_greedy_progress.jsonl')
        if os.path.exists(jsonl_path):
            with open(jsonl_path) as f:
                for line in f:
                    r = json.loads(line)
                    if r.get('solved') and r.get('path'):
                        pres_tuple = pres_by_idx[r['idx']]
                        if pres_tuple not in solved_paths:
                            solved_paths[pres_tuple] = r['path']
                            n_extra += 1
            print(f'Added {n_extra} new solutions from iter_0 search results')

# 3. Save updated solved paths
with open(paths_file, 'w') as f:
    json.dump({str(list(k)): v for k, v in solved_paths.items()}, f)
print(f'Total solved paths: {len(solved_paths)}')

# 4. Build training data with corrected negative label (5x, not 2x)
negative_label = float(MAX_PATH_LENGTH * 5)
data_path = os.path.join(REFINEMENT_DIR, 'training_data_iter_0.pkl')
build_dataset_from_dict(
    solved_paths=solved_paths,
    all_presentations=all_presentations,
    output_path=data_path,
    negative_label=negative_label,
    max_path_length=MAX_PATH_LENGTH,
)
print(f'Training data built with negative_label={negative_label}')

# 5. Retrain both MLP and Seq
checkpoint_dir = os.path.join(REFINEMENT_DIR, 'checkpoints_iter_0')
cmd = [
    sys.executable,
    os.path.join(DRIVE_DIR, 'value_search/train_value_net.py'),
    '--data-path', data_path,
    '--save-dir', checkpoint_dir,
    '--architecture', 'both',
    '--epochs', '100',
]
print('Training...')
rc = run_with_progress(cmd, cwd=DRIVE_DIR)
if rc != 0:
    print(f'ERROR: training failed (exit code {rc})')
    raise SystemExit

# 6. Verify checkpoint
ckpt = os.path.join(checkpoint_dir, 'best_mlp.pt')
if not os.path.exists(ckpt):
    print('ERROR: checkpoint not found!')
    raise SystemExit

# 7. Advance state to iteration=1
state_file = os.path.join(REFINEMENT_DIR, 'refinement_state.json')
with open(state_file) as f:
    state = json.load(f)

prev_total = state['total_solved_per_iteration'][-1] if state['total_solved_per_iteration'] else 0
total_now = len(solved_paths)
new_count = total_now - prev_total
state['model_paths'].append(ckpt)
state['iteration'] = 1
state['solved_per_iteration'].append(new_count)
state['total_solved_per_iteration'].append(total_now)

with open(state_file, 'w') as f:
    json.dump(state, f, indent=2)

print(f'\nDone! State → iteration=1, total_solved={total_now} (+{new_count} new)')
print('Run Cell 8 (Resume) to continue from iteration 1.')


## Cell 10 — Check Progress Anytime

In [None]:
import os, json

state_file = os.path.join(DRIVE_DIR, 'experiments/refinement/refinement_state.json')
if not os.path.exists(state_file):
    print('No state file — run Cell 6 (fresh start) or Cell 7 (reconstruct state) first.')
else:
    with open(state_file) as f:
        state = json.load(f)

    paths_file = os.path.join(DRIVE_DIR, 'experiments/refinement/all_solved_paths.json')
    total_solved = 0
    if os.path.exists(paths_file):
        with open(paths_file) as f:
            total_solved = len(json.load(f))

    print(f'Current iteration:  {state["iteration"]}')
    print(f'Total solved:       {total_solved}/1190')
    print(f'Solved per iter:    {state["solved_per_iteration"]}')
    print(f'Total per iter:     {state["total_solved_per_iteration"]}')

    refinement_dir = os.path.join(DRIVE_DIR, 'experiments/refinement')
    print(f'\nFiles in refinement dir:')
    for f in sorted(os.listdir(refinement_dir)):
        fpath = os.path.join(refinement_dir, f)
        size_mb = os.path.getsize(fpath) / 1e6 if os.path.isfile(fpath) else 0
        tag = f'({size_mb:.1f} MB)' if os.path.isfile(fpath) else '(dir)'
        print(f'  {f}  {tag}')