In [193]:
!pip -q install timm

import torch, random, math
import torch.nn as nn
import torch.nn.functional as F
import timm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

device = "cuda" if torch.cuda.is_available() else "cpu"
device


'cuda'

In [194]:
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
])

train_ds = datasets.CIFAR10(root=".", train=True, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=2)

imgs, labels = next(iter(train_loader))
imgs.shape, labels.shape


(torch.Size([32, 3, 224, 224]), torch.Size([32]))

In [195]:
GRID=14

def sample_block_mask(grid=14,scale=(0.15,0.20),aspect=(0.75,1.5)):
    area= grid*grid
    target_area=random.uniform(*scale) * area
    asp=random.uniform(*aspect)


    h=int(round(math.sqrt(target_area * asp)))
    w=int(round(math.sqrt(target_area/asp)))
    h=max(1,min(grid,h))
    w=max(1,min(grid,w))


    top=random.randint(0,grid-h)
    left=random.randint(0,grid-w)
    mask=torch.zeros(grid,grid,dtype=torch.bool)
    mask[top:top+h,left:left+w] = True

    return mask



def sample_context_mask(grid=14,scale=(0.85,1.0)):
    area= grid*grid
    ctx_area=random.uniform(*scale)*area
    side=int(round(math.sqrt(ctx_area)))
    side=max(1,min(grid,side))


    top=random.randint(0,grid-side)
    left=random.randint(0,grid-side)

    mask=torch.zeros(grid,grid,dtype=torch.bool)
    mask[top:top+side,left:left+side]=True
    return mask


def remove_overlap(ctx_mask, target_masks):
    for tm in target_masks:
        ctx_mask = ctx_mask & (~tm)
    return ctx_mask

def mask_to_idx(mask2d):
    return torch.where(mask2d.flatten())[0]

In [196]:
M = 4
targets = [sample_block_mask(GRID) for _ in range(M)]
ctx = sample_context_mask(GRID)
ctx = remove_overlap(ctx, targets)

ctx_idx = mask_to_idx(ctx)
tgt_idx = mask_to_idx(torch.stack(targets).any(dim=0))

ctx_idx.numel(), tgt_idx.numel()


(123, 73)

**Minimal ViT encode**

In [197]:
class TinyViT(nn.Module):
    def __init__(self, dim=256, depth=4, heads=8):
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=dim,
            nhead=heads,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, depth)

    def forward(self, x):
        # x: [B, N, D]
        return self.encoder(x)


**Patch embedding**

In [198]:
class PatchEmbed(nn.Module):
    def __init__(self,img_size=224,patch=16,dim=256):
        super().__init__()
        self.grid=img_size // patch
        self.proj=nn.Conv2d(3,dim,kernel_size=patch,stride=patch)


    def forward(self,x):
        x=self.proj(x)
        x=x.flatten(2).transpose(1,2)
        return x

**Build student / teacher / predictor**

In [199]:
patch_embed = PatchEmbed().to(device)
context_encoder = TinyViT(dim=DIM).to(device)
target_encoder  = TinyViT(dim=DIM).to(device)
predictor       = TinyViT(dim=DIM).to(device)

In [200]:
optimizer = torch.optim.AdamW(
    list(context_encoder.parameters()) +
    list(predictor.parameters()),
    lr=1e-4,
    weight_decay=1e-4
)

**Freeze Teacher**

In [201]:
for p in target_encoder.parameters():
    p.requires_grad=False

**Mask Tokens**

In [202]:
mask_tokens=nn.Parameter(torch.zeros(1,1,DIM))

**Forward pass sanity check**

In [204]:
B=2
dummy_img=torch.randn(B,3,224,224)

tokens=patch_embed(dummy_img.device)
print("All Tokens: ",tokens.shape)

TypeError: conv2d() received an invalid combination of arguments - got (torch.device, Parameter, Parameter, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias = None, tuple of ints stride = 1, tuple of ints padding = 0, tuple of ints dilation = 1, int groups = 1)
      didn't match because some of the arguments have invalid types: (!torch.device!, !Parameter!, !Parameter!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, !int!)
 * (Tensor input, Tensor weight, Tensor bias = None, tuple of ints stride = 1, str padding = "valid", tuple of ints dilation = 1, int groups = 1)
      didn't match because some of the arguments have invalid types: (!torch.device!, !Parameter!, !Parameter!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, !int!)


