# I-JEPA Tutorial Notebook (Standard Setup, CIFAR-10)
# --------------------------------------------------
# This notebook follows the *original I-JEPA design philosophy*:
# - Context encoder (student)
# - Target encoder (teacher, EMA)
# - No reconstruction loss
# - Predict latent representations of masked target regions
# - No negative samples, no contrastive loss


In [None]:
# ==================================================
# 1. Imports & setup
# ==================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from timm.models.vision_transformer import VisionTransformer
from tqdm import tqdm
import math, random, os

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# ==================================================
# 2. Hyperparameters (CIFAR-friendly, I-JEPA style)
# ==================================================
IMG_SIZE = 32
PATCH = 8                     # 4x4 grid => 16 patches (important for blocks)
NUM_PATCHES = (IMG_SIZE // PATCH) ** 2

CTX_RATIO = 0.5               # fraction of patches kept as context
CTX_KEEP = int(CTX_RATIO * NUM_PATCHES)

EMB_DIM = 384
DEPTH_ENC = 6
DEPTH_PRED = 3
HEADS = 6
MLP_RATIO = 4

BATCH = 512
EPOCHS = 90
BASE_LR = 1.5e-4
EMA_MOMENTUM = 0.996


In [None]:
# ==================================================
# 3. Dataset (CIFAR-10)
# ==================================================
transform = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.6, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010))
])

train_set = datasets.CIFAR10(root='.', train=True, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=BATCH, shuffle=True,
                          drop_last=True, num_workers=4)


# Models 

In [None]:
# ==================================================
# 4. Encoders
# ==================================================
class IJEPAViT(nn.Module):
    """
    ViT backbone without class token.
    Outputs patch-level embeddings.
    """
    def __init__(self):
        super().__init__()
        self.vit = VisionTransformer(
            img_size=IMG_SIZE,
            patch_size=PATCH,
            in_chans=3,
            embed_dim=EMB_DIM,
            depth=DEPTH_ENC,
            num_heads=HEADS,
            mlp_ratio=MLP_RATIO,
            num_classes=0,
            global_pool=''
        )

    def forward(self, x, patch_idx=None):
        x = self.vit.patch_embed(x)              # B, N, D
        x = x + self.vit.pos_embed[:, 1:, :]
        B, N, D = x.shape

        if patch_idx is not None:
            x = x[torch.arange(B)[:, None], patch_idx]

        x = self.vit.blocks(x)
        x = self.vit.norm(x)
        return x                                 # B, n_patches, D

In [None]:
# ==================================================
# 5. Predictor (latent-space prediction head)
# ==================================================
class Predictor(nn.Module):
    def __init__(self):
        super().__init__()
        self.mask_token = nn.Parameter(torch.zeros(1, 1, EMB_DIM))
        self.blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=EMB_DIM,
                nhead=HEADS,
                dim_feedforward=EMB_DIM * MLP_RATIO,
                batch_first=True
            ) for _ in range(DEPTH_PRED)
        ])
        self.norm = nn.LayerNorm(EMB_DIM)

    def forward(self, context, queries):
        x = torch.cat([context, queries], dim=1)
        for blk in self.blocks:
            x = blk(x)
        return self.norm(x[:, -queries.shape[1]:])


# Model Initialization

In [None]:
# ==================================================
# 6. Mask sampling (I-JEPA-style)
# ==================================================

def sample_context_patches(B, N, keep):
    return torch.stack([
        torch.randperm(N, device=device)[:keep]
        for _ in range(B)
    ])


def sample_target_patches(B, N, num_targets=4):
    return torch.stack([
        torch.randperm(N, device=device)[:num_targets]
        for _ in range(B)
    ])


In [None]:
# ==================================================
# 7. Model instantiation
# ==================================================
context_encoder = IJEPAViT().to(device)
target_encoder = IJEPAViT().to(device)
predictor = Predictor().to(device)

