In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import GTSRB, SEMEION
from torch.utils.data import DataLoader
from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLR
import matplotlib.pyplot as plt
import numpy as np
import math
import time
import os

# ====== Config ======
BATCH_SIZE = 1024
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on: {DEVICE}")
print(f"Batch Size: {BATCH_SIZE}")

# ============================================================
# ====== OPTIMIZERS: CUSTOM IMPLEMENTATIONS ==================
# ============================================================

class Lion(optim.Optimizer):
    """
    Lion Optimizer (Google Brain).
    Sign-based, memory efficient.
    """
    def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0):
        defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        for group in self.param_groups:
            params = [p for p in group['params'] if p.grad is not None]
            for p in params:
                grad = p.grad
                state = self.state[p]
                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p)

                exp_avg = state['exp_avg']
                beta1, beta2 = group['betas']
                lr, wd = group['lr'], group['weight_decay']

                # Update
                update = exp_avg * beta1 + grad * (1 - beta1)
                p.data.add_(torch.sign(update), alpha=-lr)

                # Momentum Decay
                exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)

                # Weight Decay
                if wd > 0:
                    p.data.mul_(1 - lr * wd)

class MuAdam(optim.Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super().__init__(params, defaults)
    @torch.no_grad()
    def step(self, closure=None):
        for group in self.param_groups:
            beta1, beta2 = group['betas']
            for p in group['params']:
                if p.grad is None: continue
                grad = p.grad; state = self.state[p]
                if len(state) == 0: state['step'] = 0; state['exp_avg'] = torch.zeros_like(p); state['exp_avg_sq'] = torch.zeros_like(p)
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']; state['step'] += 1
                if group['weight_decay'] != 0: p.mul_(1 - group['lr'] * group['weight_decay'])
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
                denom = exp_avg_sq.sqrt().add_(group['eps'])
                p.add_((exp_avg * beta1 + grad * (1-beta1)) / denom, alpha=-group['lr'])

class ThermoLion(optim.Optimizer):
    """
    Renamed from OmniThermoLion.
    Thermodynamic-enhanced Lion optimization.
    """
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), temp_decay=0.99, weight_decay=0.01):
        defaults = dict(lr=lr, betas=betas, temp_decay=temp_decay, weight_decay=weight_decay)
        super().__init__(params, defaults)
    @torch.no_grad()
    def step(self, closure=None):
        for group in self.param_groups:
            beta1, beta2 = group['betas']; lr = group['lr']; wd = group['weight_decay']; tdec = group['temp_decay']
            for p in group['params']:
                if p.grad is None: continue
                g = p.grad.data; state = self.state[p]
                if len(state) == 0: state['step'] = 0; state['m'] = torch.zeros_like(p); state['v'] = torch.zeros_like(p); state['temp'] = 1.0
                m, v, temp = state['m'], state['v'], state['temp']; state['step'] += 1; temp *= tdec; state['temp'] = temp
                m.mul_(beta1).add_(g, alpha=1-beta1)
                v.mul_(beta2).addcmul_(g, g, value=1-beta2)
                snr = torch.abs(m) / (torch.sqrt(v) + 1e-8)
                gate = torch.tanh(snr)
                step = (1 - gate) * torch.sign(m) * (1 + 0.5 * torch.clamp(torch.sign(m)*torch.sign(g), 0, 1)) * lr + gate * (m / (torch.sqrt(v) + 1e-8)) * lr * 2.0
                if temp > 0.01: step += torch.randn_like(p) * math.sqrt(temp * v.mean().item() + 1e-10) * lr * (1 - gate)
                if wd: p.mul_(1 - lr * wd)
                p.add_(step, alpha=-1)

# --- Lookahead Optimizer Wrapper (Fixed) ---
class Lookahead(optim.Optimizer):
    def __init__(self, optimizer, k=5, alpha=0.5):
        self.optimizer = optimizer
        self.k = k
        self.alpha = alpha
        self.param_groups = self.optimizer.param_groups
        self.state = {}
        # FIX: Initialize defaults from the inner optimizer
        self.defaults = getattr(optimizer, 'defaults', {})
        for group in self.param_groups:
            group['counter'] = 0

    def step(self, closure=None):
        loss = self.optimizer.step(closure)
        for group in self.param_groups:
            if group['counter'] == 0:
                for p in group['params']:
                    self.state[p] = p.detach().clone()

            group['counter'] += 1
            if group['counter'] >= self.k:
                group['counter'] = 0
                for p in group['params']:
                    slow = self.state[p]
                    fast = p.data
                    # Slow weights update: slow = slow + alpha * (fast - slow)
                    fast.sub_(slow).mul_(self.alpha).add_(slow)
                    slow.copy_(fast)
        return loss

    def zero_grad(self):
        self.optimizer.zero_grad()