**Context-only tokens**

In [None]:
ctx_tokens=tokens[:,ctx_idx]
print("Context tokens: ",ctx_tokens.shape)

In [None]:
M = 4
target_masks = [sample_block_mask(GRID) for _ in range(M)]


**Target tokens (teacher)**

In [None]:
with torch.no_grad():
    target_repr=target_encoder(tokens)
    target_repr=target_repr[:,tgt_idx]
print("Target Repr: ",target_repr.shape)

In [None]:
def extract_target_repr(encoder_outputs, target_masks):
    """
    encoder_outputs: [B, N, D]
    target_masks: list of [grid, grid] boolean masks
    """
    B, N, D = encoder_outputs.shape
    grid = int(N ** 0.5)

    outputs = []
    for mask in target_masks:
        flat_mask = mask.flatten()          # [N]
        selected = encoder_outputs[:, flat_mask, :]  # [B, num_patches, D]
        outputs.append(selected)

    # concatenate all target blocks
    return torch.cat(outputs, dim=1)        # [B, total_targets, D]


**Predictor input**

In [None]:
mask_tokens=mask_tokens.expand(B,tgt_idx.numel(),DIM)
pred_input=torch.cat([ctx_tokens,mask_tokens],dim=1)
print("Predictor Input : ",pred_input.shape)

**Predictor Output**

In [None]:
pred_out=predictor(pred_input)
predicted_targets=pred_out[:,-tgt_idx.numel():]

print("Predicted Targets: ",predicted_targets.shape)

**L2 loss (the learning signal)**

In [None]:
pred_targets = predictor(ctx_tokens)

T_pred = pred_targets.shape[1]
T_tgt  = target_repr.shape[1]
T = min(T_pred, T_tgt)   # ðŸ”‘ critical line

pred_targets = pred_targets[:, :T, :]
target_repr  = target_repr[:, :T, :]

loss = ((pred_targets - target_repr) ** 2).mean()

**Backprop (ONLY through context + predictor)**

In [None]:
optimizer.zero_grad()
loss.backward()
optimizer.step()


**EMA update**

In [None]:
@torch.no_grad()
def ema_update(target, online, momentum=0.996):
    for p_t, p_o in zip(target.parameters(), online.parameters()):
        p_t.data = momentum * p_t.data + (1 - momentum) * p_o.data


In [None]:
ema_update(target_encoder, context_encoder)


**Sanity Check:**   

We want to verify one thing only:

Does the loss go down over a few steps?

If it doesnâ€™t â†’ something is fundamentally wrong.
If it does â†’ JEPA is working.


In [None]:
y_full=tokens

In [None]:
B, N, D = tokens.shape
keep_ratio = 0.8
num_keep = int(N * keep_ratio)

perm = torch.randperm(N)
keep_idx = perm[:num_keep]

x_ctx = tokens[:, keep_idx] 

In [None]:
ctx_tokens=context_encoder(x_ctx)

In [None]:
losses = []

for step in range(20):
    optimizer.zero_grad()

    # ----------------------------------
    # 1) Get a FRESH batch & embeddings
    # ----------------------------------
    imgs, _ = next(iter(train_loader))
    imgs = imgs.to(device)

    tokens = patch_embed(imgs)        # [B, 196, D]
    B, N, D = tokens.shape

    # ----------------------------------
    # 2) Random mask (JEPA-style)
    # ----------------------------------
    keep_ratio = 0.8
    num_keep = int(N * keep_ratio)

    perm = torch.randperm(N)
    keep_idx = perm[:num_keep]
    pred_idx = perm[num_keep:]

    x_ctx = tokens[:, keep_idx]       # context tokens
    y_tgt = tokens[:, pred_idx]       # target tokens (same order!)

    # ----------------------------------
    # 3) Teacher (NO grad)
    # ----------------------------------
    with torch.no_grad():
        target_repr = target_encoder(y_tgt)

    # ----------------------------------
    # 4) Student
    # ----------------------------------
    ctx_repr = context_encoder(x_ctx)
    pred_repr = predictor(ctx_repr)

    # ----------------------------------
    # 5) Align lengths (safe & correct)
    # ----------------------------------
    T = min(pred_repr.shape[1], target_repr.shape[1])
    pred_repr   = pred_repr[:, :T]
    target_repr = target_repr[:, :T]

    # ----------------------------------
    # 6) L2 loss in representation space
    # ----------------------------------
    loss = F.mse_loss(pred_repr, target_repr)

    # ----------------------------------
    # 7) Backprop ONLY student
    # ----------------------------------
    loss.backward()
    optimizer.step()

    # ----------------------------------
    # 8) EMA teacher update
    # ----------------------------------
    ema_update(target_encoder, context_encoder, momentum=0.996)

    losses.append(loss.item())
    print(f"Step {step}: loss = {loss.item():.4f}")

