# Comparación de entrenamiento estándar vs SICReg en datasets de torchvision

Este notebook entrena una **ResNet18** en varios datasets de `torchvision` (MNIST, CIFAR10, CIFAR100) con:

1. **Entrenamiento estándar** (cross-entropy).
2. **Entrenamiento estándar + pérdida SICReg**.

Al final se genera una tabla con los resultados de generalización y gráficos con las curvas de aprendizaje para comparar ambos métodos.


## 1. Imports y configuración


In [1]:
import json
import math
import random
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms, models

import pandas as pd
import matplotlib.pyplot as plt

torch.backends.cudnn.benchmark = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)


Device: cuda


## 1.1 Reproducibilidad y registro de experimentos


In [None]:
SEED = 42

def seed_everything(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(SEED)

RUN_DIR = Path('experiments')
CHECKPOINT_DIR = RUN_DIR / 'checkpoints'
LOG_PATH = RUN_DIR / 'sicreg_runs.json'
RUN_DIR.mkdir(parents=True, exist_ok=True)
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

def load_runs() -> Dict[str, dict]:
    if LOG_PATH.exists():
        return json.loads(LOG_PATH.read_text())
    return {}

def save_runs(runs: Dict[str, dict]) -> None:
    LOG_PATH.write_text(json.dumps(runs, indent=2))

class EarlyStoppingScheduler:
    def __init__(self, patience: int = 5):
        self.patience = patience
        self.best = None
        self.epochs_without_improve = 0

    def should_stop(self, metric: float) -> bool:
        if self.best is None or metric > self.best:
            self.best = metric
            self.epochs_without_improve = 0
            return False
        self.epochs_without_improve += 1
        return self.epochs_without_improve >= self.patience


## 2. Definición de la pérdida SICReg

La implementación sigue el esquema del ejemplo mínimo (SIGReg) pero usando el nombre **SICReg** para mantener la convención del enunciado.


In [2]:
class SICReg(nn.Module):
    def __init__(self, knots: int = 17):
        super().__init__()
        t = torch.linspace(0, 3, knots, dtype=torch.float32)
        dt = 3 / (knots - 1)
        weights = torch.full((knots,), 2 * dt, dtype=torch.float32)
        weights[[0, -1]] = dt
        window = torch.exp(-t.square() / 2.0)
        self.register_buffer('t', t)
        self.register_buffer('phi', window)
        self.register_buffer('weights', weights * window)

    def forward(self, proj: torch.Tensor) -> torch.Tensor:
        # proj: [V, B, D]
        A = torch.randn(proj.size(-1), 256, device=proj.device)
        A = A.div_(A.norm(p=2, dim=0))
        x_t = (proj @ A).unsqueeze(-1) * self.t
        err = (x_t.cos().mean(-3) - self.phi).square() + x_t.sin().mean(-3).square()
        statistic = (err @ self.weights) * proj.size(-2)
        return statistic.mean()


## 3. Modelo ResNet18 con cabeza de proyección


In [3]:
class ResNet18WithHead(nn.Module):
    def __init__(self, num_classes: int, proj_dim: int = 128):
        super().__init__()
        self.backbone = models.resnet18(weights=None)
        self.backbone.fc = nn.Identity()
        self.classifier = nn.Linear(512, num_classes)
        self.projector = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, proj_dim),
        )

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        features = self.backbone(x)
        logits = self.classifier(features)
        proj = self.projector(features)
        return logits, proj


## 4. Datasets y DataLoaders

Se construyen transformaciones específicas para cada dataset. Para **MNIST** se replica el canal a 3 para adaptarlo a ResNet.


In [4]:
@dataclass
class DatasetConfig:
    name: str
    dataset_cls: type
    num_classes: int
    image_size: int

DATASETS = [
    DatasetConfig('MNIST', datasets.MNIST, 10, 28),
    DatasetConfig('CIFAR10', datasets.CIFAR10, 10, 32),
    DatasetConfig('CIFAR100', datasets.CIFAR100, 100, 32),
]

def build_transforms(image_size: int, is_train: bool, grayscale: bool = False):
    if is_train:
        tfm = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
            transforms.RandomHorizontalFlip(),
        ])
    else:
        tfm = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
        ])

    to_tensor = [transforms.ToTensor()]
    if grayscale:
        to_tensor.append(transforms.Lambda(lambda x: x.repeat(3, 1, 1)))
    return transforms.Compose([tfm] + to_tensor)

class MultiViewDataset(Dataset):
    def __init__(self, base_dataset: Dataset, views: int, transform):
        self.base_dataset = base_dataset
        self.views = views
        self.transform = transform

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        x, y = self.base_dataset[idx]
        return torch.stack([self.transform(x) for _ in range(self.views)]), y


