# Random Model Search: Generalization Distribution

**Goal**: Test the distribution of generalization among random CNN models.

**Key question**: Among models that achieve >20% train accuracy (above chance), how many generalize?

Runs indefinitely until stopped. Saves periodically (overwriting old saves).

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import json
import time
import math
from pathlib import Path
from collections import deque
from dataclasses import dataclass, field
import matplotlib.pyplot as plt

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name()}")

## Configuration

In [None]:
CONFIG = {
    # Data splits
    'train_size': 256,
    'val_size': 5000,
    'test_size': 5000,
    'data_seed': 42,
    
    # Search parameters
    'batch_size': 256,  # Models per batch (smaller for CNN memory)
    'screen_threshold': 0.20,
    
    # Weight initialization
    'weight_scale': 1.5,
    'bias_scale': 0.5,
    
    # Saving
    'save_every': 50000,
    'log_every': 10000,  # Print progress every N models
    'max_best_weights': 25,
    'results_file': 'results/search_results.json',
    'weights_file': 'results/best_weights.pt',
}

## Expected Best Calculation

In [None]:
def expected_best_random(n_samples: int, n_trials: int, p: float = 0.1) -> float:
    """Expected best accuracy from n_trials random models."""
    if n_trials <= 1:
        return p
    mu = p
    sigma = math.sqrt(p * (1 - p) / n_samples)
    expected_max = mu + sigma * math.sqrt(2 * math.log(n_trials))
    return min(expected_max, 1.0)

## Load Data

In [None]:
def load_mnist(device):
    """Load MNIST as 1x28x28 images."""
    from torchvision import datasets, transforms
    
    train_data = datasets.MNIST(
        root='./data', train=True, download=True,
        transform=transforms.ToTensor()
    )
    
    images = torch.stack([img for img, _ in train_data]).to(device)  # (N, 1, 28, 28)
    labels = torch.tensor([lbl for _, lbl in train_data]).to(device)
    
    return images, labels


def create_fixed_splits(images, labels, train_size, val_size, test_size, seed):
    """Create fixed, non-overlapping splits."""
    rng = np.random.RandomState(seed)
    indices = rng.permutation(len(images))
    
    train_idx = indices[:train_size]
    val_idx = indices[train_size:train_size + val_size]
    test_idx = indices[train_size + val_size:train_size + val_size + test_size]
    
    return (
        (images[train_idx], labels[train_idx]),
        (images[val_idx], labels[val_idx]),
        (images[test_idx], labels[test_idx]),
    )

In [None]:
print("Loading MNIST...")
all_images, all_labels = load_mnist(device)
print(f"Loaded {len(all_images)} images")

(train_images, train_labels), (val_images, val_labels), (test_images, test_labels) = \
    create_fixed_splits(
        all_images, all_labels,
        CONFIG['train_size'],
        CONFIG['val_size'],
        CONFIG['test_size'],
        CONFIG['data_seed']
    )

print(f"Train: {len(train_images)}, Val: {len(val_images)}, Test: {len(test_images)}")

## Small CNN Model

Tiny CNN for fast random search:
- Conv1: 1 → 4 channels, 5x5, stride 2 → 12x12
- Conv2: 4 → 8 channels, 5x5, stride 2 → 4x4  
- FC: 128 → 10

In [None]:
class TinyCNN(nn.Module):
    """Tiny CNN for random search."""
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 4, 5, stride=2)   # -> 4x12x12
        self.conv2 = nn.Conv2d(4, 8, 5, stride=2)   # -> 8x4x4
        self.fc = nn.Linear(8 * 4 * 4, 10)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        return self.fc(x)


def randomize_model(model, weight_scale, bias_scale):
    """Randomize all parameters with given scales."""
    with torch.no_grad():
        for name, param in model.named_parameters():
            if 'bias' in name:
                param.normal_(0, bias_scale)
            else:
                param.normal_(0, weight_scale)