In [205]:
# =========================
# JEPA TOY â€” CLEAN & WORKING
# =========================

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import random

# -------------------------
# Device
# -------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# -------------------------
# Data
# -------------------------
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

dataset = datasets.CIFAR10(root=".", train=True, download=True, transform=transform)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

imgs, _ = next(iter(loader))
imgs = imgs.to(device)

# -------------------------
# Patch Embedding
# -------------------------
class PatchEmbed(nn.Module):
    def __init__(self, patch=16, dim=256):
        super().__init__()
        self.proj = nn.Conv2d(3, dim, kernel_size=patch, stride=patch)

    def forward(self, x):
        x = self.proj(x)                 # [B, D, 14, 14]
        x = x.flatten(2).transpose(1,2) # [B, 196, D]
        return x

# -------------------------
# Tiny Transformer
# -------------------------
class TinyViT(nn.Module):
    def __init__(self, dim=256, depth=4, heads=8):
        super().__init__()
        layer = nn.TransformerEncoderLayer(
            d_model=dim,
            nhead=heads,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(layer, depth)

    def forward(self, x):
        return self.encoder(x)

# -------------------------
# JEPA Components
# -------------------------
DIM = 256

patch_embed = PatchEmbed().to(device)
context_encoder = TinyViT(DIM).to(device)
target_encoder  = TinyViT(DIM).to(device)
predictor       = TinyViT(DIM).to(device)

# Freeze teacher
for p in target_encoder.parameters():
    p.requires_grad = False

optimizer = torch.optim.AdamW(
    list(context_encoder.parameters()) +
    list(predictor.parameters()),
    lr=1e-4
)

# -------------------------
# Random Context Mask
# -------------------------
def random_context(tokens, keep_ratio=0.8):
    B, N, D = tokens.shape
    k = int(N * keep_ratio)
    idx = torch.randperm(N, device=tokens.device)[:k]
    return tokens[:, idx]

# -------------------------
# EMA Update
# -------------------------
@torch.no_grad()
def ema_update(teacher, student, m=0.996):
    for pt, ps in zip(teacher.parameters(), student.parameters()):
        pt.data = m * pt.data + (1 - m) * ps.data

# -------------------------
# Training Loop
# -------------------------
print("\nTraining...\n")
for step in range(20):
    optimizer.zero_grad()

    # Patchify
    tokens = patch_embed(imgs)            # [B, 196, D]

    # Context masking
    x_ctx = random_context(tokens)        # [B, ~156, D]

    # Student forward
    ctx_repr = context_encoder(x_ctx)
    pred = predictor(ctx_repr)

    # Teacher forward (NO GRAD)
    with torch.no_grad():
        target = target_encoder(tokens)

    # Shape alignment
    T = min(pred.shape[1], target.shape[1])
    loss = F.mse_loss(pred[:, :T], target[:, :T])

    # Backprop
    loss.backward()
    optimizer.step()

    # EMA
    ema_update(target_encoder, context_encoder)

    print(f"Step {step:02d} | Loss: {loss.item():.4f}")

print("\nDone.")

Device: cuda

Training...

Step 00 | Loss: 1.8689
Step 01 | Loss: 1.2484
Step 02 | Loss: 0.9213
Step 03 | Loss: 0.7102
Step 04 | Loss: 0.5636
Step 05 | Loss: 0.4562
Step 06 | Loss: 0.3702
Step 07 | Loss: 0.3047
Step 08 | Loss: 0.2565
Step 09 | Loss: 0.2226
Step 10 | Loss: 0.1995
Step 11 | Loss: 0.1826
Step 12 | Loss: 0.1711
Step 13 | Loss: 0.1645
Step 14 | Loss: 0.1602
Step 15 | Loss: 0.1581
Step 16 | Loss: 0.1563
Step 17 | Loss: 0.1551
Step 18 | Loss: 0.1537
Step 19 | Loss: 0.1523

Done.
