In [1]:
# =========================
# Classic Baselines Consistency Benchmark
# Industry-standard architectures on CIFAR-100
# Same dual-run methodology for fair comparison
# =========================

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import timm
import pandas as pd
import os
import gc
from datetime import datetime
from typing import Dict, Tuple, List
from tqdm.auto import tqdm

# -------------------------
# Config
# -------------------------
RESULTS_FILE = f"classic_baselines_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
EPOCHS = 20  # Standard training length
BATCH_SIZE = 128
NUM_RUNS = 2
NUM_CLASSES = 100

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True


# -------------------------
# Data
# -------------------------
def get_cifar100_loaders():
    """Standard CIFAR-100 with basic augmentation."""

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    ])

    train_ds = datasets.CIFAR100("./data", train=True, download=True, transform=train_transform)
    test_ds = datasets.CIFAR100("./data", train=False, download=True, transform=test_transform)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

    return train_loader, test_loader


# -------------------------
# Models
# -------------------------
class SimpleMLP(nn.Module):
    """Basic MLP baseline."""
    def __init__(self, num_classes=100):
        super().__init__()
        self.flatten = nn.Flatten()
        self.net = nn.Sequential(
            nn.Linear(32 * 32 * 3, 1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, num_classes),
        )

    def forward(self, x):
        return self.net(self.flatten(x))


class SmallCNN(nn.Module):
    """Simple 4-layer CNN."""
    def __init__(self, num_classes=100):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(256, 512, 3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
        )
        self.classifier = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)


def get_resnet18_cifar():
    """ResNet-18 adapted for CIFAR (32x32)."""
    model = timm.create_model('resnet18', pretrained=False, num_classes=100)
    # Adapt for 32x32: smaller initial conv, no initial maxpool
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    return model


def get_resnet34_cifar():
    """ResNet-34 adapted for CIFAR."""
    model = timm.create_model('resnet34', pretrained=False, num_classes=100)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    return model


def get_vit_tiny_cifar():
    """ViT-Tiny for CIFAR-100."""
    model = timm.create_model(
        'vit_tiny_patch4_32',  # 4x4 patches for 32x32 images
        pretrained=False,
        num_classes=100,
        img_size=32,
    )
    return model


def get_convnext_tiny_cifar():
    """ConvNeXt-Tiny adapted for CIFAR."""
    model = timm.create_model('convnext_tiny', pretrained=False, num_classes=100)
    # Adapt stem for 32x32
    model.stem[0] = nn.Conv2d(3, 96, kernel_size=2, stride=2)
    return model


def get_efficientnet_b0_cifar():
    """EfficientNet-B0 for CIFAR."""
    model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=100)
    return model


def get_deit_tiny_cifar():
    """DeiT-Tiny for CIFAR-100."""
    model = timm.create_model(
        'deit_tiny_patch16_224',
        pretrained=False,
        num_classes=100,
        img_size=32,
        patch_size=4,
    )
    return model


MODELS = {
    'mlp': (SimpleMLP, {'lr': 1e-3, 'wd': 1e-4}),
    'small_cnn': (SmallCNN, {'lr': 1e-3, 'wd': 1e-4}),
    'resnet18': (get_resnet18_cifar, {'lr': 0.1, 'wd': 5e-4, 'scheduler': 'cosine'}),
    'resnet34': (get_resnet34_cifar, {'lr': 0.1, 'wd': 5e-4, 'scheduler': 'cosine'}),
    'vit_tiny': (get_vit_tiny_cifar, {'lr': 1e-3, 'wd': 0.05, 'scheduler': 'cosine'}),
    'convnext_tiny': (get_convnext_tiny_cifar, {'lr': 1e-3, 'wd': 0.05, 'scheduler': 'cosine'}),
    'efficientnet_b0': (get_efficientnet_b0_cifar, {'lr': 1e-3, 'wd': 1e-4, 'scheduler': 'cosine'}),
}


