<table style="background-color:#FFFFFF">   
  <tr>     
  <td><img src="https://upload.wikimedia.org/wikipedia/commons/9/95/Logo_EPFL_2019.svg" width="150x"/>
  </td>     
  <td>
  <h1> <b>CS-461: Foundation Models and Generative AI</b> </h1>
  Prof. Charlotte Bunne  
  </td>   
  </tr>
</table>

# ðŸ“š  Exercise Session (Coding Part) - 11

Overview of the coding part:

* [**TASK 1:** Implementing Joint-Embedding Predictive Architecture (JEPA)](#task_name_1)
    - [Subtask A: I-JEPA](#subtask_name_1_A)
    - [Subtask B: V-JEPA](#subtask_name_1_B)

Let's import some libraries that we will often use during the session:

In [None]:
import copy
import torch
import math
from functools import partial
import numpy as np
import torch.nn.functional as F

import torch.nn as nn
from utils import * # contains helper functions for ViT

<a name="task_name_1"></a>
## Task 1: Implementing Joint-Embedding Predictive Architecture (JEPA)

In this exercise, we will look into the implementation details of JEPA. We will start from I-JEPA([Link to paper](https://arxiv.org/pdf/2301.08243)) and then move forward to V-JEPA ([Link to paper](https://arxiv.org/pdf/2404.08471)). We won't be training actual models since they typically require to be trained on ImageNet1k or large amount of video frames to get meaningful results, but the code should give you a clear, hands-on understanding of how these architectures are implemented and trained in practice.

**Background**: 
- I-JEPA aims to learn **static** visual representations. It learns the representations from **images**, by predicting the latent representation of masked regions in the same image. Due to the absence of any temporal modeling, it cannot model how the world evolves, but the same underlying idea can scale from static images to videos.
- Extending from I-JEPA, V-JEPA trains on **video** frames. The model must infer the underlying state of the world that drives video evolution.

<a name="subtask_name_1_A"></a>
### Task 1.A: I-JEPA

<img src="https://scontent-zrh1-1.xx.fbcdn.net/v/t39.2365-6/353824985_215991878033819_2765220267220815437_n.png?_nc_cat=106&ccb=1-7&_nc_sid=e280be&_nc_ohc=C35AdiBHEPIQ7kNvwG9_1L1&_nc_oc=AdnV-LWXrT7MAdYsfPcR1eEmxwRoX_RlmXeTjCmUWIJ81DvjYyPbNfxeDZ3V2Mf7agKE1kK7Gmr0bgejZj5GOawF&_nc_zt=14&_nc_ht=scontent-zrh1-1.xx&_nc_gid=keNBHgcG3wwfMSPpHGMONA&oh=00_AfgBTCoNceDfXZNOf9Ijb6rQcogZkjCEpbfG_4xpJnUajg&oe=694127D1" width="800">

I-JEPA can be broken into the following modules:
1.	Context Encoder (Neural network)
2.	Target Encoder (EMA updated version of the context encoder)
3.	Predictor (Another neural network)
4.	Loss Function (Dashed lines shown in the plot. E.g. L1 or L2 distance)
5.  Masking strategy (Produce context and target from the image)

Let's look into them one by one.

#### 1. Context Encoder
The context encoder is typically implemented as a classic vision Transformer (ViT), which we have seen previously, in exercise 7 for example. It contains:
- **Patch Embedding**: 224Ã—224 image will be cut into 16Ã—16 patches with a 16 patch size.
- **Position Embeddings**: Position embedding for each patch
- **Transformer Blocks**: Self-attention + MLP layers
- **Masking function**: Only keep the context part that will be forwarded into the context encoder. 

**TODO**: 
- Implement the patch embedding, which converts an image into a sequence of patch embeddings that can be processed by a Transformer.
- Implement the masking function. Note that here the 'masks' specify patches that we want to *keep*.

In [None]:
class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        num_patches = (img_size // patch_size) * (img_size // patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.patch_shape = (img_size // patch_size, img_size // patch_size)
        
        # [B, C, H, W] -> [B, embed_dim, H/P, W/P]
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # [B, C, H, W] -> [B, embed_dim, H/P, W/P] -> [B, embed_dim, H/P * W/P] -> [B, H/P * W/P, embed_dim]
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x

def apply_masks(x, masks):
    """
    :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)]
    :param masks: *list* of tensors containing indices of patches in [N] to keep
    """
    B = x.shape[0]
    all_x = []
    for m in masks:
        mask_keep = m.unsqueeze(0).unsqueeze(-1).expand(B, -1, x.size(-1))
        all_x += [torch.gather(x, dim=1, index=mask_keep)]
    return torch.cat(all_x, dim=0)

class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                    drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, return_attention=False):
        y, attn = self.attn(self.norm1(x))
        if return_attention:
            return attn
        x = x + self.drop_path(y)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

class VisionTransformer(nn.Module):
    """ Vision Transformer """
    def __init__(
        self,
        img_size=[224],
        patch_size=16,
        in_chans=3,
        embed_dim=768,
        predictor_embed_dim=384,
        depth=12,
        predictor_depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_layer=nn.LayerNorm,
        init_std=0.02,
        **kwargs
    ):
        super().__init__()
        self.num_features = self.embed_dim = embed_dim
        self.num_heads = num_heads
        # --
        self.patch_embed = PatchEmbed(
            img_size=img_size[0],
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches
        # --
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), requires_grad=False)
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1],
                                            int(self.patch_embed.num_patches**.5),
                                            cls_token=False)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
        # --
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)
        # ------
        self.init_std = init_std
        self.apply(self._init_weights)
        self.fix_init_weight()

    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=self.init_std)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            trunc_normal_(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x, masks=None):
        if masks is not None:
            if not isinstance(masks, list):
                masks = [masks]

        # -- patchify x
        x = self.patch_embed(x)
        B, N, D = x.shape

        # -- add positional embedding to x
        x = x + self.pos_embed

        # -- mask x
        if masks is not None:
            x = apply_masks(x, masks)

        # -- fwd prop
        for i, blk in enumerate(self.blocks):
            x = blk(x)

        if self.norm is not None:
            x = self.norm(x)

        return x


