# LeJEPA completo (SICReg + JEPA) en MPI3D / dSprites

Este notebook entrena **LeJEPA completo** usando el esquema de `MINIMAL.md` y el loader reducido de `dataset.py` para los datasets **MPI3D** y **dSprites**.  
Luego entrena *probes* lineales sobre las representaciones para evaluar cada factor.

## 1. Imports

In [None]:
import math
from dataclasses import dataclass
from pathlib import Path

import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import GradScaler, autocast
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from torch.utils.data import DataLoader, Dataset
from torchvision.ops import MLP
from torchvision.transforms import v2
from tqdm import tqdm

from dataset import IDGBenchmarkDataset, IDGDatasetName, IDGSplitName

## 2. Configuración

In [None]:
@dataclass
class Config:
    data_root: str = "./data"
    dataset: IDGDatasetName = "dsprites"  # "dsprites" o "mpi3d"
    split: IDGSplitName = "random"
    backbone: str = "vit_small_patch8_224"
    img_size: int = 64
    proj_dim: int = 128
    views: int = 2
    epochs: int = 100
    batch_size: int = 256
    lr: float = 2e-3
    weight_decay: float = 5e-2
    lamb: float = 0.02
    num_workers: int = 8
    device: str = "cuda"
    amp: bool = True
    probe_epochs: int = 50
    probe_lr: float = 1e-3
    probe_weight_decay: float = 1e-6
    probe_batch_size: int = 256

cfg = Config()

## 3. Dataset + augmentaciones

In [None]:
class MultiViewIDGDataset(Dataset):
    def __init__(self, root, dataset, split, mode, views, transform):
        self.views = views
        self.transform = transform
        self.ds = IDGBenchmarkDataset(
            root=root,
            dataset=dataset,
            split=split,
            mode=mode,
            image_as_float=True,
            latents_dtype=torch.long,
            transform=None,
        )

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        img, latents, _ = self.ds[idx]
        views = torch.stack([self.transform(img) for _ in range(self.views)])
        return views, latents


