# Comparación de entrenamiento estándar vs SICReg en datasets de torchvision

Este notebook entrena una **ResNet18** en varios datasets de `torchvision` (MNIST, CIFAR10, CIFAR100) con:

1. **Entrenamiento estándar** (cross-entropy).
2. **Entrenamiento estándar + pérdida SICReg**.

Al final se genera una tabla con los resultados de generalización y gráficos con las curvas de aprendizaje para comparar ambos métodos.


## 1. Imports y configuración


In [None]:
import math
import random
from dataclasses import dataclass
from typing import Dict, List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms, models

import pandas as pd
import matplotlib.pyplot as plt

torch.backends.cudnn.benchmark = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)


## 2. Definición de la pérdida SICReg

La implementación sigue el esquema del ejemplo mínimo (SIGReg) pero usando el nombre **SICReg** para mantener la convención del enunciado.


In [None]:
class SICReg(nn.Module):
    def __init__(self, knots: int = 17):
        super().__init__()
        t = torch.linspace(0, 3, knots, dtype=torch.float32)
        dt = 3 / (knots - 1)
        weights = torch.full((knots,), 2 * dt, dtype=torch.float32)
        weights[[0, -1]] = dt
        window = torch.exp(-t.square() / 2.0)
        self.register_buffer('t', t)
        self.register_buffer('phi', window)
        self.register_buffer('weights', weights * window)

    def forward(self, proj: torch.Tensor) -> torch.Tensor:
        # proj: [V, B, D]
        A = torch.randn(proj.size(-1), 256, device=proj.device)
        A = A.div_(A.norm(p=2, dim=0))
        x_t = (proj @ A).unsqueeze(-1) * self.t
        err = (x_t.cos().mean(-3) - self.phi).square() + x_t.sin().mean(-3).square()
        statistic = (err @ self.weights) * proj.size(-2)
        return statistic.mean()


## 3. Modelo ResNet18 con cabeza de proyección


In [None]:
class ResNet18WithHead(nn.Module):
    def __init__(self, num_classes: int, proj_dim: int = 128):
        super().__init__()
        self.backbone = models.resnet18(weights=None)
        self.backbone.fc = nn.Identity()
        self.classifier = nn.Linear(512, num_classes)
        self.projector = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, proj_dim),
        )

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        features = self.backbone(x)
        logits = self.classifier(features)
        proj = self.projector(features)
        return logits, proj


## 4. Datasets y DataLoaders

Se construyen transformaciones específicas para cada dataset. Para **MNIST** se replica el canal a 3 para adaptarlo a ResNet.


In [None]:
@dataclass
class DatasetConfig:
    name: str
    dataset_cls: type
    num_classes: int
    image_size: int

DATASETS = [
    DatasetConfig('MNIST', datasets.MNIST, 10, 28),
    DatasetConfig('CIFAR10', datasets.CIFAR10, 10, 32),
    DatasetConfig('CIFAR100', datasets.CIFAR100, 100, 32),
]

def build_transforms(image_size: int, is_train: bool, grayscale: bool = False):
    if is_train:
        tfm = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
            transforms.RandomHorizontalFlip(),
        ])
    else:
        tfm = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
        ])

    to_tensor = [transforms.ToTensor()]
    if grayscale:
        to_tensor.append(transforms.Lambda(lambda x: x.repeat(3, 1, 1)))
    return transforms.Compose([tfm] + to_tensor)

class MultiViewDataset(Dataset):
    def __init__(self, base_dataset: Dataset, views: int, transform):
        self.base_dataset = base_dataset
        self.views = views
        self.transform = transform

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        x, y = self.base_dataset[idx]
        return torch.stack([self.transform(x) for _ in range(self.views)]), y


## 5. Funciones de entrenamiento y evaluación


In [None]:
def evaluate(model: nn.Module, loader: DataLoader) -> float:
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            logits, _ = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total