# --- SWATS (Simplified) ---
class SWATS(optim.Optimizer):
    """
    Simplified SWATS: Starts as Adam, switches to SGD when condition met.
    """
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, phase='ADAM')
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        for group in self.param_groups:
            # Simple heuristic switch at step 300 for benchmark stability
            for p in group['params']:
                if p.grad is None: continue
                grad = p.grad
                state = self.state[p]
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)

                state['step'] += 1

                # Check switch condition
                if state['step'] > 300:
                    group['phase'] = 'SGD'

                if group['phase'] == 'ADAM':
                    # Adam Logic
                    exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                    beta1, beta2 = group['betas']
                    exp_avg.mul_(beta1).add_(grad, alpha=1-beta1)
                    exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2)
                    denom = exp_avg_sq.sqrt().add_(group['eps'])
                    step_size = group['lr'] * math.sqrt(1 - beta2 ** state['step']) / (1 - beta1 ** state['step'])
                    p.addcdiv_(exp_avg, denom, value=-step_size)
                else:
                    # SGD Logic
                    p.add_(grad, alpha=-group['lr'])

                if group['weight_decay'] > 0:
                    p.mul_(1 - group['lr'] * group['weight_decay'])

# ==============================================
# ====== DATASET & LOADING =====================
# ==============================================