def build_transforms(img_size: int, train: bool):
    if train:
        return v2.Compose(
            [
                v2.RandomResizedCrop(img_size, scale=(0.6, 1.0)),
                v2.RandomHorizontalFlip(),
                v2.RandomApply([v2.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.3),
            ]
        )
    return v2.Compose([v2.Resize(img_size), v2.CenterCrop(img_size)])

train_transform = build_transforms(cfg.img_size, train=True)
test_transform = build_transforms(cfg.img_size, train=False)

train_ds = MultiViewIDGDataset(
    root=cfg.data_root,
    dataset=cfg.dataset,
    split=cfg.split,
    mode="train",
    views=cfg.views,
    transform=train_transform,
)
train_loader = DataLoader(
    train_ds,
    batch_size=cfg.batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=cfg.num_workers,
    pin_memory=True,
    persistent_workers=cfg.num_workers > 0,
)

## 4. Modelo LeJEPA completo (SICReg + pérdida JEPA)

In [None]:
class SICReg(nn.Module):
    def __init__(self, knots: int = 17, t_max: float = 3.0):
        super().__init__()
        t = torch.linspace(0, t_max, knots, dtype=torch.float32)
        dt = t_max / (knots - 1)
        weights = torch.full((knots,), 2 * dt, dtype=torch.float32)
        weights[[0, -1]] = dt
        window = torch.exp(-t.square() / 2.0)
        self.register_buffer("t", t)
        self.register_buffer("phi", window)
        self.register_buffer("weights", weights * window)

    def forward(self, proj: torch.Tensor) -> torch.Tensor:
        proj_dim = proj.size(-1)
        sketch_dim = min(256, proj_dim)
        A = torch.randn(proj_dim, sketch_dim, device=proj.device)
        A = A.div_(A.norm(p=2, dim=0))
        x_t = (proj @ A).unsqueeze(-1) * self.t
        err = (x_t.cos().mean(-3) - self.phi).square() + x_t.sin().mean(-3).square()
        statistic = (err @ self.weights) * proj.size(-2)
        return statistic.mean()


class ViTEncoder(nn.Module):
    def __init__(self, backbone: str, img_size: int, proj_dim: int):
        super().__init__()
        self.backbone = timm.create_model(
            backbone,
            pretrained=False,
            num_classes=512,
            drop_path_rate=0.1,
            img_size=img_size,
        )
        self.proj = MLP(512, [2048, 2048, proj_dim], norm_layer=nn.BatchNorm1d)

    def forward(self, x):
        n, v = x.shape[:2]
        emb = self.backbone(x.flatten(0, 1))
        proj = self.proj(emb).reshape(n, v, -1).transpose(0, 1)
        return emb, proj


device = torch.device(cfg.device)
model = ViTEncoder(cfg.backbone, cfg.img_size, cfg.proj_dim).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

warmup_steps = max(1, len(train_loader))
total_steps = len(train_loader) * cfg.epochs
scheduler = SequentialLR(
    optimizer,
    schedulers=[
        LinearLR(optimizer, start_factor=0.01, total_iters=warmup_steps),
        CosineAnnealingLR(optimizer, T_max=max(1, total_steps - warmup_steps), eta_min=1e-4),
    ],
    milestones=[warmup_steps],
)
scaler = GradScaler(enabled=cfg.amp)
sicreg = SICReg().to(device)

## 5. Entrenamiento LeJEPA

In [None]:
for epoch in range(cfg.epochs):
    model.train()
    inv_loss_sum = 0.0
    sicreg_sum = 0.0
    lejepa_sum = 0.0
    progress = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{cfg.epochs}")
    for views, _ in progress:
        views = views.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        with autocast(device_type=device.type, dtype=torch.bfloat16, enabled=cfg.amp):
            _, proj = model(views)
            inv_loss = (proj.mean(0) - proj).square().mean()
            sicreg_loss = sicreg(proj)
            lejepa_loss = sicreg_loss * cfg.lamb + inv_loss * (1 - cfg.lamb)
        scaler.scale(lejepa_loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        inv_loss_sum += inv_loss.item()
        sicreg_sum += sicreg_loss.item()
        lejepa_sum += lejepa_loss.item()
        step = progress.n + 1
        progress.set_postfix(
            inv=inv_loss_sum / step,
            sicreg=sicreg_sum / step,
            lejepa=lejepa_sum / step,
        )

## 6. Entrenar probes por factor

In [None]:
probe_train_ds = IDGBenchmarkDataset(
    root=cfg.data_root,
    dataset=cfg.dataset,
    split=cfg.split,
    mode="train",
    image_as_float=True,
    latents_dtype=torch.long,
    transform=test_transform,
)
probe_test_ds = IDGBenchmarkDataset(
    root=cfg.data_root,
    dataset=cfg.dataset,
    split=cfg.split,
    mode="test",
    image_as_float=True,
    latents_dtype=torch.long,
    transform=test_transform,
)

probe_train_loader = DataLoader(
    probe_train_ds,
    batch_size=cfg.probe_batch_size,
    shuffle=True,
    num_workers=cfg.num_workers,
    pin_memory=True,
    persistent_workers=cfg.num_workers > 0,
)
probe_test_loader = DataLoader(
    probe_test_ds,
    batch_size=cfg.probe_batch_size,
    shuffle=False,
    num_workers=cfg.num_workers,
    pin_memory=True,
    persistent_workers=cfg.num_workers > 0,
)

labels = probe_train_ds._labels
if labels.ndim == 1:
    labels = labels[:, None]
num_classes = [int(labels[:, i].max()) + 1 for i in range(labels.shape[1])]

probes = nn.ModuleList([nn.Linear(512, n) for n in num_classes]).to(device)
probe_opt = torch.optim.AdamW(probes.parameters(), lr=cfg.probe_lr, weight_decay=cfg.probe_weight_decay)

for epoch in range(cfg.probe_epochs):
    probes.train()
    running = 0.0
    progress = tqdm(probe_train_loader, desc=f"Probe {epoch + 1}/{cfg.probe_epochs}")
    for imgs, latents, _ in progress:
        imgs = imgs.to(device, non_blocking=True)
        latents = latents.to(device, non_blocking=True)
        with torch.no_grad():
            emb, _ = model(imgs[:, None])
        losses = [F.cross_entropy(head(emb), latents[:, i]) for i, head in enumerate(probes)]
        loss = sum(losses)
        probe_opt.zero_grad(set_to_none=True)
        loss.backward()
        probe_opt.step()
        running += loss.item()
        progress.set_postfix(loss=running / (progress.n + 1))

probes.eval()
correct = [0 for _ in num_classes]
total = 0
with torch.inference_mode():
    for imgs, latents, _ in tqdm(probe_test_loader, desc="Eval probes"):
        imgs = imgs.to(device, non_blocking=True)
        latents = latents.to(device, non_blocking=True)
        emb, _ = model(imgs[:, None])
        for i, head in enumerate(probes):
            pred = head(emb).argmax(dim=1)
            correct[i] += (pred == latents[:, i]).sum().item()
        total += latents.size(0)

for i, c in enumerate(correct):
    print(f"Factor {i}: {c / total:.4f}")