To initiate a model, we simply do:

In [None]:
def vit_tiny(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model

context_encoder = vit_tiny()

#### 2. Target Encoder
The target encoder's weights are identical to the context-encoder weights at initialization, and updated via an
exponential moving average thereafter. We have seen the same idea in many self-supervised learning approaches so far.

**TODO**: Implement EMA (a classic exercise that won't hurt to be more proficient)

In [None]:
target_encoder = copy.deepcopy(context_encoder)

In [None]:
def ema_update(context_encoder, target_encoder, beta):
    student_model = context_encoder.eval()
    teacher_model = target_encoder.eval()
    with torch.no_grad():
        for student_param, teacher_param in zip(student_model.parameters(), teacher_model.parameters()):
            teacher_param.data.mul_(other=beta).add_(other=student_param.data, alpha=1 - beta)

#### 3. Predictor
The predictor is another ViT that takes the context encoder output and, conditioned on positional tokens, predicts the representations of a target block at a specific location.

**TODO**: Can you tell how the predictor ViT is different from the encoder ViT?

In [None]:
class VisionTransformerPredictor(nn.Module):
    """ Vision Transformer """
    def __init__(
        self,
        num_patches,
        embed_dim=768,
        predictor_embed_dim=384,
        depth=6,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_layer=nn.LayerNorm,
        init_std=0.02,
        **kwargs
    ):
        super().__init__()
        self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim))
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        # --
        self.predictor_pos_embed = nn.Parameter(torch.zeros(1, num_patches, predictor_embed_dim),
                                                requires_grad=False)
        predictor_pos_embed = get_2d_sincos_pos_embed(self.predictor_pos_embed.shape[-1],
                                                        int(num_patches**.5),
                                                        cls_token=False)
        self.predictor_pos_embed.data.copy_(torch.from_numpy(predictor_pos_embed).float().unsqueeze(0))
        # --
        self.predictor_blocks = nn.ModuleList([
            Block(
                dim=predictor_embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth)])
        self.predictor_norm = norm_layer(predictor_embed_dim)
        self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
        # ------
        self.init_std = init_std
        trunc_normal_(self.mask_token, std=self.init_std)
        self.apply(self._init_weights)
        self.fix_init_weight()

    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.predictor_blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=self.init_std)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            trunc_normal_(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x, masks_x, masks):
        assert (masks is not None) and (masks_x is not None), 'Cannot run predictor without mask indices'

        if not isinstance(masks_x, list):
            masks_x = [masks_x]

        if not isinstance(masks, list):
            masks = [masks]

        # -- Batch Size
        B = len(x) // len(masks_x)

        # -- map from encoder-dim to pedictor-dim
        x = self.predictor_embed(x)

        # -- add positional embedding to x tokens
        x_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1)
        x += apply_masks(x_pos_embed, masks_x)

        _, N_ctxt, D = x.shape

        # -- concat mask tokens to x
        pos_embs = self.predictor_pos_embed.repeat(B, 1, 1)
        pos_embs = apply_masks(pos_embs, masks)
        pos_embs = repeat_interleave_batch(pos_embs, B, repeat=len(masks_x))
        # --
        pred_tokens = self.mask_token.repeat(pos_embs.size(0), pos_embs.size(1), 1)
        # --
        pred_tokens += pos_embs
        x = x.repeat(len(masks), 1, 1)
        x = torch.cat([x, pred_tokens], dim=1)

        # -- fwd prop
        for blk in self.predictor_blocks:
            x = blk(x)
        x = self.predictor_norm(x)

        # -- return preds for mask tokens
        x = x[:, N_ctxt:]
        x = self.predictor_proj(x)

        return x

