# AC-Solver: Parallelized Value-Guided Search on A100

This notebook runs V-guided greedy, beam search, and MCTS on **all 1190 Miller-Schupp presentations** using multiprocessing + A100 GPU.

**Runtime setup**: Go to `Runtime > Change runtime type > A100 GPU`

**Expected runtime**: ~30-60 min for V-guided at 1M nodes, ~2-4h for all three algorithms.

## 1. Setup: Clone Repo and Install Dependencies

In [10]:
print("Hello World")

Hello World


In [11]:
import os
if not os.path.exists('/content/AC-Solver-Caltech'):
    !git clone https://github.com/Avi161/AC-Solver-Caltech.git
%cd /content/AC-Solver-Caltech
!git fetch origin
!git checkout feat/test-antigravity
!git pull origin feat/test-antigravity
!pip install -q torch numpy pyyaml

/content/AC-Solver-Caltech
Already on 'feat/test-antigravity'
Your branch is up to date with 'origin/feat/test-antigravity'.
From https://github.com/Avi161/AC-Solver-Caltech
 * branch            feat/test-antigravity -> FETCH_HEAD
Already up to date.


In [12]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
# NOTE: Do NOT call torch.cuda.get_device_name() or get_device_properties()
# here. Those calls initialize CUDA in the main process, which breaks
# fork-based multiprocessing later. GPU will be used inside workers.
if not torch.cuda.is_available():
    print("WARNING: No GPU detected! Go to Runtime > Change runtime type > A100 GPU")

import multiprocessing
print(f"CPU cores: {multiprocessing.cpu_count()}")

PyTorch version: 2.10.0+cu128
CUDA available: True
CPU cores: 12


In [13]:
import os
# Verify that model checkpoints exist
for f in ['value_search/checkpoints/best_mlp.pt',
          'value_search/checkpoints/best_seq.pt',
          'value_search/checkpoints/feature_stats.json']:
    exists = os.path.exists(f)
    size = os.path.getsize(f) if exists else 0
    print(f"  {'✓' if exists else '✗'} {f} ({size/1024:.0f} KB)")

  ✓ value_search/checkpoints/best_mlp.pt (405 KB)
  ✓ value_search/checkpoints/best_seq.pt (159 KB)
  ✓ value_search/checkpoints/feature_stats.json (1 KB)


## 2. Configuration

Adjust these settings based on how long you want to run.

In [19]:
# ============================================================
# CONFIGURATION — adjust these as needed
# ============================================================

# Which algorithms to run. Options: 'v_guided', 'beam', 'mcts', or 'all'
ALGORITHM = 'all'

# Value network architecture: 'mlp' (faster) or 'seq' (more accurate)
ARCHITECTURE = 'mlp'

# Max nodes to explore per presentation per algorithm.
# Recommended budgets:
#   1,000   -> ~30s total, quick sanity check
#   10,000  -> ~5 min total
#   100,000 -> ~30-60 min for v_guided, much longer for MCTS
#   1,000,000 -> ~4-8h for v_guided (best results)
MAX_NODES = 1_000_000

# Number of parallel worker processes.
# A100 Colab has ~12 CPU cores. Workers run search in parallel.
# Each worker loads its own copy of the model onto the GPU.
# Recommended: 6-8 for A100 (balance CPU search + GPU inference)
NUM_WORKERS = 8

# Beam width (only for beam search)
BEAM_WIDTH = 50

# MCTS exploration constant
C_EXPLORE = 1.41

# Apply cyclic reduction after AC moves (slightly better results)
CYCLICALLY_REDUCE = True

print(f"Config: algorithm={ALGORITHM}, arch={ARCHITECTURE}, "
      f"max_nodes={MAX_NODES:,}, workers={NUM_WORKERS}")

Config: algorithm=all, arch=mlp, max_nodes=1,000,000, workers=8


## 3. Run Parallel Search

In [15]:
import sys
import json
import time
import datetime
import numpy as np
from ast import literal_eval

