# V-JEPA 2 Demo Notebook

This tutorial provides an example of how to load the V-JEPA 2 model in vanilla PyTorch and HuggingFace, extract a video embedding, and then predict an action class. For more details about the paper and model weights, please see https://github.com/facebookresearch/vjepa2.

In [2]:
import sys

sys.path.append("../")

In [None]:
! pip install matplotlib

In [16]:
training_gt = [
    '../data/data_sf_1/annotations/task_dji_0011_001.mp4-2023_06_02_12_02_56-mot 1.1/gt/gt.txt',
    '../data/data_sf_1/annotations/task_dji_0011_002.mp4-2023_07_24_07_42_53-mot 1.1/gt/gt.txt',
    '../data/data_sf_1/annotations/task_dji_0020_001.mp4-2023_07_24_07_07_22-mot 1.1/gt/gt.txt',
    '../data/data_sf_1/annotations/task_dji_0020_003.mp4-2023_07_19_11_08_45-mot 1.1/gt/gt.txt',
    '../data/data_sf_1/annotations/task__dji_0027_001_cut.mp4-2023_03_16_13_41_41-mot 1.1/gt/gt.txt'
    '../data/data_sf_1/annotations/task_dji_0069_001.mp4-2023_07_19_12_20_43-mot 1.1/gt/gt.txt',
    '../data/data_sf_1/annotations/task_dji_0069_002.mp4-2023_07_19_13_50_09-mot 1.1/gt/gt.txt',
    '../data/data_sf_1/annotations/task_dji_0071_001.mp4-2023_07_19_10_50_04-mot 1.1/gt/gt.txt',
    '../data/data_sf_1/annotations/task_dji_0071_003.mp4-2023_06_28_09_54_05-mot 1.1/gt/gt.txt',
    '../data/data_sf_1/annotations/task_gh020076.mp4-2023_04_06_13_05_55-mot 1.1/gt/gt.txt',
    '../data/data_sf_1/annotations/task_gh030076.mp4-2023_04_06_11_46_50-mot 1.1/gt/gt.txt',
    '../data/data_sf_1/annotations/task_gh040076.mp4-2023_08_08_14_20_35-mot 1.1/gt/gt.txt',
]


test_gt = [
    '../data/data_sf_1/annotations/task_dji_0011_003.mp4-2023_07_24_08_07_18-mot 1.1/gt/gt.txt',
    '../data/data_sf_1/annotations/task_dji_0020_002.mp4-2023_07_19_10_17_29-mot 1.1/gt/gt.txt',
    '../data/data_sf_1/annotations/task_dji_0027_002.mp4-2023_07_05_07_48_07-mot 1.1/gt/gt.txt',
    '../data/data_sf_1/annotations/task_dji_0071_002.mp4-2023_06_05_12_42_01-mot 1.1/gt/gt.txt',
    '../data/data_sf_1/annotations/task_gh010076.mp4-2023_04_06_12_15_21-mot 1.1/gt/gt.txt',
 ]

In [17]:
training_videos = [
    '../data/data_sf_1/videos/DJI_0011_001.MP4',
    '../data/data_sf_1/videos/DJI_0011_002.MP4',
    '../data/data_sf_1/videos/DJI_0020_001.MP4',
    '../data/data_sf_1/videos/DJI_0020_003.MP4',
    '../data/data_sf_1/videos/DJI_0027_001_cut.MP4',
    '../data/data_sf_1/videos/DJI_0069_001.MP4',
    '../data/data_sf_1/videos/DJI_0069_002.MP4',
    '../data/data_sf_1/videos/DJI_0071_001.MP4',
    '../data/data_sf_1/videos/DJI_0071_003.MP4',
    '../data/data_sf_1/videos/GH020076.MP4',
    '../data/data_sf_1/videos/GH030076.MP4',
    '../data/data_sf_1/videos/GH040076.MP4'
]

testing_videos = [
    '../data/data_sf_1/videos/DJI_0011_003.MP4',
    '../data/data_sf_1/videos/DJI_0020_002.MP4',
    '../data/data_sf_1/videos/DJI_0027_002.MP4',
    '../data/data_sf_1/videos/DJI_0071_002.MP4',
    '../data/data_sf_1/videos/GH010076.MP4',
]

In [3]:
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from math import ceil, sqrt