# -------------------------
# Training
# -------------------------
def train_single_run(
    model_fn,
    train_loader,
    test_loader,
    config: Dict,
    seed: int,
) -> Tuple[float, float, List[float]]:
    """Single training run."""

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    model = model_fn().to(device)

    # Count params
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    # Optimizer
    lr = config.get('lr', 1e-3)
    wd = config.get('wd', 1e-4)

    if config.get('scheduler') == 'cosine':
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=wd)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
    else:
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

    criterion = nn.CrossEntropyLoss()

    best_acc = 0.0
    epoch_accs = []

    for epoch in range(EPOCHS):
        # Train
        model.train()
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)

            optimizer.zero_grad()
            logits = model(imgs)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()

        scheduler.step()

        # Eval
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for imgs, labels in test_loader:
                imgs, labels = imgs.to(device), labels.to(device)
                logits = model(imgs)
                preds = logits.argmax(dim=-1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

        acc = 100.0 * correct / total
        epoch_accs.append(acc)
        if acc > best_acc:
            best_acc = acc

        if (epoch + 1) % 5 == 0:
            print(f"      Epoch {epoch+1:2d}: acc={acc:.2f}%")

    del model, optimizer, scheduler
    torch.cuda.empty_cache()
    gc.collect()

    return best_acc, epoch_accs[-1], num_params


def benchmark_model(
    name: str,
    model_fn,
    config: Dict,
    train_loader,
    test_loader,
) -> Dict:
    """Run model multiple times for consistency check."""

    print(f"\n{'='*50}")
    print(f"  {name.upper()}")
    print(f"{'='*50}")

    run_bests = []
    run_finals = []
    num_params = 0

    for run_idx in range(NUM_RUNS):
        seed = 42 + run_idx * 1000
        print(f"  Run {run_idx+1}/{NUM_RUNS} (seed={seed})")

        try:
            best_acc, final_acc, num_params = train_single_run(
                model_fn, train_loader, test_loader, config, seed
            )
            run_bests.append(best_acc)
            run_finals.append(final_acc)
            print(f"    -> Best: {best_acc:.2f}%")
        except Exception as e:
            print(f"    ERROR: {e}")
            run_bests.append(0.0)
            run_finals.append(0.0)

    # Stats
    best_min, best_max = min(run_bests), max(run_bests)
    consistency = best_min / best_max if best_max > 0 else 0
    best_mean = sum(run_bests) / len(run_bests)
    best_std = (sum((x - best_mean)**2 for x in run_bests) / len(run_bests)) ** 0.5

    print(f"  Mean: {best_mean:.2f}% ± {best_std:.2f}% | Consistency: {consistency:.3f}")

    return {
        'name': name,
        'params': num_params,
        'params_M': num_params / 1e6,
        'best_mean': best_mean,
        'best_std': best_std,
        'best_min': best_min,
        'best_max': best_max,
        'consistency': consistency,
        'run1_best': run_bests[0],
        'run2_best': run_bests[1] if len(run_bests) > 1 else run_bests[0],
        'final_mean': sum(run_finals) / len(run_finals),
    }


# -------------------------
# Main
# -------------------------
def run_baselines():
    print("=" * 60)
    print("CLASSIC BASELINES CONSISTENCY BENCHMARK")
    print("=" * 60)
    print(f"Epochs: {EPOCHS}, Batch: {BATCH_SIZE}, Runs: {NUM_RUNS}")

    train_loader, test_loader = get_cifar100_loaders()
    print(f"Train: {len(train_loader.dataset)}, Test: {len(test_loader.dataset)}")

    results = []

    for name, (model_fn, config) in MODELS.items():
        result = benchmark_model(name, model_fn, config, train_loader, test_loader)
        results.append(result)

    # Save
    df = pd.DataFrame(results)
    df.to_csv(RESULTS_FILE, index=False)

    print(f"\n{'='*60}")
    print("RESULTS SUMMARY")
    print("=" * 60)

    df_sorted = df.sort_values('best_mean', ascending=False)
    print(df_sorted[['name', 'params_M', 'best_mean', 'best_std', 'consistency']].to_string(index=False))

    # Compare to walker fusion
    print(f"\n{'='*60}")
    print("COMPARISON CONTEXT")
    print("=" * 60)
    print("Walker fusion results (from combo ablation):")
    print("  baseline_walker:      88.07% ± 0.59% | consistency: 0.987")
    print("  combo_shiva_cosine:   88.62% ± 0.05% | consistency: 0.999")
    print("  combo_shiva_learned:  88.64% ± 0.05% | consistency: 0.999")
    print("\nNote: Walker uses frozen pretrained encoders (no training)")
    print("      Baselines train from scratch")

    return df


if __name__ == "__main__":
    results_df = run_baselines()

Device: cuda
CLASSIC BASELINES CONSISTENCY BENCHMARK
Epochs: 20, Batch: 128, Runs: 2


  self.setter(val)
100%|██████████| 169M/169M [00:05<00:00, 29.1MB/s]


Train: 50000, Test: 10000

  MLP
  Run 1/2 (seed=42)
      Epoch  5: acc=13.08%
      Epoch 10: acc=16.91%
      Epoch 15: acc=19.62%
      Epoch 20: acc=20.82%
    -> Best: 20.83%
  Run 2/2 (seed=1042)
      Epoch  5: acc=13.73%
      Epoch 10: acc=14.98%
      Epoch 15: acc=18.24%
      Epoch 20: acc=20.71%
    -> Best: 20.71%
  Mean: 20.77% ± 0.06% | Consistency: 0.994

  SMALL_CNN
  Run 1/2 (seed=42)
      Epoch  5: acc=47.19%
      Epoch 10: acc=55.97%
      Epoch 15: acc=60.78%
      Epoch 20: acc=62.40%
    -> Best: 62.40%
  Run 2/2 (seed=1042)
      Epoch  5: acc=46.42%
      Epoch 10: acc=56.37%
      Epoch 15: acc=60.40%
      Epoch 20: acc=61.52%
    -> Best: 61.62%
  Mean: 62.01% ± 0.39% | Consistency: 0.987

  RESNET18
  Run 1/2 (seed=42)
      Epoch  5: acc=43.79%
      Epoch 10: acc=56.45%
      Epoch 15: acc=66.29%
      Epoch 20: acc=72.99%
    -> Best: 73.05%
  Run 2/2 (seed=1042)
      Epoch  5: acc=41.67%
      Epoch 10: acc=54.71%
      Epoch 15: acc=66.87%
      E