import multiprocessing
multiprocessing.set_start_method('spawn', force=True)
from multiprocessing import Pool

sys.path.insert(0, '.')

from ac_solver.envs.ac_moves import ACMove
from ac_solver.envs.utils import is_presentation_trivial
from value_search.value_guided_search import (
    value_guided_greedy_search, beam_search, load_model,
    backfill_solution_cache,
)
from value_search.mcts import mcts_search

# Import worker function from .py file (required for spawn)
from scripts.colab_worker import search_worker

print("Imports OK")

Imports OK


In [16]:
# Load all presentations
def load_presentations(path):
    pres = []
    with open(path) as f:
        for line in f:
            if line.strip():
                pres.append(np.array(literal_eval(line.strip()), dtype=np.int8))
    return pres

all_pres = load_presentations('ac_solver/search/miller_schupp/data/all_presentations.txt')
greedy_solved_pres = load_presentations('ac_solver/search/miller_schupp/data/greedy_solved_presentations.txt')
greedy_solved_set = set(tuple(p) for p in greedy_solved_pres)

print(f"Total presentations: {len(all_pres)}")
print(f"Greedy solved: {len(greedy_solved_set)}")
print(f"Unsolved by greedy: {len(all_pres) - len(greedy_solved_set)}")

Total presentations: 1190
Greedy solved: 533
Unsolved by greedy: 657


In [17]:
# search_worker is imported from scripts/colab_worker.py above.
# It must live in a .py file (not a notebook cell) because the 'spawn'
# start method creates fresh Python processes that need to import
# the worker function by module path.
print(f"Worker function: {search_worker.__module__}.{search_worker.__name__}")
print("Worker function ready")

Worker function: scripts.colab_worker.search_worker
Worker function ready


In [18]:
def run_parallel(presentations, algorithm, config, num_workers, greedy_solved_set):
    """Run search in parallel across all presentations."""
    n = len(presentations)
    print(f"\n{'='*60}")
    print(f"  {algorithm.upper()} | {config['max_nodes']:,} nodes | {num_workers} workers")
    print(f"{'='*60}")
    
    work_items = [
        (i, pres.tolist(), algorithm, config)
        for i, pres in enumerate(presentations)
    ]
    
    results = []
    solved_count = 0
    newly_solved_count = 0
    t_start = time.time()
    
    with Pool(processes=num_workers) as pool:
        for result in pool.imap_unordered(search_worker, work_items):
            results.append(result)
            if result['solved']:
                solved_count += 1
                if tuple(presentations[result['idx']]) not in greedy_solved_set:
                    newly_solved_count += 1
            
            done = len(results)
            # Changed from done % 50 == 0 to done % 1 == 0 to print EVERY result instantly
            if done % 1 == 0 or done == n:
                elapsed = time.time() - t_start
                rate = elapsed / done
                eta = rate * (n - done)
                eta_s = f"{eta/60:.1f}m" if eta < 3600 else f"{eta/3600:.1f}h"
                print(f"  {done}/{n} | solved={solved_count} | "
                      f"new={newly_solved_count} | ETA {eta_s}")
    
    total_time = time.time() - t_start
    results.sort(key=lambda r: r['idx'])
    
    solved_results = [r for r in results if r['solved']]
    path_lengths = [r['path_length'] for r in solved_results]
    
    print(f"\n  RESULT: {len(solved_results)}/{n} solved "
          f"({newly_solved_count} new beyond greedy) in {total_time/60:.1f}m")
    if path_lengths:
        print(f"  Avg path: {np.mean(path_lengths):.1f}, "
              f"Max: {max(path_lengths)}, Min: {min(path_lengths)}")
    
    # Show newly solved
    newly = [r for r in solved_results 
             if tuple(presentations[r['idx']]) not in greedy_solved_set]
    if newly:
        print(f"\n  *** NEWLY SOLVED ({len(newly)} presentations) ***")
        for r in newly[:30]:
            pres = presentations[r['idx']]
            mrl = len(pres) // 2
            tl = int(np.count_nonzero(pres[:mrl]) + np.count_nonzero(pres[mrl:]))
            print(f"    idx={r['idx']:>4d}, path_len={r['path_length']:>3d}, "
                  f"word_len={tl}, nodes={r['nodes_explored']:>7d}")
        if len(newly) > 30:
            print(f"    ... and {len(newly)-30} more")
    
    return results