def show_feature_grid(frames, cmap='viridis', figsize=(12, 12), ncols=None):
    """
    Show a grid of scalar feature maps (H, W) as heatmaps.

    Args:
        frames (np.ndarray): (T, H, W) array of scalar frames.
        cmap (str): Colormap for heatmap (e.g., 'viridis', 'plasma').
        figsize (tuple): Overall figure size.
        ncols (int): Number of columns. If None, auto-calculated.
    """
    T = frames.shape[0]
    ncols = ncols or ceil(sqrt(T))
    nrows = ceil(T / ncols)

    fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
    axes = np.array(axes).reshape(-1)  # Flatten in case of 1D row

    vmin = frames.min()
    vmax = frames.max()

    for i in range(nrows * ncols):
        ax = axes[i]
        if i < T:
            ax.imshow(frames[i], cmap=cmap, vmin=vmin, vmax=vmax)
            ax.set_title(f"Frame {i}")
        ax.axis("off")

    plt.tight_layout()
    plt.show()


def interpolate_temporal_features_to_frame_space(features, patch_size=16, output_T=64, method='norm'):
    """
    Interpolates ViT patch features (temporal + spatial) to original frame space.

    Args:
        features (np.ndarray): shape (T, H_p, W_p, C)
        patch_size (int): Patch size used in encoder (default 16).
        output_T (int): Target number of frames (e.g., 64).
        method (str): 'mean', 'norm', or 'first' for reducing channels.
        visualize (bool): Whether to show selected frames.

    Returns:
        upsampled_features (np.ndarray): (output_T, H, W), per-frame scalar maps.
    """
    T, H_p, W_p, C = features.shape
    features_torch = torch.from_numpy(features).permute(0, 3, 1, 2).float()  # (T, C, H_p, W_p)

    # Collapse channels to 1
    if method == "mean":
        features_collapsed = features_torch.mean(dim=1, keepdim=True)  # (T, 1, H_p, W_p)
    elif method == "norm":
        features_collapsed = torch.norm(features_torch, dim=1, keepdim=True)  # (T, 1, H_p, W_p)
    elif method == "first":
        features_collapsed = features_torch[:, 0:1, :, :]  # (T, 1, H_p, W_p)
    else:
        raise ValueError("method must be 'mean', 'norm', or 'first'")

    # Upsample spatially
    H = H_p * patch_size
    W = W_p * patch_size
    features_spatial = F.interpolate(
        features_collapsed.permute(1, 0, 2, 3),  # (1, T, H_p, W_p)
        size=(H, W),
        mode="bilinear",
        align_corners=False
    ).squeeze(0)  # (T, H, W)

    # Now interpolate temporally
    features_spatial = features_spatial.permute(1, 2, 0)  # (H, W, T)
    features_spatial = features_spatial.reshape(-1, T).unsqueeze(1)  # (H*W, 1, T)
    features_interp = F.interpolate(features_spatial, size=output_T, mode="linear", align_corners=False)  # (H*W, 1, output_T)

    # Reshape back to (output_T, H, W)
    features_interp = features_interp.squeeze(1).reshape(H, W, output_T).permute(2, 0, 1).numpy()  # (output_T, H, W)
    return features_interp


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import math

def visualize_frame_grid(frames, grid_rows=None, grid_cols=None, figsize=(12, 12), title="Frame Grid"):
    """
    Visualize a stack of frames (as a numpy array) in a grid.

    Args:
        frames (np.ndarray): Array of shape (N, H, W, 3)
        grid_rows (int): Number of rows in the grid. If None, auto-calculated.
        grid_cols (int): Number of cols in the grid. If None, auto-calculated.
        figsize (tuple): Size of the matplotlib figure.
        title (str): Title of the plot.
    """
    N, H, W, C = frames.shape
    assert C == 3, "Frames must be RGB (HxWx3)"
    
    if grid_rows is None or grid_cols is None:
        grid_cols = int(math.sqrt(N))
        grid_rows = math.ceil(N / grid_cols)

    fig, axes = plt.subplots(grid_rows, grid_cols, figsize=figsize)
    fig.suptitle(title, fontsize=16)
    
    for i in range(grid_rows * grid_cols):
        row = i // grid_cols
        col = i % grid_cols
        ax = axes[row, col] if grid_rows > 1 else axes[col]
        ax.axis('off')

        if i < N:
            frame = frames[i]
            ax.imshow(frame.astype(np.uint8))
        else:
            ax.set_visible(False)  # Hide empty subplots

    plt.tight_layout()
    plt.show()

First, let's import the necessary libraries and load the necessary functions for this tutorial.