def get_model_params(model):
    """Get model parameters as dict of CPU tensors."""
    return {name: param.cpu().clone() for name, param in model.named_parameters()}


def evaluate_model(model, images, labels):
    """Compute accuracy."""
    with torch.no_grad():
        logits = model(images)
        preds = logits.argmax(dim=1)
        return (preds == labels).float().mean().item()

In [None]:
# Test model
test_model = TinyCNN().to(device)
print(f"Model parameters: {sum(p.numel() for p in test_model.parameters()):,}")
randomize_model(test_model, CONFIG['weight_scale'], CONFIG['bias_scale'])
test_acc = evaluate_model(test_model, train_images, train_labels)
print(f"Random model train accuracy: {test_acc:.4f}")

## Results Tracking

In [None]:
@dataclass
class SearchResults:
    """Track search results."""
    config: dict = field(default_factory=dict)
    n_failed_screen: int = 0
    passing_results: list = field(default_factory=list)  # [train_acc, val_acc]
    best_results: list = field(default_factory=list)  # [train, val, test, exp_best, idx]
    best_val_acc: float = 0.0
    best_model_idx: int = -1
    total_evaluated: int = 0
    elapsed_seconds: float = 0.0
    
    def to_dict(self):
        return {
            'config': self.config,
            'n_failed_screen': self.n_failed_screen,
            'passing_results': self.passing_results,
            'best_results': self.best_results,
            'best_val_acc': self.best_val_acc,
            'best_model_idx': self.best_model_idx,
            'total_evaluated': self.total_evaluated,
            'elapsed_seconds': self.elapsed_seconds,
        }
    
    @classmethod
    def from_dict(cls, d):
        return cls(
            config=d.get('config', {}),
            n_failed_screen=d.get('n_failed_screen', 0),
            passing_results=d.get('passing_results', []),
            best_results=d.get('best_results', []),
            best_val_acc=d.get('best_val_acc', 0.0),
            best_model_idx=d.get('best_model_idx', -1),
            total_evaluated=d.get('total_evaluated', 0),
            elapsed_seconds=d.get('elapsed_seconds', 0.0),
        )
    
    def save(self, filepath):
        Path(filepath).parent.mkdir(parents=True, exist_ok=True)
        # Write to temp file first, then rename (atomic)
        temp_path = filepath + '.tmp'
        with open(temp_path, 'w') as f:
            json.dump(self.to_dict(), f)
            f.flush()
        Path(temp_path).rename(filepath)
    
    @classmethod
    def load(cls, filepath):
        with open(filepath, 'r') as f:
            return cls.from_dict(json.load(f))

## Main Search Loop (Runs Indefinitely)

