# Classification Sample Size Experiments

This notebook runs sample size experiments for 2D classification with IT and GL uncertainty decompositions.

Models tested:
- MC Dropout (IT and GL)
- Deep Ensemble (IT and GL)
- BNN (IT and GL)

## Packages

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import sys
from pathlib import Path

# Add parent directory to path to import utils
project_root = Path.cwd().parent if Path.cwd().name == 'Experiments' else Path.cwd()
sys.path.insert(0, str(project_root))

# Setup results directory
results_dir = project_root / "results" / "classification" / "sample_size"
results_dir.mkdir(parents=True, exist_ok=True)
plots_dir = results_dir / "plots"
plots_dir.mkdir(exist_ok=True)
stats_dir = results_dir / "statistics"
stats_dir.mkdir(exist_ok=True)
outputs_dir = results_dir / "outputs"
outputs_dir.mkdir(exist_ok=True)

print(f"Results will be saved to: {results_dir}")

# Import classification experiment utilities
from utils.classification_data import simulate_dataset
from utils.classification_experiments import (
    run_mc_dropout_it_sample_size_experiment,
    run_mc_dropout_gl_sample_size_experiment,
    run_deep_ensemble_it_sample_size_experiment,
    run_deep_ensemble_gl_sample_size_experiment,
    run_bnn_it_sample_size_experiment,
    run_bnn_gl_sample_size_experiment,
)
from utils.device import get_device
import utils.results_save as results_save_module

# Set module-level directories
results_save_module.plots_dir = plots_dir
results_save_module.stats_dir = stats_dir
results_save_module.outputs_dir = outputs_dir

## Device Setup

In [None]:
device = get_device()
print(f"Using device: {device}")

## Data Generation Setup

Configure the base dataset (Gaussian blobs with 3 classes).

In [None]:
# Reproducibility
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

# Base configuration for Gaussian blob classification
base_cfg = {
    "N_train": 1000,  # Will be varied in sample size sweep
    "N_test": 500,
    "num_classes": 3,
    "blob_sigma": 0.25,  # Controls blob overlap
    "tau": 0.2,
    "eta": 0.0,  # No label noise for sample size experiment
    "sigma_in": 0.0,  # No input noise
    "seed": seed,
    # Model hyperparameters (shared)
    "input_dim": 2,
    "epochs": 300,
    "batch_size": 32,
    "lr": 1e-3,
    "dropout_p": 0.25,
    "mc_samples": 50,
    "gl_samples": 100,
    "K": 5,  # Ensemble size
    "hidden_width": 32,
    "weight_scale": 1.0,
    "warmup": 200,
    "samples": 200,
    "chains": 1,
}

# Visualize baseline data
X_train, y_train, X_test, y_test, meta = simulate_dataset(base_cfg)
print(f"Train: {len(X_train)}, Test: {len(X_test)}")
print(f"Class counts (train): {np.bincount(y_train, minlength=3)}")

plt.figure(figsize=(6, 6))
colors = ['tab:blue', 'tab:orange', 'tab:green']
for c in range(3):
    mask = y_train == c
    plt.scatter(X_train[mask, 0], X_train[mask, 1], c=colors[c], alpha=0.5, label=f'Class {c}', s=10)
plt.scatter(meta['centers'][:, 0], meta['centers'][:, 1], c='black', marker='*', s=200, label='Centers')
plt.legend()
plt.title('Baseline Training Data (Gaussian Blobs)')
plt.xlabel('x1')
plt.ylabel('x2')
plt.axis('equal')
plt.grid(True, alpha=0.3)
plt.show()

## Set Parameters

In [None]:
# Sample sizes to test
sample_sizes = [100,200,500,1000]

# For BNN, use fewer sizes due to computational cost
sample_sizes_bnn = [100, 1000]

print(f"Sample sizes: {sample_sizes}")
print(f"Sample sizes (BNN): {sample_sizes_bnn}")

## MC Dropout IT

In [None]:
results_mc_dropout_it = run_mc_dropout_it_sample_size_experiment(
    base_cfg=base_cfg,
    sample_sizes=sample_sizes,
    seed=seed,
)

## MC Dropout GL

In [None]:
results_mc_dropout_gl = run_mc_dropout_gl_sample_size_experiment(
    base_cfg=base_cfg,
    sample_sizes=sample_sizes,
    seed=seed,
)

## Deep Ensemble IT

In [None]:
results_deep_ensemble_it = run_deep_ensemble_it_sample_size_experiment(
    base_cfg=base_cfg,
    sample_sizes=sample_sizes,
    seed=seed,
)

In [None]:
results_deep_ensemble_gl = run_deep_ensemble_gl_sample_size_experiment(
    base_cfg=base_cfg,
    sample_sizes=sample_sizes,
    seed=seed,
)

In [None]:
# BNN is computationally expensive, so we use fewer sample sizes
# Uncomment to run:
results_bnn_it = run_bnn_it_sample_size_experiment(
    base_cfg=base_cfg,
    sample_sizes=sample_sizes_bnn,
    seed=seed,
)

In [None]:
# BNN is computationally expensive, so we use fewer sample sizes
# Uncomment to run:
results_bnn_gl = run_bnn_gl_sample_size_experiment(
    base_cfg=base_cfg,
    sample_sizes=sample_sizes_bnn,
    seed=seed,
)