def train_one_epoch_baseline(model, loader, optimizer):
    model.train()
    running_loss = 0.0
    for x, y in loader:
        x = x.to(device)
        y = y.to(device)
        logits, _ = model(x)
        loss = F.cross_entropy(logits, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * y.size(0)
    return running_loss / len(loader.dataset)

def train_one_epoch_sicreg(model, loader, optimizer, sicreg: SICReg, lamb: float):
    model.train()
    running_loss = 0.0
    for views, y in loader:
        views = views.to(device)
        y = y.to(device)
        v1 = views[:, 0]
        v2 = views[:, 1]
        logits1, proj1 = model(v1)
        logits2, proj2 = model(v2)
        cls_loss = 0.5 * (F.cross_entropy(logits1, y) + F.cross_entropy(logits2, y))
        proj = torch.stack([proj1, proj2], dim=0)
        sicreg_loss = sicreg(proj)
        loss = cls_loss + lamb * sicreg_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * y.size(0)
    return running_loss / len(loader.dataset)


## 6. Ejecutar experimentos

Ajustá `EPOCHS`, `BATCH_SIZE` o `LAMB` según recursos.


In [None]:
EPOCHS = 5
BATCH_SIZE = 128
LR = 1e-3
LAMB = 0.05

results: Dict[str, Dict[str, List[float]]] = {}
summary_rows = []

for cfg in DATASETS:
    print(f'==> Dataset: {cfg.name}')
    grayscale = cfg.name == 'MNIST'
    train_tfm = build_transforms(cfg.image_size, is_train=True, grayscale=grayscale)
    test_tfm = build_transforms(cfg.image_size, is_train=False, grayscale=grayscale)

    train_base = cfg.dataset_cls(root='data', train=True, download=True, transform=train_tfm)
    train_raw = cfg.dataset_cls(root='data', train=True, download=True, transform=None)
    test_base = cfg.dataset_cls(root='data', train=False, download=True, transform=test_tfm)

    baseline_loader = DataLoader(train_base, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_base, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    mv_dataset = MultiViewDataset(train_raw, views=2, transform=train_tfm)
    sicreg_loader = DataLoader(mv_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

    # Baseline training
    baseline_model = ResNet18WithHead(cfg.num_classes).to(device)
    baseline_opt = torch.optim.Adam(baseline_model.parameters(), lr=LR)
    baseline_losses = []
    baseline_accs = []
    for epoch in range(EPOCHS):
        loss = train_one_epoch_baseline(baseline_model, baseline_loader, baseline_opt)
        acc = evaluate(baseline_model, test_loader)
        baseline_losses.append(loss)
        baseline_accs.append(acc)
        print(f'  [Baseline] Epoch {epoch+1}/{EPOCHS} - loss: {loss:.4f} - acc: {acc:.4f}')

    # SICReg training
    sicreg_model = ResNet18WithHead(cfg.num_classes).to(device)
    sicreg_opt = torch.optim.Adam(sicreg_model.parameters(), lr=LR)
    sicreg_losses = []
    sicreg_accs = []
    sicreg_loss_fn = SICReg().to(device)
    for epoch in range(EPOCHS):
        loss = train_one_epoch_sicreg(sicreg_model, sicreg_loader, sicreg_opt, sicreg_loss_fn, LAMB)
        acc = evaluate(sicreg_model, test_loader)
        sicreg_losses.append(loss)
        sicreg_accs.append(acc)
        print(f'  [SICReg]  Epoch {epoch+1}/{EPOCHS} - loss: {loss:.4f} - acc: {acc:.4f}')

    results[cfg.name] = {
        'baseline_loss': baseline_losses,
        'baseline_acc': baseline_accs,
        'sicreg_loss': sicreg_losses,
        'sicreg_acc': sicreg_accs,
    }
    summary_rows.append({
        'Dataset': cfg.name,
        'Baseline Acc (last)': baseline_accs[-1],
        'SICReg Acc (last)': sicreg_accs[-1],
    })


## 7. Tabla comparativa de generalización


In [None]:
summary_df = pd.DataFrame(summary_rows)
summary_df


## 8. Curvas de aprendizaje por dataset


In [None]:
for dataset_name, data in results.items():
    epochs = list(range(1, len(data['baseline_acc']) + 1))
    plt.figure(figsize=(6, 4))
    plt.plot(epochs, data['baseline_acc'], label='Baseline')
    plt.plot(epochs, data['sicreg_acc'], label='SICReg')
    plt.title(f'Accuracy vs Epochs - {dataset_name}')
    plt.xlabel('Epoch')
    plt.ylabel('Test Accuracy')
    plt.legend()
    plt.grid(True)
    plt.show()


---
**Notas:**
- Para resultados más robustos aumentá `EPOCHS` y considerá usar un scheduler.
- Podés agregar más datasets de `torchvision` extendiendo la lista `DATASETS`.
