# V-JEPA 2 Visualizer Decoder
Trains a CNN decoder that maps V-JEPA 2 latent patch tokens → reconstructed video frames.

**Runtime:** Set to **T4 GPU** via Runtime → Change runtime type

In [None]:
# Cell 1: Install dependencies
!pip install -q transformers accelerate yt-dlp opencv-python-headless

In [None]:
# Cell 2: Load V-JEPA 2 encoder (frozen)
import torch
from transformers import AutoModel

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {DEVICE}')

encoder = AutoModel.from_pretrained('facebook/vjepa2-vitl-fpc64-256', trust_remote_code=True)
encoder = encoder.to(DEVICE, dtype=torch.float16).eval()
for p in encoder.parameters():
    p.requires_grad = False

print('Encoder loaded and frozen')
print(f'Params: {sum(p.numel() for p in encoder.parameters()):,}')

In [None]:
# Cell 3: Build training dataset
# Download Big Buck Bunny, extract 8-frame clips, encode with V-JEPA 2
import os, subprocess, cv2, numpy as np
from torchvision import transforms
from pathlib import Path
from PIL import Image

DDIR = Path('/tmp/jd')
DDIR.mkdir(exist_ok=True)

ET = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
TT = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# Download video
vp = '/tmp/bbb.mp4'
if not os.path.exists(vp) or os.path.getsize(vp) < 10000:
    print('Downloading Big Buck Bunny...')
    subprocess.run([
        'yt-dlp', '--quiet',
        '-f', 'bestvideo[height<=360][ext=mp4]/best[height<=360]',
        '--download-sections', '*0:00-0:40',
        '-o', vp,
        'https://www.youtube.com/watch?v=_FjuOVeahA8'
    ], timeout=90)

# Extract clips and encode
cap = cv2.VideoCapture(vp)
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
starts = np.linspace(0, max(0, total - 60), 40, dtype=int)
n_saved = 0

for ci, s in enumerate(starts):
    fe, ft = [], []
    for t in range(8):
        cap.set(cv2.CAP_PROP_POS_FRAMES, int(s + t * 4))
        ret, fr = cap.read()
        if not ret:
            break
        p = Image.fromarray(cv2.cvtColor(fr, cv2.COLOR_BGR2RGB))
        fe.append(ET(p))
        ft.append(TT(p))
    if len(fe) < 8:
        continue
    vid = torch.stack(fe).permute(1, 0, 2, 3).unsqueeze(0).to(DEVICE, dtype=torch.float16)
    with torch.no_grad():
        emb = encoder(pixel_values_videos=vid).last_hidden_state[0].cpu().float()
    torch.save({'e': emb, 't': ft[3]}, DDIR / f'p{ci:04d}.pt')
    n_saved += 1

cap.release()
print(f'Dataset: {n_saved} pairs saved to {DDIR}')

In [None]:
# Cell 4: Define CNN Decoder + train 25 epochs
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

class DS(Dataset):
    def __init__(self, d):
        self.f = sorted(Path(d).glob('p*.pt'))
    def __len__(self):
        return len(self.f)
    def __getitem__(self, i):
        d = torch.load(self.f[i])
        return d['e'], d['t']

loader = DataLoader(DS('/tmp/jd'), batch_size=4, shuffle=True)
print(f'Loader: {len(loader)} batches/epoch')


