# AC-Solver: Parallelized Value-Guided Search on A100

This notebook runs V-guided greedy, beam search, and MCTS on **all 1190 Miller-Schupp presentations** using multiprocessing + A100 GPU.

**Runtime setup**: Go to `Runtime > Change runtime type > A100 GPU`

**Expected runtime**: ~30-60 min for V-guided at 1M nodes, ~2-4h for all three algorithms.

## 1. Setup: Clone Repo and Install Dependencies

In [1]:
print("Hello World")

Hello World


In [2]:
# Clone the repo (change to your fork URL if needed)
!git clone https://github.com/Avi161/AC-Solver-Caltech.git
%cd AC-Solver-Caltech
!git checkout feat/test-antigravity
!pip install -q torch numpy pyyaml

Cloning into 'AC-Solver-Caltech'...
remote: Enumerating objects: 1167, done.[K
remote: Counting objects: 100% (180/180), done.[K
remote: Compressing objects: 100% (50/50), done.[K
remote: Total 1167 (delta 147), reused 138 (delta 130), pack-reused 987 (from 1)[K
Receiving objects: 100% (1167/1167), 9.17 MiB | 10.48 MiB/s, done.
Resolving deltas: 100% (598/598), done.
/content/AC-Solver-Caltech
error: pathspec 'feat/test-antigravity' did not match any file(s) known to git


In [4]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("WARNING: No GPU detected! Go to Runtime > Change runtime type > A100 GPU")

import multiprocessing
print(f"CPU cores: {multiprocessing.cpu_count()}")

PyTorch version: 2.10.0+cu128
CUDA available: True
GPU: NVIDIA A100-SXM4-80GB
GPU Memory: 85.1 GB
CPU cores: 12


In [5]:
import os
# Verify that model checkpoints exist
for f in ['value_search/checkpoints/best_mlp.pt',
          'value_search/checkpoints/best_seq.pt',
          'value_search/checkpoints/feature_stats.json']:
    exists = os.path.exists(f)
    size = os.path.getsize(f) if exists else 0
    print(f"  {'✓' if exists else '✗'} {f} ({size/1024:.0f} KB)")

  ✓ value_search/checkpoints/best_mlp.pt (405 KB)
  ✓ value_search/checkpoints/best_seq.pt (159 KB)
  ✓ value_search/checkpoints/feature_stats.json (1 KB)


## 2. Configuration

Adjust these settings based on how long you want to run.

In [6]:
# ============================================================
# CONFIGURATION — adjust these as needed
# ============================================================

# Which algorithms to run. Options: 'v_guided', 'beam', 'mcts', or 'all'
ALGORITHM = 'all'

# Value network architecture: 'mlp' (faster) or 'seq' (more accurate)
ARCHITECTURE = 'mlp'

# Max nodes to explore per presentation per algorithm.
# Recommended budgets:
#   1,000   -> ~30s total, quick sanity check
#   10,000  -> ~5 min total
#   100,000 -> ~30-60 min for v_guided, much longer for MCTS
#   1,000,000 -> ~4-8h for v_guided (best results)
MAX_NODES = 1_000_000

# Number of parallel worker processes.
# A100 Colab has ~12 CPU cores. Workers run search in parallel.
# Each worker loads its own copy of the model onto the GPU.
# Recommended: 6-8 for A100 (balance CPU search + GPU inference)
NUM_WORKERS = 8

# Beam width (only for beam search)
BEAM_WIDTH = 50

# MCTS exploration constant
C_EXPLORE = 1.41

# Apply cyclic reduction after AC moves (slightly better results)
CYCLICALLY_REDUCE = True

print(f"Config: algorithm={ALGORITHM}, arch={ARCHITECTURE}, "
      f"max_nodes={MAX_NODES:,}, workers={NUM_WORKERS}")

Config: algorithm=all, arch=mlp, max_nodes=1,000,000, workers=8


## 3. Run Parallel Search

In [11]:
import sys
import json
import time
import datetime
import numpy as np
from ast import literal_eval
import multiprocessing
multiprocessing.set_start_method('spawn', force=True)
from multiprocessing import Pool


sys.path.insert(0, '.')

from ac_solver.envs.ac_moves import ACMove
from ac_solver.envs.utils import is_presentation_trivial
from value_search.value_guided_search import (
    value_guided_greedy_search, beam_search, load_model,
    backfill_solution_cache,
)
from value_search.mcts import mcts_search

print("Imports OK")

Imports OK


In [8]:
# Load all presentations
def load_presentations(path):
    pres = []
    with open(path) as f:
        for line in f:
            if line.strip():
                pres.append(np.array(literal_eval(line.strip()), dtype=np.int8))
    return pres

all_pres = load_presentations('ac_solver/search/miller_schupp/data/all_presentations.txt')
greedy_solved_pres = load_presentations('ac_solver/search/miller_schupp/data/greedy_solved_presentations.txt')
greedy_solved_set = set(tuple(p) for p in greedy_solved_pres)

print(f"Total presentations: {len(all_pres)}")
print(f"Greedy solved: {len(greedy_solved_set)}")
print(f"Unsolved by greedy: {len(all_pres) - len(greedy_solved_set)}")

Total presentations: 1190
Greedy solved: 533
Unsolved by greedy: 657


In [9]:
def search_worker(args):
    """Search a single presentation. Runs in a subprocess."""
    idx, pres_list, algorithm, config = args
    pres = np.array(pres_list, dtype=np.int8)
    
    # Each worker loads model independently (avoids CUDA fork issues)
    device = config['device']
    model, feat_mean, feat_std = load_model(
        config['checkpoint'], config['architecture'],
        config['feature_stats'], device
    )
    
    t0 = time.time()
    if algorithm == 'v_guided':
        solved, path, stats = value_guided_greedy_search(
            pres, model=model, architecture=config['architecture'],
            feat_mean=feat_mean, feat_std=feat_std,
            max_nodes_to_explore=config['max_nodes'], device=device,
            cyclically_reduce_after_moves=config.get('cyclically_reduce', False),
        )
    elif algorithm == 'beam':
        solved, path, stats = beam_search(
            pres, model=model, architecture=config['architecture'],
            feat_mean=feat_mean, feat_std=feat_std,
            beam_width=config.get('beam_width', 50),
            max_nodes_to_explore=config['max_nodes'], device=device,
        )
    elif algorithm == 'mcts':
        solved, path, stats = mcts_search(
            pres, model=model, architecture=config['architecture'],
            feat_mean=feat_mean, feat_std=feat_std,
            max_nodes_to_explore=config['max_nodes'],
            c_explore=config.get('c_explore', 1.41), device=device,
        )
    else:
        raise ValueError(f"Unknown algorithm: {algorithm}")
    
    elapsed = time.time() - t0
    result = {
        'idx': idx, 'solved': solved,
        'path_length': len(path) if solved else 0,
        'nodes_explored': stats.get('nodes_explored', 0),
        'time': elapsed,
    }
    if solved and path:
        result['path'] = [[int(a), int(l)] for a, l in path]
    return result

print("Worker function defined")

Worker function defined


In [None]:
def run_parallel(presentations, algorithm, config, num_workers, greedy_solved_set):
    """Run search in parallel across all presentations."""
    n = len(presentations)
    print(f"\n{'='*60}")
    print(f"  {algorithm.upper()} | {config['max_nodes']:,} nodes | {num_workers} workers")
    print(f"{'='*60}")
    
    work_items = [
        (i, pres.tolist(), algorithm, config)
        for i, pres in enumerate(presentations)
    ]
    
    results = []
    solved_count = 0
    newly_solved_count = 0
    t_start = time.time()
    
    with Pool(processes=num_workers) as pool:
        for result in pool.imap_unordered(search_worker, work_items):
            results.append(result)
            if result['solved']:
                solved_count += 1
                if tuple(presentations[result['idx']]) not in greedy_solved_set:
                    newly_solved_count += 1
            
            done = len(results)
            if done % 1 == 0 or done == n:
                elapsed = time.time() - t_start
                rate = elapsed / done
                eta = rate * (n - done)
                eta_s = f"{eta/60:.1f}m" if eta < 3600 else f"{eta/3600:.1f}h"
                print(f"  {done}/{n} | solved={solved_count} | "
                      f"new={newly_solved_count} | ETA {eta_s}")
    
    total_time = time.time() - t_start
    results.sort(key=lambda r: r['idx'])
    
    solved_results = [r for r in results if r['solved']]
    path_lengths = [r['path_length'] for r in solved_results]
    
    print(f"\n  RESULT: {len(solved_results)}/{n} solved "
          f"({newly_solved_count} new beyond greedy) in {total_time/60:.1f}m")
    if path_lengths:
        print(f"  Avg path: {np.mean(path_lengths):.1f}, "
              f"Max: {max(path_lengths)}, Min: {min(path_lengths)}")
    
    # Show newly solved
    newly = [r for r in solved_results 
             if tuple(presentations[r['idx']]) not in greedy_solved_set]
    if newly:
        print(f"\n  *** NEWLY SOLVED ({len(newly)} presentations) ***")
        for r in newly[:30]:
            pres = presentations[r['idx']]
            mrl = len(pres) // 2
            tl = int(np.count_nonzero(pres[:mrl]) + np.count_nonzero(pres[mrl:]))
            print(f"    idx={r['idx']:>4d}, path_len={r['path_length']:>3d}, "
                  f"word_len={tl}, nodes={r['nodes_explored']:>7d}")
        if len(newly) > 30:
            print(f"    ... and {len(newly)-30} more")
    
    return results

# Build config
device = 'cuda' if torch.cuda.is_available() else 'cpu'
arch = ARCHITECTURE
ckpt = f'value_search/checkpoints/best_{arch}.pt'
stats = 'value_search/checkpoints/feature_stats.json'

config = {
    'device': device,
    'architecture': arch,
    'checkpoint': ckpt,
    'feature_stats': stats,
    'max_nodes': MAX_NODES,
    'beam_width': BEAM_WIDTH,
    'c_explore': C_EXPLORE,
    'cyclically_reduce': CYCLICALLY_REDUCE,
}

# Determine algorithms
if ALGORITHM == 'all':
    algos = ['v_guided', 'beam', 'mcts']
else:
    algos = [ALGORITHM]

print(f"Device: {device}")
print(f"Algorithms to run: {algos}")
print(f"Starting...")

all_results = {}
all_solved_sets = {}

for algo in algos:
    results = run_parallel(all_pres, algo, config, NUM_WORKERS, greedy_solved_set)
    all_results[algo] = results
    all_solved_sets[algo] = set(r['idx'] for r in results if r['solved'])

print(f"\n{'='*60}")
print(f"  ALL DONE")
print(f"{'='*60}")

Device: cuda
Algorithms to run: ['v_guided', 'beam', 'mcts']
Starting...

  V_GUIDED | 1,000,000 nodes | 8 workers


## 4. Analysis

In [None]:
# Comparison table
print(f"{'Algorithm':<20} | {'Solved':>10} | {'New':>5} | {'Avg Path':>10} | {'Max Path':>10}")
print(f"{'-'*20}-+-{'-'*10}-+-{'-'*5}-+-{'-'*10}-+-{'-'*10}")

for algo, results in all_results.items():
    solved = [r for r in results if r['solved']]
    newly = [r for r in solved if tuple(all_pres[r['idx']]) not in greedy_solved_set]
    paths = [r['path_length'] for r in solved]
    avg_p = f"{np.mean(paths):.1f}" if paths else "—"
    max_p = f"{max(paths)}" if paths else "—"
    print(f"{algo:<20} | {f'{len(solved)}/1190':>10} | {len(newly):>5} | {avg_p:>10} | {max_p:>10}")

# Union
if len(all_solved_sets) > 1:
    union = set()
    for s in all_solved_sets.values():
        union |= s
    union_new = sum(1 for idx in union if tuple(all_pres[idx]) not in greedy_solved_set)
    print(f"{'UNION':<20} | {f'{len(union)}/1190':>10} | {union_new:>5} |")
    
    # Uniquely solved by each
    print(f"\nUniquely solved:")
    for algo, s in all_solved_sets.items():
        others = set().union(*(v for k, v in all_solved_sets.items() if k != algo))
        unique = s - others
        if unique:
            print(f"  {algo}: {len(unique)} unique")

In [None]:
# Save all results
timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
output_dir = f'experiments/results/{timestamp}_colab'
os.makedirs(output_dir, exist_ok=True)

# Save config
with open(f'{output_dir}/config.json', 'w') as f:
    json.dump({
        'algorithms': algos,
        'max_nodes': MAX_NODES,
        'workers': NUM_WORKERS,
        'architecture': ARCHITECTURE,
        'beam_width': BEAM_WIDTH,
        'c_explore': C_EXPLORE,
        'device': device,
    }, f, indent=2)

# Save per-algorithm results
for algo, results in all_results.items():
    with open(f'{output_dir}/{algo}_results.json', 'w') as f:
        json.dump(results, f, indent=2)

# Save newly solved summary
all_newly = {}
for algo, results in all_results.items():
    newly = [r for r in results if r['solved'] 
             and tuple(all_pres[r['idx']]) not in greedy_solved_set]
    all_newly[algo] = [r['idx'] for r in newly]

with open(f'{output_dir}/newly_solved.json', 'w') as f:
    json.dump(all_newly, f, indent=2)

print(f"Results saved to: {output_dir}/")
print(f"Files:")
for fname in sorted(os.listdir(output_dir)):
    size = os.path.getsize(f'{output_dir}/{fname}')
    print(f"  {fname} ({size/1024:.0f} KB)")

## 5. Verify Newly Solved Paths

In [None]:
# Verify that all reported solutions are actually correct
def verify_solution(pres, path):
    """Replay path and verify it reaches trivial."""
    state = np.array(pres, dtype=np.int8)
    mrl = len(state) // 2
    wl = [int(np.count_nonzero(state[:mrl])), int(np.count_nonzero(state[mrl:]))]
    for action, expected_len in path:
        state, wl = ACMove(action, state, mrl, wl, cyclical=False)
    return is_presentation_trivial(state)

print("Verifying all solutions...")
errors = 0
verified = 0
for algo, results in all_results.items():
    for r in results:
        if r['solved'] and 'path' in r:
            path = [(a, l) for a, l in r['path']]
            ok = verify_solution(all_pres[r['idx']], path)
            if not ok:
                print(f"  ERROR: {algo} idx={r['idx']} path does NOT reach trivial!")
                errors += 1
            verified += 1

print(f"Verified {verified} solutions, {errors} errors")
if errors == 0:
    print("All solutions are verified correct! ✓")

## 6. Download Results (Optional)

In [None]:
# Zip results for download
import shutil
zip_path = shutil.make_archive(f'/content/ac_solver_results_{timestamp}', 'zip', output_dir)
print(f"Results zipped to: {zip_path}")

# If running on Colab, offer download
try:
    from google.colab import files
    files.download(zip_path)
except ImportError:
    print("Not on Colab — download manually from the file browser.")