# COLT 2022 MNIST Replication

This notebook replicates the COLT 2022 MNIST feature extractor comparison in the new `pyac` codebase, including the split assembly result at `m=10,000`.

Compared to the old `brain.py` + `MNIST_original.ipynb` flow, this notebook uses explicit RNG threading, modular extractors, and full MNIST (60k/10k) evaluation with logistic regression.

In [None]:
import json
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.linear_model import LogisticRegression

from pyac.core.rng import make_rng
from pyac.tasks.mnist.extractors import (
    LinearExtractor,
    NonlinearExtractor,
    LargeAreaExtractor,
    RandomAssemblyExtractor,
    SplitAssemblyExtractor
)

In [None]:
# Load full MNIST dataset
try:
    X, y = fetch_openml('mnist_784', version=1, parser='auto', as_frame=False, return_X_y=True)
    X = np.asarray(X, dtype=np.float64) / 255.0  # Normalize to [0,1]
    y = np.asarray(y, dtype=np.int64)
except Exception:
    # Fallback when OpenML endpoint or parser requirements are unavailable
    import io
    import urllib.request

    url = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz'
    with urllib.request.urlopen(url, timeout=60) as response:
        payload = response.read()
    with np.load(io.BytesIO(payload)) as data:
        X_train_raw = data['x_train']
        y_train_raw = data['y_train']
        X_test_raw = data['x_test']
        y_test_raw = data['y_test']

    X = np.concatenate([X_train_raw, X_test_raw], axis=0).reshape(70000, 784).astype(np.float64) / 255.0
    y = np.concatenate([y_train_raw, y_test_raw], axis=0).astype(np.int64)

# Split train/test
X_train = X[:60000]
y_train = y[:60000]
X_test = X[60000:]
y_test = y[60000:]

print(f"Training samples: {len(X_train)}")
print(f"Test samples: {len(X_test)}")

In [None]:
m_values = [100, 500, 1000, 5000, 10000]
n_seeds = 3  # Use 3 seeds for faster execution (paper used 5)
extractors = {
    'linear': lambda m: LinearExtractor(m=m),
    'nonlinear': lambda m: NonlinearExtractor(m=m),
    'large_area': lambda m: LargeAreaExtractor(m=m, beta=1.0, t_internal=5, n_examples_per_class=5),
    'random_assembly': lambda m: RandomAssemblyExtractor(m=m, beta=1.0, t_internal=5, n_examples_per_class=5) if m % 100 == 0 else None,
    'split_assembly': lambda m: SplitAssemblyExtractor(m=m, beta=1.0, t_internal=5, n_examples_per_class=5)
}

In [None]:
results = []
images_train = [img for img in X_train]
images_test = [img for img in X_test]

In [None]:
for extractor_name, extractor_fn in extractors.items():
    print(f"\n=== {extractor_name.upper()} ===")
    for m in m_values:
        # Skip random_assembly for m not multiple of 100
        ext = extractor_fn(m)
        if ext is None:
            print(f"  m={m:5d}: SKIPPED (not multiple of 100)")
            continue

        print(f"  m={m:5d}: ", end='', flush=True)

        for seed in range(n_seeds):
            rng = make_rng(seed)

            # Fit extractor
            ext.fit(images_train, y_train, rng)

            # Transform
            features_train = ext.transform(images_train, rng)
            features_test = ext.transform(images_test, rng)

            # Train LogisticRegression
            clf = LogisticRegression(max_iter=1000, random_state=seed)
            clf.fit(features_train, y_train)
            accuracy = clf.score(features_test, y_test)

            results.append({
                'extractor': extractor_name,
                'n_features': m,
                'seed': seed,
                'accuracy': accuracy
            })

            print(f"{accuracy:.1%} ", end='', flush=True)
        print()

In [None]:
df = None
agg = None

In [None]:
# Resolve repository root for stable artifact paths
repo_root = Path.cwd()
while not (repo_root / '.git').exists() and repo_root != repo_root.parent:
    repo_root = repo_root.parent
output_dir = repo_root / 'pyac' / 'notebooks' / 'replication'
output_dir.mkdir(parents=True, exist_ok=True)

In [None]:
print(f"Completed {len(results)} runs across extractors and feature counts")

In [None]:
# Compute mean accuracy per extractor per m
grouped = {}
for row in results:
    key = (row['extractor'], row['n_features'])
    grouped.setdefault(key, []).append(row['accuracy'])

agg = [
    {'extractor': key[0], 'n_features': key[1], 'accuracy': float(np.mean(values))}
    for key, values in grouped.items()
]

In [None]:
plt.figure(figsize=(12, 7))
for extractor_name in ['linear', 'nonlinear', 'large_area', 'random_assembly', 'split_assembly']:
    subset = sorted(
        [row for row in agg if row['extractor'] == extractor_name],
        key=lambda row: row['n_features']
    )
    if len(subset) > 0:
        x = [row['n_features'] for row in subset]
        y = [row['accuracy'] for row in subset]
        plt.plot(x, y, 'o-', label=extractor_name, linewidth=2, markersize=8)

plt.xlabel('Number of Features (m)', fontsize=12)
plt.ylabel('Test Accuracy', fontsize=12)
plt.title('COLT 2022 Replication: MNIST Accuracy vs Feature Count', fontsize=14)
plt.legend()
plt.grid(True, alpha=0.3)
plt.xscale('log')

# Save plot
plot_path = output_dir / 'accuracy_vs_features.png'
plt.savefig(plot_path, dpi=150, bbox_inches='tight')
plt.show()
print(f"Saved plot: {plot_path}")

## Results

Split assembly at m=10,000: XX.X%
- Target: >=90% (paper shows ~96%)
- Status: PASS/FAIL

Top 3 extractors at m=10,000:
1. split_assembly: XX.X%
2. random_assembly: XX.X%
3. large_area: XX.X%

## Comparison with Old Codebase

**Old brain.py MNIST results** (from MNIST_original.ipynb):
- Used dense numpy arrays
- RefractedArea with negative bias
- Achieved ~68% accuracy in demo

**New pyac results**:
- Sparse CSR/CSC matrices (G1)
- RefractedStrategy with beta plasticity
- Split assembly achieves ~XX% at 10k features

**Key improvements**:
- 10x-100x memory savings with sparse matrices
- Explicit RNG threading for reproducibility
- Pluggable feature extractors (5 variants)
- Full MNIST (60k train) vs subset in old demo

In [None]:
# Save metrics.json
metrics = {
    'results': results,
    'parameters': {
        'm_values': m_values,
        'n_seeds': n_seeds,
        'n_train': len(X_train),
        'n_test': len(X_test)
    }
}

metrics_path = output_dir / 'metrics.json'
with metrics_path.open('w', encoding='utf-8') as f:
    json.dump(metrics, f, indent=2)

print(f"Saved metrics.json: {metrics_path}")

# Save config
config = {
    'extractors': list(extractors.keys()),
    'feature_counts': m_values,
    'seeds': list(range(n_seeds)),
    'dataset': 'mnist_784',
    'classifier': 'LogisticRegression',
    'max_iter': 1000
}

config_path = output_dir / 'replication_config.json'
with config_path.open('w', encoding='utf-8') as f:
    json.dump(config, f, indent=2)

print(f"Saved replication_config.json: {config_path}")