# Initialize target encoder = context encoder
with torch.no_grad():
    for ps, pt in zip(context_encoder.parameters(), target_encoder.parameters()):
        pt.copy_(ps)
        pt.requires_grad = False

optimizer = torch.optim.AdamW(
    list(context_encoder.parameters()) + list(predictor.parameters()),
    lr=BASE_LR, weight_decay=0.05
)

scaler = torch.cuda.amp.GradScaler()


# Training loop (90 epochs, cosine LR, EMA)

In [None]:
# ==================================================
# 8. Training loop (I-JEPA loss)
# ==================================================

def cosine_schedule(step, total_steps):
    return 0.5 * (1 + math.cos(math.pi * step / total_steps))

steps_per_epoch = len(train_loader)
total_steps = steps_per_epoch * EPOCHS

for epoch in range(EPOCHS):
    running_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")

    for it, (x, _) in enumerate(pbar):
        x = x.to(device)
        B = x.size(0)

        ctx_idx = sample_context_patches(B, NUM_PATCHES, CTX_KEEP)
        tgt_idx = sample_target_patches(B, NUM_PATCHES, num_targets=4)

        # Context encoding
        z_ctx = context_encoder(x, ctx_idx)

        # Target encoding (EMA, no grad)
        with torch.no_grad():
            z_tgt = target_encoder(x)
            z_tgt = z_tgt[torch.arange(B)[:, None], tgt_idx]

        # Predictor
        queries = predictor.mask_token.expand(B, tgt_idx.size(1), -1)
        z_pred = predictor(z_ctx, queries)

        # I-JEPA loss (cosine similarity)
        z_pred = F.normalize(z_pred, dim=-1)
        z_tgt = F.normalize(z_tgt, dim=-1)
        loss = 2 - 2 * (z_pred * z_tgt).sum(dim=-1).mean()

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # EMA update
        with torch.no_grad():
            step = epoch * steps_per_epoch + it
            m = 1 - (1 - EMA_MOMENTUM) * cosine_schedule(step, total_steps)
            for ps, pt in zip(context_encoder.parameters(), target_encoder.parameters()):
                pt.mul_(m).add_((1 - m) * ps)

        running_loss += loss.item()
        pbar.set_postfix(loss=loss.item())

    print(f"Epoch {epoch:02d} | Avg Loss: {running_loss / steps_per_epoch:.4f}")


# Linear evaluation (frozen target encoder)

In [8]:
from sklearn.linear_model import LogisticRegression


# @torch.no_grad()
# def extract_features(loader, encoder):
#     feats, labels = [], []
#     for x, y in loader:
#         x = x.to(device)
#         #z = encoder(x).mean(dim=1)
#         z = encoder(x); z = z.reshape(B,4,4,-1).mean([1,2])
#         feats.append(F.normalize(z, dim=1).cpu())
#         labels.append(y)
#     return torch.cat(feats), torch.cat(labels)

@torch.no_grad()
def extract_features(loader, encoder):
    feats, labels = [], []
    for x, y in loader:
        x = x.to(device)
        z = encoder(x)                      # B, 16, 384
        z = z.reshape(z.size(0), 4, 4, -1)  # 4x4 spatial grid
        z = z.mean(dim=[1, 2])              # spatial average pooling
        feats.append(F.normalize(z, dim=1).cpu())
        labels.append(y)
    return torch.cat(feats), torch.cat(labels)

train_loader_eval = DataLoader(datasets.CIFAR10('.', train=True, transform=transform),
                               batch_size=500, shuffle=False)
test_loader_eval  = DataLoader(datasets.CIFAR10('.', train=False, transform=transform),
                               batch_size=500, shuffle=False)

z_tr, y_tr = extract_features(train_loader_eval, target_encoder)
z_te, y_te = extract_features(test_loader_eval, target_encoder)

clf = LogisticRegression(max_iter=1000, solver='lbfgs', n_jobs=-1)
clf.fit(z_tr, y_tr)
print("Linear probe accuracy:", clf.score(z_te, y_te))

Linear probe accuracy: 0.5142