Similar to the context encoder, we can initiate it by:

In [None]:
def vit_predictor(num_patches=196, **kwargs):
    model = VisionTransformerPredictor(
        num_patches=num_patches, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs)
    return model

predictor = vit_predictor()

#### 4. Loss Function
The loss is simply the average distance between the predicted patch-level representations (i.e., outputs from the predictor) and the target patch-level representation (i.e., outputs from the target encoder).

**TODO**: In practice, JEPA uses a Smooth L1 loss instead of an L2 loss. What are the potential benefits?

In [None]:
def loss_fn(z, h):
    loss = F.smooth_l1_loss(z, h)
    return loss

#### 5. Masking strategy for the context and target

Having implemented the models and loss function, now we look into how an input image is processed into an input (context) and target. 

**TODO**: We first need a function to get a rectangle mask given a scale (0 to 1) of the image and an aspect ratio.

**Hint:** Let $H$ and $W$ be `patch_h` and `patch_w`, and let $h_m$ and $w_m$ be the mask dimensions to compute.

- **Area constraint:** The mask should cover a fraction $s$ of the total area: $h_m \cdot w_m = s \cdot H \cdot W$
- **Aspect ratio constraint:** The mask should have aspect ratio $r$: $\frac{w_m}{h_m} = r$

Solve these two equations for $h_m$ and $w_m$.

In [None]:
def get_mask(patch_h, patch_w, scale, aspect_ratio):
    """
    :param patch_h: number of patches along height
    :param patch_w: number of patches along width
    :param scale: scale of the mask with respect to the image (0 to 1)
    :param aspect_ratio: aspect ratio of the mask (width / height)
    :return: height and width of the mask
    """
    mask_h = int(round(math.sqrt(scale * patch_h * patch_w / aspect_ratio)))
    mask_w = int(round(math.sqrt(scale * patch_h * patch_w * aspect_ratio)))
    
    # Clamp to valid dimensions
    mask_h = min(mask_h, patch_h)
    mask_w = min(mask_w, patch_w)

    return mask_h, mask_w

The targets to be predicted by the predictor can then be obtained by sampling masks with specific height and width multiple times on the image. We will record the patch indices of each target, together with the union of all targets (which will be used in getting context batches).

**TODO**: Implement code to sample mask location randomly.

In [None]:
def get_target(patch_dim, aspect_ratio, scale, M):  
    #get the patch dimensions
    patch_h, patch_w = patch_dim
    block_h, block_w = get_mask(patch_h, patch_w, scale, aspect_ratio)
    target_patches = []
    all_patches = []
    for z in range(M):
        #get a random starting patch
        start_patch_h = torch.randint(0, patch_h - block_h+1, (1,)).item()
        start_patch_w = torch.randint(0, patch_w - block_w+1, (1,)).item()
        start_patch = start_patch_h * patch_w + start_patch_w

        patches = []
        #get the patches in the target block
        for i in range(block_h):
            for j in range(block_w):
                patches.append(start_patch + i * patch_w + j)
                if start_patch + i * patch_w + j not in all_patches:
                    all_patches.append(start_patch + i * patch_w + j)
                
        #get the target block
        target_patches.append(torch.tensor(patches))
    return target_patches, all_patches

The context used as the input of the predictor can be obtained by sampling a mask with specific height and width on the image. To avoid the overlapping between the context and target region (which will make the prediction trivial), we additionally remove the overlapping part from the context. This means the context is not necessarily rectangular. We will record the patch indices of the context.

In [None]:
def get_context(patch_dim, aspect_ratio, scale, target_patches):
    patch_h, patch_w = patch_dim
    block_h, block_w = get_mask(patch_h, patch_w, scale, aspect_ratio)

    #get a random patch
    start_patch_h = torch.randint(0, patch_h - block_h+1, (1,)).item()
    start_patch_w = torch.randint(0, patch_w - block_w+1, (1,)).item()
    start_patch = start_patch_h * patch_w + start_patch_w

    #get the patches in the context_block
    patches = []
    for i in range(block_h):
        for j in range(block_w):
            if start_patch + i * patch_w + j not in target_patches: #remove the target patches to avoid overlapping
                patches.append(start_patch + i * patch_w + j)
    context_patches = [torch.tensor(patches)]
    return context_patches