## 5. Funciones de entrenamiento y evaluación


In [11]:
def evaluate(model: nn.Module, loader: DataLoader) -> float:
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            logits, _ = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total

def train_one_epoch_sicreg(model, loader, optimizer, sicreg: SICReg, lamb: float):
    model.train()
    running_loss = 0.0
    running_cls_loss = 0.0
    running_sicreg_loss = 0.0
    for x, y in loader:
        x = x.to(device)
        y = y.to(device)
        logits, proj = model(x)
        cls_loss = F.cross_entropy(logits, y)
        sicreg_loss = sicreg(proj)
        loss = cls_loss + lamb * sicreg_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * y.size(0)
        running_cls_loss += cls_loss.item() * y.size(0)
        running_sicreg_loss += sicreg_loss.item() * y.size(0)
    return running_loss / len(loader.dataset), running_cls_loss / len(loader.dataset), running_sicreg_loss / len(loader.dataset)

def train_one_epoch_baseline(model, loader, optimizer):
    model.train()
    running_loss = 0.0
    for x, y in loader:
        x = x.to(device)
        y = y.to(device)
        logits, _ = model(x)
        loss = F.cross_entropy(logits, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * y.size(0)
    return running_loss / len(loader.dataset)




## 6. Ejecutar experimentos

Ajustá `EPOCHS`, `BATCH_SIZE` o `LAMB` según recursos.


In [None]:
from tqdm.notebook import tqdm
EPOCHS = 10
BATCH_SIZE = 256
LR = 1e-3
LAMB = 0.05
EARLY_STOP_PATIENCE = 5

results: Dict[str, Dict[str, List[float]]] = {}
summary_rows = []
runs = load_runs()

for cfg in tqdm(DATASETS, desc='Datasets'):
    print(f'==> Dataset: {cfg.name}')
    grayscale = cfg.name == 'MNIST'
    train_tfm = build_transforms(cfg.image_size, is_train=True, grayscale=grayscale)
    test_tfm = build_transforms(cfg.image_size, is_train=False, grayscale=grayscale)

    train_base = cfg.dataset_cls(root='data', train=True, download=True, transform=train_tfm)
    train_raw = cfg.dataset_cls(root='data', train=True, download=True, transform=None)
    test_base = cfg.dataset_cls(root='data', train=False, download=True, transform=test_tfm)

    baseline_loader = DataLoader(train_base, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
    test_loader = DataLoader(test_base, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

    mv_dataset = MultiViewDataset(train_raw, views=2, transform=train_tfm)
    sicreg_loader = DataLoader(mv_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)

    run_key = f'{cfg.name}-seed{SEED}-epochs{EPOCHS}-lr{LR}-lamb{LAMB}-bs{BATCH_SIZE}'
    run_state = runs.get(run_key)

    if run_state and run_state.get('status') == 'completed':
        print(f'  [SICReg] Resultado ya completado, se omite: {run_key}')
        results[cfg.name] = run_state['metrics']
        summary_rows.append({
            'Dataset': cfg.name,
            'Baseline Acc (last)': run_state['metrics']['baseline_acc'][-1] if run_state['metrics']['baseline_acc'] else None,
            'SICReg Acc (last)': run_state['metrics']['sicreg_acc'][-1] if run_state['metrics']['sicreg_acc'] else None,
            'Baseline Acc (best)': max(run_state['metrics']['baseline_acc']) if run_state['metrics']['baseline_acc'] else None,
            'SICReg Acc (best)': max(run_state['metrics']['sicreg_acc']) if run_state['metrics']['sicreg_acc'] else None,
        })
        continue

    # Baseline training
    baseline_model = ResNet18WithHead(cfg.num_classes).to(device)
    baseline_opt = torch.optim.Adam(baseline_model.parameters(), lr=LR)
    baseline_losses = []
    baseline_accs = []
    baseline_early_stopper = EarlyStoppingScheduler(patience=EARLY_STOP_PATIENCE)
    baseline_checkpoint_path = CHECKPOINT_DIR / f'{run_key}-baseline.pt'
    baseline_start_epoch = 1

    if run_state:
        baseline_losses = run_state['metrics'].get('baseline_loss', [])
        baseline_accs = run_state['metrics'].get('baseline_acc', [])
        baseline_early_stopper.best = run_state.get('baseline_best_acc')
        baseline_early_stopper.epochs_without_improve = run_state.get('baseline_epochs_without_improve', 0)
        if baseline_checkpoint_path.exists():
            checkpoint = torch.load(baseline_checkpoint_path, map_location=device)
            baseline_model.load_state_dict(checkpoint['model_state'])
            baseline_opt.load_state_dict(checkpoint['optimizer_state'])
            baseline_start_epoch = checkpoint['epoch'] + 1
            print(f'  [Baseline] Reanudando desde época {baseline_start_epoch}')

    if run_state and run_state.get('baseline_status') == 'completed':
        print(f'  [Baseline] Resultado ya completado, se omite: {run_key}')
    else:
        for epoch in tqdm(range(baseline_start_epoch, EPOCHS + 1), desc='Training Baseline'):
            loss = train_one_epoch_baseline(baseline_model, baseline_loader, baseline_opt)
            acc = evaluate(baseline_model, test_loader)
            baseline_losses.append(loss)
            baseline_accs.append(acc)
            print(f'  [Baseline] Epoch {epoch}/{EPOCHS} - loss: {loss:.4f} - acc: {acc:.4f}')

            should_stop = baseline_early_stopper.should_stop(acc)
            runs[run_key] = {
                'status': 'in_progress',
                'epoch': run_state.get('epoch', 0) if run_state else 0,
                'baseline_epoch': epoch,
                'baseline_best_acc': baseline_early_stopper.best,
                'baseline_epochs_without_improve': baseline_early_stopper.epochs_without_improve,
                'baseline_status': 'in_progress',
                'metrics': {
                    'baseline_loss': baseline_losses,
                    'baseline_acc': baseline_accs,
                    'sicreg_total_loss': run_state['metrics'].get('sicreg_total_loss', []) if run_state else [],
                    'sicreg_cls_loss': run_state['metrics'].get('sicreg_cls_loss', []) if run_state else [],
                    'sicreg_loss': run_state['metrics'].get('sicreg_loss', []) if run_state else [],
                    'sicreg_acc': run_state['metrics'].get('sicreg_acc', []) if run_state else [],
                },
            }
            save_runs(runs)
            torch.save({
                'epoch': epoch,
                'model_state': baseline_model.state_dict(),
                'optimizer_state': baseline_opt.state_dict(),
            }, baseline_checkpoint_path)
            if should_stop:
                print(f'  [Baseline] Early stopping activado en época {epoch}')
                break

        runs[run_key] = {
            'status': 'in_progress',
            'epoch': run_state.get('epoch', 0) if run_state else 0,
            'baseline_epoch': len(baseline_accs),
            'baseline_best_acc': baseline_early_stopper.best,
            'baseline_epochs_without_improve': baseline_early_stopper.epochs_without_improve,
            'baseline_status': 'completed',
            'metrics': {
                'baseline_loss': baseline_losses,
                'baseline_acc': baseline_accs,
                'sicreg_total_loss': run_state['metrics'].get('sicreg_total_loss', []) if run_state else [],
                'sicreg_cls_loss': run_state['metrics'].get('sicreg_cls_loss', []) if run_state else [],
                'sicreg_loss': run_state['metrics'].get('sicreg_loss', []) if run_state else [],
                'sicreg_acc': run_state['metrics'].get('sicreg_acc', []) if run_state else [],
            },
        }
        save_runs(runs)

    # SICReg training
    sicreg_model = ResNet18WithHead(cfg.num_classes).to(device)
    sicreg_opt = torch.optim.Adam(sicreg_model.parameters(), lr=LR)
    sicreg_total_losses = []
    sicreg_cls_losses = []
    sicreg_losses = []
    sicreg_accs = []
    sicreg_loss_fn = SICReg().to(device)
    early_stopper = EarlyStoppingScheduler(patience=EARLY_STOP_PATIENCE)

    checkpoint_path = CHECKPOINT_DIR / f'{run_key}.pt'
    start_epoch = 1
    if run_state:
        sicreg_total_losses = run_state['metrics'].get('sicreg_total_loss', [])
        sicreg_cls_losses = run_state['metrics'].get('sicreg_cls_loss', [])
        sicreg_losses = run_state['metrics'].get('sicreg_loss', [])
        sicreg_accs = run_state['metrics'].get('sicreg_acc', [])
        early_stopper.best = run_state.get('best_acc')
        early_stopper.epochs_without_improve = run_state.get('epochs_without_improve', 0)
        if checkpoint_path.exists():
            checkpoint = torch.load(checkpoint_path, map_location=device)
            sicreg_model.load_state_dict(checkpoint['model_state'])
            sicreg_opt.load_state_dict(checkpoint['optimizer_state'])
            start_epoch = checkpoint['epoch'] + 1
            print(f'  [SICReg] Reanudando desde época {start_epoch}')

    for epoch in tqdm(range(start_epoch, EPOCHS + 1), desc='Training with SICReg'):
        loss, cls_loss, sicreg_loss = train_one_epoch_sicreg(sicreg_model, baseline_loader, sicreg_opt, sicreg_loss_fn, LAMB)
        acc = evaluate(sicreg_model, test_loader)
        sicreg_total_losses.append(loss)
        sicreg_cls_losses.append(cls_loss)
        sicreg_losses.append(sicreg_loss)
        sicreg_accs.append(acc)
        print(f'  [SICReg]  Epoch {epoch}/{EPOCHS} - loss: {loss:.4f} - acc: {acc:.4f} - cls: {cls_loss:.4f} - sicreg: {sicreg_loss:.4f}')

        should_stop = early_stopper.should_stop(acc)
        runs[run_key] = {
            'status': 'in_progress',
            'epoch': epoch,
            'best_acc': early_stopper.best,
            'epochs_without_improve': early_stopper.epochs_without_improve,
            'baseline_epoch': len(baseline_accs),
            'baseline_best_acc': baseline_early_stopper.best,
            'baseline_epochs_without_improve': baseline_early_stopper.epochs_without_improve,
            'baseline_status': run_state.get('baseline_status', 'completed') if run_state else 'completed',
            'metrics': {
                'baseline_loss': baseline_losses,
                'baseline_acc': baseline_accs,
                'sicreg_total_loss': sicreg_total_losses,
                'sicreg_cls_loss': sicreg_cls_losses,
                'sicreg_loss': sicreg_losses,
                'sicreg_acc': sicreg_accs,
            },
        }
        save_runs(runs)
        torch.save({
            'epoch': epoch,
            'model_state': sicreg_model.state_dict(),
            'optimizer_state': sicreg_opt.state_dict(),
        }, checkpoint_path)
        if should_stop:
            print(f'  [SICReg] Early stopping activado en época {epoch}')
            break

    runs[run_key] = {
        'status': 'completed',
        'epoch': len(sicreg_accs),
        'best_acc': early_stopper.best,
        'epochs_without_improve': early_stopper.epochs_without_improve,
        'baseline_epoch': len(baseline_accs),
        'baseline_best_acc': baseline_early_stopper.best,
        'baseline_epochs_without_improve': baseline_early_stopper.epochs_without_improve,
        'baseline_status': 'completed',
        'metrics': {
            'baseline_loss': baseline_losses,
            'baseline_acc': baseline_accs,
            'sicreg_total_loss': sicreg_total_losses,
            'sicreg_cls_loss': sicreg_cls_losses,
            'sicreg_loss': sicreg_losses,
            'sicreg_acc': sicreg_accs,
        },
    }
    save_runs(runs)

    results[cfg.name] = {
        'baseline_loss': baseline_losses,
        'baseline_acc': baseline_accs,
        'sicreg_loss': sicreg_losses,
        'sicreg_acc': sicreg_accs,
    }
    summary_rows.append({
        'Dataset': cfg.name,
        'Baseline Acc (last)': baseline_accs[-1] if baseline_accs else None,
        'SICReg Acc (last)': sicreg_accs[-1] if sicreg_accs else None,
        'Baseline Acc (best)': max(baseline_accs) if baseline_accs else None,
        'SICReg Acc (best)': max(sicreg_accs) if sicreg_accs else None,
    })


Datasets:   0%|          | 0/3 [00:00<?, ?it/s]

==> Dataset: MNIST


Training with SICReg:   0%|          | 0/10 [00:00<?, ?it/s]

  [SICReg]  Epoch 1/10 - loss: 0.7582 - acc: 0.9600 - cls: 0.2600 - sicreg: 9.9624
  [SICReg]  Epoch 2/10 - loss: 0.3064 - acc: 0.9639 - cls: 0.1018 - sicreg: 4.0934
  [SICReg]  Epoch 3/10 - loss: 0.2407 - acc: 0.9768 - cls: 0.0808 - sicreg: 3.1983


## 7. Tabla comparativa de generalización


In [None]:
summary_df = pd.DataFrame(summary_rows)
summary_df


## 8. Curvas de aprendizaje por dataset


In [None]:
for dataset_name, data in results.items():
    epochs = list(range(1, len(data['baseline_acc']) + 1))
    plt.figure(figsize=(6, 4))
    plt.plot(epochs, data['baseline_acc'], label='Baseline')
    plt.plot(epochs, data['sicreg_acc'], label='SICReg')
    plt.title(f'Accuracy vs Epochs - {dataset_name}')
    plt.xlabel('Epoch')
    plt.ylabel('Test Accuracy')
    plt.legend()
    plt.grid(True)
    plt.show()


---
**Notas:**
- Para resultados más robustos aumentá `EPOCHS` y considerá usar un scheduler.
- Podés agregar más datasets de `torchvision` extendiendo la lista `DATASETS`.