In [None]:
def run_search(
    train_data: tuple,
    val_data: tuple,
    test_data: tuple,
    config: dict,
    resume_from: str = None,
):
    """
    Run random model search indefinitely.
    Saves periodically. Stop with Ctrl+C or runtime disconnect.
    """
    train_images, train_labels = train_data
    val_images, val_labels = val_data
    test_images, test_labels = test_data
    
    # Initialize or load
    results_file = config['results_file']
    weights_file = config['weights_file']
    
    if resume_from and Path(resume_from).exists():
        print(f"Resuming from {resume_from}")
        results = SearchResults.load(resume_from)
    else:
        results = SearchResults(config=config)
    
    # Load existing best weights
    if Path(weights_file).exists():
        best_weights_list = torch.load(weights_file, weights_only=False)
        best_weights = deque(best_weights_list, maxlen=config['max_best_weights'])
        print(f"Loaded {len(best_weights)} best model weights")
    else:
        best_weights = deque(maxlen=config['max_best_weights'])
    
    batch_size = config['batch_size']
    save_every = config['save_every']
    log_every = config['log_every']
    threshold = config['screen_threshold']
    weight_scale = config['weight_scale']
    bias_scale = config['bias_scale']
    
    # Create model template
    model = TinyCNN().to(device)
    
    start_time = time.time()
    last_save_count = results.total_evaluated
    last_log_count = results.total_evaluated
    
    print(f"\nStarting search (runs indefinitely, Ctrl+C to stop)")
    print(f"Total evaluated so far: {results.total_evaluated:,}")
    print(f"Passing models so far: {len(results.passing_results):,}")
    print(f"Current best val acc: {results.best_val_acc:.4f}")
    print(f"{'='*60}\n")
    
    try:
        while True:
            # Evaluate batch of random models
            for _ in range(batch_size):
                randomize_model(model, weight_scale, bias_scale)
                
                # Screen on train
                train_acc = evaluate_model(model, train_images, train_labels)
                
                if train_acc < threshold:
                    results.n_failed_screen += 1
                else:
                    # Passed screening - evaluate on val
                    val_acc = evaluate_model(model, val_images, val_labels)
                    results.passing_results.append([train_acc, val_acc])
                    
                    # Check if new best
                    if val_acc > results.best_val_acc:
                        test_acc = evaluate_model(model, test_images, test_labels)
                        exp_best = expected_best_random(len(val_images), results.total_evaluated + 1)
                        
                        results.best_val_acc = val_acc
                        results.best_model_idx = results.total_evaluated
                        results.best_results.append([
                            train_acc, val_acc, test_acc, exp_best, results.total_evaluated
                        ])
                        
                        # Save weights
                        best_weights.append({
                            'params': get_model_params(model),
                            'model_idx': results.total_evaluated,
                            'val_acc': val_acc,
                        })
                        
                        # Log new best
                        print(f"[NEW BEST #{len(results.best_results)}] "
                              f"model {results.total_evaluated:,} | "
                              f"train: {train_acc:.3f} | "
                              f"val: {val_acc:.3f} | "
                              f"test: {test_acc:.3f} | "
                              f"exp_best: {exp_best:.3f}")
                
                results.total_evaluated += 1
            
            # Periodic logging
            if results.total_evaluated - last_log_count >= log_every:
                elapsed = time.time() - start_time + results.elapsed_seconds
                rate = results.total_evaluated / elapsed if elapsed > 0 else 0
                pass_rate = len(results.passing_results) / results.total_evaluated
                print(f"[PROGRESS] {results.total_evaluated:,} models | "
                      f"{rate:,.0f}/sec | "
                      f"passing: {len(results.passing_results):,} ({pass_rate:.3%}) | "
                      f"best: {results.best_val_acc:.3f}")
                last_log_count = results.total_evaluated
            
            # Periodic save
            if results.total_evaluated - last_save_count >= save_every:
                results.elapsed_seconds = time.time() - start_time + results.elapsed_seconds
                results.save(results_file)
                
                Path(weights_file).parent.mkdir(parents=True, exist_ok=True)
                temp_weights = weights_file + '.tmp'
                torch.save(list(best_weights), temp_weights)
                Path(temp_weights).rename(weights_file)
                
                print(f"[SAVED] {results_file}")
                last_save_count = results.total_evaluated
                start_time = time.time()  # Reset for next interval
                results.elapsed_seconds = results.elapsed_seconds  # Keep accumulated
    
    except KeyboardInterrupt:
        print(f"\n\n{'='*60}")
        print("Search interrupted. Saving final state...")
    
    finally:
        # Final save
        results.elapsed_seconds += time.time() - start_time
        results.save(results_file)
        torch.save(list(best_weights), weights_file)
        print(f"Saved to {results_file}")
        
        print(f"\n{'='*60}")
        print(f"FINAL RESULTS")
        print(f"{'='*60}")
        print(f"Total evaluated: {results.total_evaluated:,}")
        print(f"Passed screening: {len(results.passing_results):,} ({len(results.passing_results)/max(1,results.total_evaluated):.3%})")
        print(f"Best val accuracy: {results.best_val_acc:.4f}")
        print(f"Best models found: {len(results.best_results)}")
    
    return results