#### Finally: Training Loop

Now that every component is ready, we can start writing the training loop.

Note that the target blocks are obtained by masking the output of the target encoder rather than the input, unlike the context blocks. This distinction is crucial to ensure target representations of a high semantic level.

**TODO**: Finish the missing part of the forward function.

In [None]:
def forward(x, context_encoder, target_encoder, predictor, target_aspect_ratio, target_scale, context_scale, context_aspect_ratio=1, M=4):
    
    # sample aspect ratios and scales randomly
    target_aspect_ratio = np.random.uniform(target_aspect_ratio[0], target_aspect_ratio[1])
    target_scale = np.random.uniform(target_scale[0], target_scale[1])
    context_aspect_ratio = context_aspect_ratio
    context_scale = np.random.uniform(context_scale[0], context_scale[1])
    patch_dim = context_encoder.patch_embed.patch_shape
    
    #get the target block
    target_x = target_encoder(x) # target_x: [B, N, D]
    target_patches, all_patches = get_target(patch_dim, target_aspect_ratio, target_scale, M) # target_patches: M lists of tensors each containing L1 patch indices, all_patches: list of all target patch indices, used for detect overlapping with context block
    target_blocks = torch.zeros((M, target_x.shape[0], len(target_patches[0]), target_x.shape[2])) # target_blocks: [M, B, L1, D]
    for i in range(M):
        target_blocks[i] = target_x[:, target_patches[i], :]

    #get context embedding
    context_patches = get_context(patch_dim, context_aspect_ratio, context_scale, all_patches) # context_patches: list of tensor containing L2 patch indices in context block
    context_x = context_encoder(x, masks=context_patches) # context_x: [B, L2, D]

    #get the prediction blocks, predict each target block separately
    prediction_blocks = predictor(context_x, masks_x=context_patches, masks=target_patches) # prediction_blocks: [M*B, L1, D]
    
    target_blocks = target_blocks.reshape(M*target_blocks.shape[1], target_blocks.shape[2], target_blocks.shape[3]) # target_blocks: [M*B, L1, D]
    loss = loss_fn(prediction_blocks, target_blocks)

    return loss

With the forward function, the training loop simply goes:

In [None]:
def train_ijepa(context_encoder, target_encoder, predictor, train_loader, num_epochs=10, lr=1e-4, ema_momentum=0.996, device='cuda'):
    context_encoder.to(device)
    target_encoder.to(device)
    predictor.to(device)
    
    # only context encoder and predictor have gradients
    target_encoder.eval()
    for param in target_encoder.parameters():
        param.requires_grad = False

    # set up optimizer, only optimize context encoder and predictor
    optimizer = torch.optim.AdamW(list(context_encoder.parameters()) + list(predictor.parameters()), lr=lr)

    for epoch in range(num_epochs):
        context_encoder.train()
        predictor.train()
        for batch_idx, (images, _) in enumerate(train_loader):
            images = images.to(device)

            loss = forward(images, context_encoder, target_encoder, predictor,
                            target_aspect_ratio=(0.75, 1.5),
                            target_scale=(0.15, 0.2),
                            context_scale=(0.85, 1.0),
                            context_aspect_ratio=1,
                            M=4)

            # Gradient update for the context encoder and predictor
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # EMA update for target encoder
            ema_update(context_encoder, target_encoder, ema_momentum)

            if batch_idx % 100 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

    return context_encoder, target_encoder, predictor

<a name="subtask_name_1_B"></a>
### Task 1.B: V-JEPA

<img src="https://scontent-zrh1-1.xx.fbcdn.net/v/t39.8562-6/427979095_923374739138637_7724069779118251294_n.png?_nc_cat=108&ccb=1-7&_nc_sid=f537c7&_nc_ohc=vdqIaE4r7s0Q7kNvwFZXN46&_nc_oc=AdkjBxb7JrcpOBuBTYwjozvDcxwj7Am03LkIj2znomp8wJcLp4uyrBfBWhnl5AWCIlJs6AYx97IKjvHgJQyA1tqM&_nc_zt=14&_nc_ht=scontent-zrh1-1.xx&_nc_gid=p_U2VSLILV3sF-6HdQcqdg&oh=00_Afgejhnv74y3SbrjXD-wntUvJqzYsXtR251GqtwCB2X24Q&oe=692CE30A" width="800">