# Build config
device = 'cuda' if torch.cuda.is_available() else 'cpu'
arch = ARCHITECTURE
ckpt = f'value_search/checkpoints/best_{arch}.pt'
stats = 'value_search/checkpoints/feature_stats.json'

config = {
    'device': device,
    'architecture': arch,
    'checkpoint': ckpt,
    'feature_stats': stats,
    'max_nodes': MAX_NODES,  # Default, used for V-guided and Beam
    'beam_width': BEAM_WIDTH,
    'c_explore': C_EXPLORE,
    'cyclically_reduce': CYCLICALLY_REDUCE,
}

# Determine algorithms
if ALGORITHM == 'all':
    algos = ['v_guided', 'beam', 'mcts']
else:
    algos = [ALGORITHM]

print(f"Device: {device}")
print(f"Algorithms to run: {algos}")
print(f"Starting...")

all_results = {}
all_solved_sets = {}

for algo in algos:
    # Give MCTS a smaller budget of 100,000 so it finishes reasonably fast!
    algo_config = dict(config)
    if algo == 'mcts':
        algo_config['max_nodes'] = 100_000
        
    results = run_parallel(all_pres, algo, algo_config, NUM_WORKERS, greedy_solved_set)
    all_results[algo] = results
    all_solved_sets[algo] = set(r['idx'] for r in results if r['solved'])

print(f"\n{'='*60}")
print(f"  ALL DONE")
print(f"{'='*60}")

Device: cuda
Algorithms to run: ['v_guided', 'beam', 'mcts']
Starting...

  V_GUIDED | 1,000 nodes | 8 workers
  1/1190 | solved=1 | new=0 | ETA 1.3h
  2/1190 | solved=2 | new=0 | ETA 39.4m
  3/1190 | solved=3 | new=0 | ETA 26.7m
  4/1190 | solved=4 | new=0 | ETA 20.2m
  5/1190 | solved=5 | new=0 | ETA 16.3m
  6/1190 | solved=6 | new=0 | ETA 13.7m
  7/1190 | solved=7 | new=0 | ETA 11.8m
  8/1190 | solved=8 | new=0 | ETA 10.3m
  9/1190 | solved=9 | new=0 | ETA 9.2m
  10/1190 | solved=10 | new=0 | ETA 8.3m
  11/1190 | solved=11 | new=0 | ETA 7.6m
  12/1190 | solved=12 | new=0 | ETA 7.4m
  13/1190 | solved=13 | new=0 | ETA 6.9m
  14/1190 | solved=14 | new=0 | ETA 6.5m
  15/1190 | solved=15 | new=0 | ETA 6.0m
  16/1190 | solved=16 | new=0 | ETA 5.7m
  17/1190 | solved=17 | new=0 | ETA 5.3m
  18/1190 | solved=18 | new=0 | ETA 5.1m
  19/1190 | solved=19 | new=0 | ETA 4.8m
  20/1190 | solved=20 | new=0 | ETA 4.6m
  21/1190 | solved=21 | new=0 | ETA 4.4m
  22/1190 | solved=22 | new=0 | ETA 4.2

KeyboardInterrupt: 

## 4. Analysis

In [None]:
# Comparison table
print(f"{'Algorithm':<20} | {'Solved':>10} | {'New':>5} | {'Avg Path':>10} | {'Max Path':>10}")
print(f"{'-'*20}-+-{'-'*10}-+-{'-'*5}-+-{'-'*10}-+-{'-'*10}")