## Run Search

In [None]:
results = run_search(
    (train_images, train_labels),
    (val_images, val_labels),
    (test_images, test_labels),
    CONFIG,
    resume_from=CONFIG['results_file']
)

## Analyze Results

In [None]:
def plot_search_results(results: SearchResults):
    """Visualize search results."""
    if not results.passing_results:
        print("No passing models to analyze")
        return
    
    passing = np.array(results.passing_results)
    train_acc = passing[:, 0]
    val_acc = passing[:, 1]
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # 1. Generalization gap
    ax = axes[0, 0]
    gap = train_acc - val_acc
    ax.hist(gap, bins=50, alpha=0.7, edgecolor='black')
    ax.axvline(0, color='red', linestyle='--', label='No gap')
    ax.axvline(gap.mean(), color='green', linestyle='--', label=f'Mean: {gap.mean():.3f}')
    ax.set_xlabel('Generalization Gap (Train - Val)')
    ax.set_ylabel('Count')
    ax.set_title(f'Generalization Gap (n={len(gap):,})')
    ax.legend()
    
    # 2. Train vs Val scatter
    ax = axes[0, 1]
    ax.scatter(train_acc, val_acc, alpha=0.3, s=5)
    if results.best_results:
        best = np.array(results.best_results)
        ax.scatter(best[:, 0], best[:, 1], color='red', s=30, 
                   label='Best models', zorder=5, edgecolor='black')
    ax.plot([0, 1], [0, 1], 'k--', alpha=0.3)
    ax.set_xlabel('Train Accuracy')
    ax.set_ylabel('Val Accuracy')
    ax.set_title('Train vs Val Accuracy')
    ax.legend()
    
    # 3. Best model progress
    ax = axes[1, 0]
    if results.best_results:
        best = np.array(results.best_results)
        ax.plot(best[:, 4], best[:, 1], 'b-o', markersize=4, label='Val Acc')
        ax.plot(best[:, 4], best[:, 2], 'g-o', markersize=4, label='Test Acc')
        ax.plot(best[:, 4], best[:, 3], 'r--', linewidth=2, label='Expected Best (random)')
        ax.axhline(0.1, color='gray', linestyle=':', alpha=0.5)
        ax.set_xlabel('Model Index')
        ax.set_ylabel('Accuracy')
        ax.set_title('Best Model Progress vs Random Baseline')
        ax.legend()
        ax.set_xscale('log')
    
    # 4. Val accuracy distribution
    ax = axes[1, 1]
    ax.hist(val_acc, bins=50, alpha=0.7, edgecolor='black')
    ax.axvline(0.1, color='gray', linestyle=':', label='Chance (10%)')
    ax.axvline(val_acc.mean(), color='green', linestyle='--', label=f'Mean: {val_acc.mean():.3f}')
    ax.set_xlabel('Validation Accuracy')
    ax.set_ylabel('Count')
    ax.set_title('Val Accuracy Distribution')
    ax.legend()
    
    plt.tight_layout()
    plt.show()
    
    # Stats
    print(f"\nTotal: {results.total_evaluated:,} | Passing: {len(results.passing_results):,}")
    print(f"Val acc: mean={val_acc.mean():.4f}, max={val_acc.max():.4f}")
    print(f"Generalizing (val>10%): {(val_acc > 0.1).sum():,} ({(val_acc > 0.1).mean():.2%})")
    if results.best_results:
        best = np.array(results.best_results)
        print(f"Best: train={best[-1,0]:.3f}, val={best[-1,1]:.3f}, test={best[-1,2]:.3f}")

In [None]:
# Load and plot (can run this cell separately after stopping search)
if Path(CONFIG['results_file']).exists():
    results = SearchResults.load(CONFIG['results_file'])
    plot_search_results(results)

## Download

In [None]:
from google.colab import files
files.download(CONFIG['results_file'])