def get_image_dataset(name):
    """
    Returns (loader, model)
    """
    root = './data'

    # --- Transforms ---
    transform_std = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    transform_gray = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    # Default Config
    channels = 3
    classes = 10

    # --- NEW DATASETS (Replacements) ---
    if name == "GTSRB":
        ds = GTSRB(root=root, split='train', download=True, transform=transform_std)
        channels = 3; classes = 43
    elif name == "SEMEION":
        ds = SEMEION(root=root, download=True, transform=transform_gray)
        channels = 1; classes = 10

    # --- Standard Datasets ---
    elif name == "MNIST":
        ds = torchvision.datasets.MNIST(root=root, train=True, download=True, transform=transform_gray); channels = 1
    elif name == "FashionMNIST":
        ds = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform_gray); channels = 1
    elif name == "KMNIST":
        ds = torchvision.datasets.KMNIST(root=root, train=True, download=True, transform=transform_gray); channels = 1
    elif name == "USPS":
        ds = torchvision.datasets.USPS(root=root, train=True, download=True, transform=transform_gray); channels = 1
    elif name == "QMNIST":
        ds = torchvision.datasets.QMNIST(root=root, what='train', download=True, transform=transform_gray); channels = 1
    elif name == "EMNIST":
        ds = torchvision.datasets.EMNIST(root=root, split='mnist', train=True, download=True, transform=transform_gray); channels = 1
    elif name == "CIFAR10":
        ds = torchvision.datasets.CIFAR10(root=root, train=True, download=True, transform=transform_std); channels = 3
    elif name == "CIFAR100":
        ds = torchvision.datasets.CIFAR100(root=root, train=True, download=True, transform=transform_std); channels = 3; classes = 100
    elif name == "SVHN":
        ds = torchvision.datasets.SVHN(root=root, split='train', download=True, transform=transform_std); channels = 3
    elif name == "STL10":
        ds = torchvision.datasets.STL10(root=root, split='train', download=True, transform=transform_std); channels = 3
    else:
        raise ValueError(f"Unknown dataset: {name}")

    # Limit samples for speed
    limit = 5000
    if len(ds) > limit:
        indices = torch.arange(limit)
        subset = torch.utils.data.Subset(ds, indices)
    else:
        subset = ds

    loader = DataLoader(subset, batch_size=BATCH_SIZE, shuffle=True)

    # --- Models ---
    model = nn.Sequential(
        nn.Conv2d(channels, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
        nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
        nn.Flatten(),
        nn.Linear(64*8*8, 256), nn.ReLU(),
        nn.Linear(256, classes)
    ).to(DEVICE)

    return loader, model

def train_and_evaluate(dataset_name, opt_name):
    try:
        loader, model = get_image_dataset(dataset_name)
    except Exception as e:
        print(f"Skipping {dataset_name}: {e}")
        return [], 0, 0.0, 999.9

    criterion = nn.CrossEntropyLoss()
    lr = 1e-3

    # SWA Handling
    swa_model = None
    swa_scheduler = None

    # Select Optimizer
    if opt_name == "Adam": opt = optim.Adam(model.parameters(), lr=lr)
    elif opt_name == "AdamW": opt = optim.AdamW(model.parameters(), lr=lr)
    elif opt_name == "RMSprop": opt = optim.RMSprop(model.parameters(), lr=lr)
    elif opt_name == "Lion": opt = Lion(model.parameters(), lr=lr/3)
    elif opt_name == "MuAdam": opt = MuAdam(model.parameters(), lr=lr)
    elif opt_name == "ThermoLion": opt = ThermoLion(model.parameters(), lr=lr) # RENAMED
    # New Optimizers
    elif opt_name == "Lookahead":
        base_opt = optim.Adam(model.parameters(), lr=lr)
        opt = Lookahead(base_opt)
    elif opt_name == "SWATS": opt = SWATS(model.parameters(), lr=lr)
    elif opt_name == "SWA":
        # SWA requires SGD usually + SWA Utils
        opt = optim.SGD(model.parameters(), lr=lr * 10) # Higher LR for SGD/SWA
        swa_model = AveragedModel(model)
        swa_scheduler = SWALR(opt, swa_lr=0.01)
    else: opt = optim.SGD(model.parameters(), lr=lr)

    epochs = 12
    losses = []

    model.train()
    start_time = time.time()
    for ep in range(epochs):
        running_loss = 0.0
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            opt.zero_grad()
            pred = model(x)
            loss = criterion(pred, y)
            loss.backward()
            opt.step()
            running_loss += loss.item()

        # SWA Update Step at end of epoch
        if opt_name == "SWA" and ep > 5:
            swa_model.update_parameters(model)
            swa_scheduler.step()

        avg_loss = running_loss / len(loader)
        losses.append(avg_loss)

    # If SWA, update BN and use swa_model for eval
    if opt_name == "SWA":
        torch.optim.swa_utils.update_bn(loader, swa_model, device=DEVICE)
        eval_model = swa_model
    else:
        eval_model = model

    end_time = time.time()
    duration = end_time - start_time
    final_loss = losses[-1] if len(losses) > 0 else 0.0

    # Eval
    eval_model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            pred = eval_model(x)
            _, predicted = torch.max(pred.data, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()

    acc = 100 * correct / total
    return losses, duration, acc, final_loss

# ==============================================
# ====== BENCHMARK EXECUTION ===================
# ==============================================

# New datasets first
datasets = [
    "GTSRB", "SEMEION",
    "MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST", "USPS",
    "CIFAR10", "CIFAR100", "SVHN", "STL10"
]

optimizers = [
    "Adam", "AdamW", "RMSprop", "Lion",
    "MuAdam", "ThermoLion", # Renamed here
    "Lookahead", "SWATS", "SWA"
]

results = {ds: {opt: {'loss': [], 'metric': 0.0, 'final_loss': 0.0} for opt in optimizers} for ds in datasets}

print(f"Starting Ultimate Image Benchmark (12 Datasets)...")
print(f"Optimizers: {', '.join(optimizers)}")
print("-" * 110)

for ds in datasets:
    print(f"Processing {ds}...")
    for opt in optimizers:
        try:
            loss_curve, duration, metric, final_loss = train_and_evaluate(ds, opt)
            results[ds][opt]['loss'] = loss_curve
            results[ds][opt]['metric'] = metric
            results[ds][opt]['final_loss'] = final_loss
            print(f"  > {opt:<12} | Loss: {final_loss:.4f} | Acc: {metric:.2f}% | Time: {duration:.2f}s")
        except Exception as e:
            print(f"  > {opt} Failed on {ds}: {e}")

# ==============================================
# ====== PLOTTING & REPORTING ==================
# ==============================================

print("\n" + "="*140)
print(f"{'FINAL ACCURACY SUMMARY':^140}")
print("="*140)
header = f"{'Dataset':<15} | "
for opt in optimizers:
    header += f"{opt[:8]:<10} | "
print(header)
print("-" * len(header))

for ds in datasets:
    row = f"{ds:<15} | "
    for opt in optimizers:
        val = results[ds][opt]['metric']
        row += f"{val:5.1f}      | "
    print(row)
print("="*140 + "\n")

# Plotting with Increased Font Sizes
plt.rcParams.update({'font.size': 14}) # Global increase

cols = 3
rows = (len(datasets) + cols - 1) // cols
fig, axes = plt.subplots(rows, cols, figsize=(24, 6.0 * rows))
axes = axes.flatten()
colors = {opt: plt.get_cmap('tab10')(i) for i, opt in enumerate(optimizers)}

for i, ds in enumerate(datasets):
    ax = axes[i]
    ax.set_title(f"{ds}", fontsize=16, fontweight='bold')
    for opt in optimizers:
        curve = results[ds][opt]['loss']
        if len(curve) > 1:
            ax.plot(curve, label=opt, color=colors.get(opt, 'black'), linewidth=2.0)

    # Larger Tick Labels
    ax.tick_params(axis='both', which='major', labelsize=14)
    ax.grid(True, alpha=0.3)

    # Legend only on first plot to avoid clutter
    if i == 0:
        ax.legend(fontsize=14, ncol=2)

    try: ax.set_yscale('log')
    except: pass

plt.tight_layout()
plt.show()