class Decoder(nn.Module):
    """
    Maps V-JEPA 2 tokens [B, 1024, 1024] -> RGB frame [B, 3, 256, 256]
    Token layout: 1024 = 4 temporal x 16 spatial x 16 spatial
    """
    def __init__(self, D=1024, H=384):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(D, H),
            nn.LayerNorm(H),
            nn.GELU()
        )

        def up(i, o):
            return nn.Sequential(
                nn.ConvTranspose2d(i, o, 4, 2, 1),
                nn.GroupNorm(8, o),
                nn.GELU(),
                nn.Conv2d(o, o, 3, padding=1),
                nn.GroupNorm(8, o),
                nn.GELU()
            )

        self.cnn = nn.Sequential(
            up(H, 256),    # 16 -> 32
            up(256, 128),  # 32 -> 64
            up(128, 64),   # 64 -> 128
            up(64, 32),    # 128 -> 256
            nn.Conv2d(32, 3, 3, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        B = x.shape[0]
        # Project + temporal mean pool: [B,1024,D] -> [B,16,16,H] -> [B,H,16,16]
        x = self.proj(x).view(B, 4, 16, 16, -1).mean(1).permute(0, 3, 1, 2)
        return self.cnn(x)


dec = Decoder().to(DEVICE)
print(f'Decoder params: {sum(p.numel() for p in dec.parameters()):,}')

opt = torch.optim.AdamW(dec.parameters(), lr=2e-4, weight_decay=1e-4)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=25)
mse = nn.MSELoss()

def loss_fn(p, t):
    pixel_loss = mse(p, t)
    # Gradient sharpness term
    dy_p = p[:, :, 1:] - p[:, :, :-1]
    dy_t = t[:, :, 1:] - t[:, :, :-1]
    dx_p = p[:, :, :, 1:] - p[:, :, :, :-1]
    dx_t = t[:, :, :, 1:] - t[:, :, :, :-1]
    return pixel_loss + 0.1 * (mse(dy_p, dy_t) + mse(dx_p, dx_t))


# --- Train ---
print('Training...')
hist = []
for ep in range(25):
    dec.train()
    el = 0
    for e, t in loader:
        e, t = e.to(DEVICE), t.to(DEVICE)
        p = dec(e)
        l = loss_fn(p, t)
        opt.zero_grad()
        l.backward()
        torch.nn.utils.clip_grad_norm_(dec.parameters(), 1.0)
        opt.step()
        el += l.item()
    sched.step()
    hist.append(el / len(loader))
    if (ep + 1) % 5 == 0:
        print(f'Epoch {ep+1:2d}/25 | loss={hist[-1]:.5f}')

torch.save(dec.state_dict(), '/tmp/vjepa_dec.pt')
print('Weights saved to /tmp/vjepa_dec.pt')

# Loss curve
plt.figure(figsize=(7, 2.5))
plt.plot(hist, lw=2, color='steelblue')
plt.grid(alpha=0.3)
plt.title('V-JEPA 2 Decoder — Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.tight_layout()
plt.show()

In [None]:
# Cell 5: Visualise reconstructions
dec.eval()
ds = DS('/tmp/jd')
idxs = np.random.choice(len(ds), 5, replace=False)

fig, ax = plt.subplots(2, 5, figsize=(18, 7))
for col, idx in enumerate(idxs):
    e, t = ds[idx]
    with torch.no_grad():
        pred = dec(e.unsqueeze(0).to(DEVICE))[0].cpu()
    orig = t.permute(1, 2, 0).numpy()
    rec = pred.clamp(0, 1).permute(1, 2, 0).numpy()
    psnr = -10 * np.log10(((orig - rec) ** 2).mean() + 1e-8)

    ax[0, col].imshow(orig)
    ax[0, col].set_title(f'Original #{idx}', fontsize=8)
    ax[0, col].axis('off')

    ax[1, col].imshow(rec)
    ax[1, col].set_title(f'PSNR {psnr:.1f} dB', fontsize=8)
    ax[1, col].axis('off')

ax[0, 0].set_ylabel('Ground Truth', fontsize=10)
ax[1, 0].set_ylabel('V-JEPA 2 → Decoder', fontsize=10)
plt.suptitle(
    'V-JEPA 2 Latent Space → Reconstructed Frames\n'
    '(CNN Decoder trained on MSE + Gradient Loss, 25 epochs)',
    fontsize=12, fontweight='bold'
)
plt.tight_layout()
plt.show()
print('Done!')