for algo, results in all_results.items():
    solved = [r for r in results if r['solved']]
    newly = [r for r in solved if tuple(all_pres[r['idx']]) not in greedy_solved_set]
    paths = [r['path_length'] for r in solved]
    avg_p = f"{np.mean(paths):.1f}" if paths else "—"
    max_p = f"{max(paths)}" if paths else "—"
    print(f"{algo:<20} | {f'{len(solved)}/1190':>10} | {len(newly):>5} | {avg_p:>10} | {max_p:>10}")

# Union
if len(all_solved_sets) > 1:
    union = set()
    for s in all_solved_sets.values():
        union |= s
    union_new = sum(1 for idx in union if tuple(all_pres[idx]) not in greedy_solved_set)
    print(f"{'UNION':<20} | {f'{len(union)}/1190':>10} | {union_new:>5} |")
    
    # Uniquely solved by each
    print(f"\nUniquely solved:")
    for algo, s in all_solved_sets.items():
        others = set().union(*(v for k, v in all_solved_sets.items() if k != algo))
        unique = s - others
        if unique:
            print(f"  {algo}: {len(unique)} unique")

In [None]:
# Save all results
timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
output_dir = f'experiments/results/{timestamp}_colab'
os.makedirs(output_dir, exist_ok=True)

# Save config
with open(f'{output_dir}/config.json', 'w') as f:
    json.dump({
        'algorithms': algos,
        'max_nodes': MAX_NODES,
        'workers': NUM_WORKERS,
        'architecture': ARCHITECTURE,
        'beam_width': BEAM_WIDTH,
        'c_explore': C_EXPLORE,
        'device': device,
    }, f, indent=2)

# Save per-algorithm results
for algo, results in all_results.items():
    with open(f'{output_dir}/{algo}_results.json', 'w') as f:
        json.dump(results, f, indent=2)

# Save newly solved summary
all_newly = {}
for algo, results in all_results.items():
    newly = [r for r in results if r['solved'] 
             and tuple(all_pres[r['idx']]) not in greedy_solved_set]
    all_newly[algo] = [r['idx'] for r in newly]

with open(f'{output_dir}/newly_solved.json', 'w') as f:
    json.dump(all_newly, f, indent=2)

print(f"Results saved to: {output_dir}/")
print(f"Files:")
for fname in sorted(os.listdir(output_dir)):
    size = os.path.getsize(f'{output_dir}/{fname}')
    print(f"  {fname} ({size/1024:.0f} KB)")

## 5. Verify Newly Solved Paths

In [None]:
# Verify that all reported solutions are actually correct
def verify_solution(pres, path):
    """Replay path and verify it reaches trivial."""
    state = np.array(pres, dtype=np.int8)
    mrl = len(state) // 2
    wl = [int(np.count_nonzero(state[:mrl])), int(np.count_nonzero(state[mrl:]))]
    for action, expected_len in path:
        state, wl = ACMove(action, state, mrl, wl, cyclical=False)
    return is_presentation_trivial(state)

print("Verifying all solutions...")
errors = 0
verified = 0
for algo, results in all_results.items():
    for r in results:
        if r['solved'] and 'path' in r:
            path = [(a, l) for a, l in r['path']]
            ok = verify_solution(all_pres[r['idx']], path)
            if not ok:
                print(f"  ERROR: {algo} idx={r['idx']} path does NOT reach trivial!")
                errors += 1
            verified += 1

print(f"Verified {verified} solutions, {errors} errors")
if errors == 0:
    print("All solutions are verified correct! ✓")

## 6. Download Results (Optional)

In [None]:
# Zip results for download
import shutil
zip_path = shutil.make_archive(f'/content/ac_solver_results_{timestamp}', 'zip', output_dir)
print(f"Results zipped to: {zip_path}")

# If running on Colab, offer download
try:
    from google.colab import files
    files.download(zip_path)
except ImportError:
    print("Not on Colab — download manually from the file browser.")