In [None]:
import json
import os
import subprocess

import numpy as np
import torch
import torch.nn.functional as F
from decord import VideoReader
from transformers import AutoVideoProcessor, AutoModel

import src.datasets.utils.video.transforms as video_transforms
import src.datasets.utils.video.volume_transforms as volume_transforms
from src.models.attentive_pooler import AttentiveClassifier, AttentivePooler
from src.models.vision_transformer import vit_giant_xformers_rope, vit_huge_rope

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)

def load_pretrained_vjepa_pt_weights(model, pretrained_weights):
    # Load weights of the VJEPA2 encoder
    # The PyTorch state_dict is already preprocessed to have the right key names
    pretrained_dict = torch.load(pretrained_weights, weights_only=True, map_location="cpu")["encoder"]
    pretrained_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()}
    pretrained_dict = {k.replace("backbone.", ""): v for k, v in pretrained_dict.items()}
    msg = model.load_state_dict(pretrained_dict, strict=False)
    print("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))


def load_pretrained_vjepa_classifier_weights(model, pretrained_weights):
    # Load weights of the VJEPA2 classifier
    # The PyTorch state_dict is already preprocessed to have the right key names
    pretrained_dict = torch.load(pretrained_weights, weights_only=True, map_location="cpu")["classifiers"][0]
    pretrained_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()}
    msg = model.load_state_dict(pretrained_dict, strict=False)
    print("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))


def build_pt_video_transform(img_size):
    short_side_size = int(256.0 / 224 * img_size)
    # Eval transform has no random cropping nor flip
    eval_transform = video_transforms.Compose(
        [
            video_transforms.Resize(short_side_size, interpolation="bilinear"),
            # video_transforms.CenterCrop(size=(img_size, img_size)),
            volume_transforms.ClipToTensor(),
            video_transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
        ]
    )
    return eval_transform


def get_video():
    vr = VideoReader("sample_video.mp4")
    # choosing some frames here, you can define more complex sampling strategy
    frame_idx = np.arange(0, 128, 2)
    video = vr.get_batch(frame_idx).asnumpy()
    return video


def get_video_path(video_path):
    vr = VideoReader(video_path)
    # choosing some frames here, you can define more complex sampling strategy
    frame_idx = np.arange(30*360, (30*360) + 128, 2)
    video = vr.get_batch(frame_idx).asnumpy()
    return video


def forward_vjepa_video(model_hf, model_pt, hf_transform, pt_transform):
    # Run a sample inference with VJEPA
    with torch.inference_mode():
        # Read and pre-process the image
        video = get_video()  # T x H x W x C
        video = torch.from_numpy(video).permute(0, 3, 1, 2)  # T x C x H x W
        x_pt = pt_transform(video).cuda().unsqueeze(0)
        x_hf = hf_transform(video, return_tensors="pt")["pixel_values_videos"].to("cuda")
        # Extract the patch-wise features from the last layer
        out_patch_features_pt = model_pt(x_pt)
        out_patch_features_hf = model_hf.get_vision_features(x_hf)

    return out_patch_features_hf, out_patch_features_pt


def forward_vjepa_video_pt(model_pt, pt_transform, video):
    # Run a sample inference with VJEPA
    with torch.inference_mode():
        # video -> # T x H x W x C
        video = torch.from_numpy(video).permute(0, 3, 1, 2)  # T x C x H x W
        x_pt = pt_transform(video).cuda().unsqueeze(0)
        # Extract the patch-wise features from the last layer
        out_patch_features_pt = model_pt(x_pt)

    return out_patch_features_pt


def get_vjepa_video_classification_results(classifier, out_patch_features_pt):
    SOMETHING_SOMETHING_V2_CLASSES = json.load(open("ssv2_classes.json", "r"))

    with torch.inference_mode():
        out_classifier = classifier(out_patch_features_pt)

    print(f"Classifier output shape: {out_classifier.shape}")

    print("Top 5 predicted class names:")
    top5_indices = out_classifier.topk(5).indices[0]
    top5_probs = F.softmax(out_classifier.topk(5).values[0]) * 100.0  # convert to percentage
    for idx, prob in zip(top5_indices, top5_probs):
        str_idx = str(idx.item())
        print(f"{SOMETHING_SOMETHING_V2_CLASSES[str_idx]} ({prob}%)")

    return

In [1]:
import os
import numpy as np
import decord
from decord import VideoReader
from torch.utils.data import Dataset

class MultiVideoClipWithMOTDataset(Dataset):
    def __init__(self, video_paths, mot_paths, clip_len=64, frame_rate=5, original_fps=30):
        """
        Args:
            video_paths (list): List of video file paths.
            mot_paths (list): List of corresponding MOT .txt files.
        """
        assert len(video_paths) == len(mot_paths), "Mismatch in number of videos and MOT files."

        self.clip_len = clip_len
        self.frame_rate = frame_rate
        self.step = original_fps // frame_rate

        self.video_readers = []
        self.mot_annotations = []
        self.index_map = []  # global_idx -> (video_idx, clip_start)

        for vid_idx, (vpath, mpath) in enumerate(zip(video_paths, mot_paths)):
            vr = VideoReader(vpath)
            total_frames = len(vr)
            mot_data = self._load_mot(mpath)

            valid_starts = self._compute_valid_clips(total_frames, mot_data)

            for start_idx in valid_starts:
                self.index_map.append((vid_idx, start_idx))

            self.video_readers.append(vr)
            self.mot_annotations.append(mot_data)

    def _load_mot(self, mot_path):
        mot = {}
        with open(mot_path, 'r') as f:
            for line in f:
                items = line.strip().split(',')
                frame_id = int(items[0])
                obj_id = int(items[1])
                if frame_id not in mot:
                    mot[frame_id] = set()
                mot[frame_id].add(obj_id)
        return mot

    def _compute_valid_clips(self, total_frames, mot_data):
        valid_indices = []
        max_start = total_frames - self.clip_len * self.step
        for start in range(0, max_start + 1, self.step):
            clip_indices = [start + i * self.step for i in range(self.clip_len)]
            if any((f + 1) in mot_data for f in clip_indices):  # MOT uses 1-based frame indices
                valid_indices.append(start)
        return valid_indices

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, idx):
        video_idx, start = self.index_map[idx]
        vr = self.video_readers[video_idx]
        mot_data = self.mot_annotations[video_idx]

        frame_indices = [start + i * self.step for i in range(self.clip_len)]
        frames = vr.get_batch(frame_indices).asnumpy()  # (64, H, W, 3)

        unique_ids = set()
        for f in frame_indices:
            obj_ids = mot_data.get(f + 1, [])  # +1 for 1-based MOT
            unique_ids.update(obj_ids)

        return frames, len(unique_ids)