Extending from I-JEPA, V-JEPA trains on videos and treats videos as 3D images. The key differences are therefore on:
- PatchEmd2D $\rightarrow$ PatchEmd3D
- MaskSampling2D $\rightarrow$ MaskSampling3D
- Other modifications to handle 3D input including 3D positional embeddings (not included)

In [None]:
class PatchEmbed3D(nn.Module):
    """
    Image to Patch Embedding
    """

    def __init__(
        self,
        patch_size=16,
        tubelet_size=2,
        in_chans=3,
        embed_dim=768,
    ):
        super().__init__()
        self.patch_size = patch_size
        self.tubelet_size = tubelet_size

        self.proj = nn.Conv3d(
            in_channels=in_chans,
            out_channels=embed_dim,
            kernel_size=(tubelet_size, patch_size, patch_size),
            stride=(tubelet_size, patch_size, patch_size),
        )

    def forward(self, x, **kwargs):
        B, C, T, H, W = x.shape
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x

While the masking strategy becomes 3D, the authors also make changes on how the masks are sampled and how the context and target are defined:
- They leverage two types of masks: 
  - Short-range masks, where they take the union of 8 randomly sampled target blocks covering 15% of each frame.
  - Long-range masks, where they take the union of 2 randomly sampled target blocks covering 70% of each frame. 
  - In both cases, the aspect ratio for all sampled blocks is randomly chosen in the range (0.75, 1.5).
  - In both cases, the same spatial mask is applied to the full temporal dimension.
- The sampled mask is considered as the target, and the context is directly the complement of the target.

In [None]:
def get_masks(patch_dim, scale, mask_type, aspect_ratio_range=(0.75, 1.5)):
    """
    Generate V-JEPA target and context masks.
    
    Args:
        patch_dim: Tuple of (num_frames, patch_height, patch_width)
        scale: Fraction of each frame to mask (0.15 for short, 0.70 for long)
        mask_type: "short" for short-range (8 blocks, 15% each) or 
                   "long" for long-range (2 blocks, 70% each)
        aspect_ratio_range: Range for random aspect ratio sampling
    
    Returns:
        target_patches: List containing tensor of masked patch indices
        context_patches: List containing tensor of context patch indices
    """
    patch_t, patch_h, patch_w = patch_dim
    num_patches_per_frame = patch_h * patch_w
    total_patches = patch_t * num_patches_per_frame
    
    # Set parameters based on mask type
    if mask_type == "short":
        num_blocks = 8
        assert scale == 0.15, "For 'short' mask type, scale should be 0.15"
    else:  # long
        num_blocks = 2
        assert scale == 0.70, "For 'long' mask type, scale should be 0.70"
    
    # Collect all target patches (union of blocks)
    target_set = set()
    
    for _ in range(num_blocks):
        # Sample aspect ratio randomly for each block (per paper)
        min_ar, max_ar = aspect_ratio_range
        aspect_ratio = min_ar + torch.rand(1).item() * (max_ar - min_ar)
        
        # Calculate block dimensions for a single frame
        num_patches_block = int(num_patches_per_frame * scale)
        block_h = int(round(math.sqrt(num_patches_block * aspect_ratio)))
        block_w = int(round(math.sqrt(num_patches_block / aspect_ratio)))
        
        # Clamp to valid dimensions
        block_h = min(block_h, patch_h)
        block_w = min(block_w, patch_w)
        
        # Random starting position (same for all frames - creating a tube)
        start_h = torch.randint(0, max(1, patch_h - block_h + 1), (1,)).item()
        start_w = torch.randint(0, max(1, patch_w - block_w + 1), (1,)).item()
        
        # Add patches from all frames (tube spanning full temporal dimension)
        for t in range(patch_t):
            frame_offset = t * num_patches_per_frame
            for i in range(block_h):
                for j in range(block_w):
                    patch_idx = frame_offset + (start_h + i) * patch_w + (start_w + j)
                    target_set.add(patch_idx)
    
    # Context is the complement of target
    all_patches = set(range(total_patches))
    context_set = all_patches - target_set
    
    target_patches = [torch.tensor(sorted(target_set), dtype=torch.long)]
    context_patches = [torch.tensor(sorted(context_set), dtype=torch.long)]
    
    return target_patches, context_patches

#### Final remarks 

That's all for this exercise! If you are interested in this topic, feel free to read V-JEPA2 ([Link to paper](https://arxiv.org/pdf/2506.09985)) and think about: what is new compared to V-JEPA ([Link to paper](https://arxiv.org/pdf/2404.08471))?