In [None]:
class AttentiveRegressor(nn.Module):
    """Attentive Classifier"""

    def __init__(
        self,
        embed_dim=768,
        num_heads=12,
        mlp_ratio=4.0,
        depth=1,
        norm_layer=nn.LayerNorm,
        init_std=0.02,
        qkv_bias=True,
        complete_block=True,
        use_activation_checkpointing=False,
    ):
        super().__init__()
        self.pooler = AttentivePooler(
            num_queries=1,
            embed_dim=embed_dim,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            depth=depth,
            norm_layer=norm_layer,
            init_std=init_std,
            qkv_bias=qkv_bias,
            complete_block=complete_block,
            use_activation_checkpointing=use_activation_checkpointing,
        )
        self.regressor = nn.Sequential(
            nn.Linear(self.backbone.embed_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        x = self.pooler(x).squeeze(1)
        x = self.linear(x)
        return x

In [None]:
video = get_video_path("../data_sf_1/videos/DJI_0011_002.MP4")

In [None]:
visualize_frame_grid(video)

Next, let's download a sample video to the local repository. If the video is already downloaded, the code will skip this step. Likewise, let's download a mapping for the action recognition classes used in Something-Something V2, so we can interpret the predicted action class from our model.

In [None]:
sample_video_path = "sample_video.mp4"
# Download the video if not yet downloaded to local path
if not os.path.exists(sample_video_path):
    video_url = "https://huggingface.co/datasets/nateraw/kinetics-mini/resolve/main/val/bowling/-WH-lxmGJVY_000005_000015.mp4"
    command = ["wget", video_url, "-O", sample_video_path]
    subprocess.run(command)
    print("Downloading video")

# Download SSV2 classes if not already present
ssv2_classes_path = "ssv2_classes.json"
if not os.path.exists(ssv2_classes_path):
    command = [
        "wget",
        "https://huggingface.co/datasets/huggingface/label-files/resolve/d79675f2d50a7b1ecf98923d42c30526a51818e2/"
        "something-something-v2-id2label.json",
        "-O",
        "ssv2_classes.json",
    ]
    subprocess.run(command)
    print("Downloading SSV2 classes")

In [None]:
!wget https://dl.fbaipublicfiles.com/vjepa2/vith.pt ./

Now, let's load the models in both vanilla Pytorch as well as through the HuggingFace API. Note that HuggingFace API will automatically load the weights through `from_pretrained()`, so there is no additional download required for HuggingFace.

To download the PyTorch model weights, use wget and specify your preferred target path. See the README for the model weight URLs.
E.g. 
```
wget https://dl.fbaipublicfiles.com/vjepa2/vitg-384.pt -P YOUR_DIR
```
Then update `pt_model_path` with `YOUR_DIR/vitg-384.pt`. Also note that you have the option to use `torch.hub.load`.

In [None]:
# HuggingFace model repo name
hf_model_name = (
    "facebook/vjepa2-vitg-fpc64-384"  # Replace with your favored model, e.g. facebook/vjepa2-vitg-fpc64-384
)
# Path to local PyTorch weights
pt_model_path = "./vith.pt"

# Initialize the HuggingFace model, load pretrained weights
# model_hf = AutoModel.from_pretrained(hf_model_name)
# model_hf.cuda().eval()

# Build HuggingFace preprocessing transform
# hf_transform = AutoVideoProcessor.from_pretrained(hf_model_name)
# img_size = hf_transform.crop_size["height"]  # E.g. 384, 256, etc.

# Initialize the PyTorch model, load pretrained weights
img_size = 256
model_pt = vit_huge_rope(img_size=(img_size, img_size), num_frames=64)
model_pt.cuda().eval()
load_pretrained_vjepa_pt_weights(model_pt, pt_model_path)

### Can also use torch.hub to load the model
# model_pt, _ = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_vit_giant_384')
# model_pt.cuda().eval()

# Build PyTorch preprocessing transform
pt_video_transform = build_pt_video_transform(img_size=img_size)

In [None]:
features = forward_vjepa_video_pt(model_pt, pt_video_transform, video)

In [None]:
x_pt.shape

In [None]:
features = features.squeeze(0)
reshaped = features.reshape((64 // model_pt.tubelet_size, 292 // 16, 519 // 16, -1))

In [None]:
show_feature_grid(interpolate_temporal_features_to_frame_space(reshaped.cpu().numpy(), patch_size=16, output_T=64, method="norm"))

Now we can run the encoder on the video to get the patch-wise features from the last layer of the encoder. To verify that the HuggingFace and PyTorch models are equivalent, we will compare the values of the features.

In [None]:
# Inference on video to get the patch-wise features
out_patch_features_hf, out_patch_features_pt = forward_vjepa_video(
    model_hf, model_pt, hf_transform, pt_video_transform
)

print(
    f"""
    Inference results on video:
    HuggingFace output shape: {out_patch_features_hf.shape}
    PyTorch output shape:     {out_patch_features_pt.shape}
    Absolute difference sum:  {torch.abs(out_patch_features_pt - out_patch_features_hf).sum():.6f}
    Close: {torch.allclose(out_patch_features_pt, out_patch_features_hf, atol=1e-3, rtol=1e-3)}
    """
)

Great! Now we know that the features from both models are equivalent. Now let's run a pretrained attentive probe classifier on top of the extracted features, to predict an action class for the video. Let's use the Something-Something V2 probe. Note that the repository also includes attentive probe weights for other evaluations such as EPIC-KITCHENS-100 and Diving48.

To download the attentive probe weights, use wget and specify your preferred target path. E.g. `wget https://dl.fbaipublicfiles.com/vjepa2/evals/ssv2-vitg-384-64x2x3.pt -P YOUR_DIR`

Then update `classifier_model_path` with `YOUR_DIR/ssv2-vitg-384-64x2x3.pt`.

In [None]:
# Initialize the classifier
classifier_model_path = "YOUR_ATTENTIVE_PROBE_PATH"
classifier = (
    AttentiveClassifier(embed_dim=model_pt.embed_dim, num_heads=16, depth=4, num_classes=174).cuda().eval()
)
load_pretrained_vjepa_classifier_weights(classifier, classifier_model_path)

# Get classification results
get_vjepa_video_classification_results(classifier, out_patch_features_pt)

The video features a man putting a bowling ball into a tube, so the predicted action of "Putting [something] into [something]" makes sense!

This concludes the tutorial. Please see the README and paper for full details on the capabilities of V-JEPA 2 :)