# V-JEPA 2 Demo Notebook

This tutorial provides an example of how to load the V-JEPA 2 model in vanilla PyTorch and HuggingFace, extract a video embedding, and then predict an action class. For more details about the paper and model weights, please see https://github.com/facebookresearch/vjepa2.

First, let's import the necessary libraries and load the necessary functions for this tutorial.

In [49]:
import json
import os
import sys
import subprocess

if os.getcwd().endswith('notebooks'):
    # 切换到项目根目录
    os.chdir('..')

import numpy as np
import torch
import torch.nn.functional as F
import cv2
from transformers import AutoVideoProcessor, AutoModel

import src.datasets.utils.video.transforms as video_transforms
import src.datasets.utils.video.volume_transforms as volume_transforms
from src.models.attentive_pooler import AttentiveClassifier
from src.models.vision_transformer import vit_large_rope

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)

def load_pretrained_vjepa_pt_weights(model, pretrained_weights):
    # Load weights of the VJEPA2 encoder
    # The PyTorch state_dict is already preprocessed to have the right key names
    pretrained_dict = torch.load(pretrained_weights, weights_only=True, map_location="cpu")["encoder"]
    pretrained_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()}
    pretrained_dict = {k.replace("backbone.", ""): v for k, v in pretrained_dict.items()}
    msg = model.load_state_dict(pretrained_dict, strict=False)
    print("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))


def load_pretrained_vjepa_classifier_weights(model, pretrained_weights):
    # Load weights of the VJEPA2 classifier
    # The PyTorch state_dict is already preprocessed to have the right key names
    pretrained_dict = torch.load(pretrained_weights, weights_only=True, map_location="cpu")["classifiers"][0]
    pretrained_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()}
    msg = model.load_state_dict(pretrained_dict, strict=False)
    print("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))


def build_pt_video_transform(img_size):
    short_side_size = int(256.0 / 224 * img_size)
    # Eval transform has no random cropping nor flip
    eval_transform = video_transforms.Compose(
        [
            video_transforms.Resize(short_side_size, interpolation="bilinear"),
            video_transforms.CenterCrop(size=(img_size, img_size)),
            volume_transforms.ClipToTensor(),
            video_transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
        ]
    )
    return eval_transform


# def get_video():
#     vr = VideoReader("sample_video.mp4")
#     # choosing some frames here, you can define more complex sampling strategy
#     frame_idx = np.arange(0, 128, 2)
#     video = vr.get_batch(frame_idx).asnumpy()
#     return video

def get_video():
    """完全避开 torch.from_numpy()，使用 torch.tensor()"""
    cap = cv2.VideoCapture("sample_video.mp4")
    
    # 获取视频信息
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    print(f"Total frames in video: {total_frames}")
    
    frames = []
    frame_idx = np.arange(0, min(128, total_frames), 2)
    current_frame = 0
    
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        if current_frame in frame_idx:
            # OpenCV读取的是BGR，需要转换为RGB
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            # 使用 torch.tensor() 代替 torch.from_numpy()
            frame_tensor = torch.tensor(frame_rgb, dtype=torch.uint8)
            frames.append(frame_tensor)
        current_frame += 1
        if len(frames) >= len(frame_idx):
            break
    
    cap.release()
    
    # 堆叠为 torch tensor
    video = torch.stack(frames, dim=0)  # (T, H, W, C)
    print(f"Video shape: {video.shape}")
    return video

def forward_vjepa_video(model_hf, model_pt, hf_transform, pt_transform):
    """修复版本"""
    with torch.inference_mode():
        # Read and pre-process the video
        video = get_video()  # T x H x W x C (torch tensor)
        
        # 确保数据类型正确
        if video.dtype == torch.uint8:
            video = video.float() / 255.0  # 归一化到 [0,1]
        
        video = video.permute(0, 3, 1, 2)  # T x C x H x W
        print(f"Video tensor shape after permute: {video.shape}")
        
        # PyTorch 处理
        x_pt = pt_transform(video).unsqueeze(0)
        print(f"PyTorch input shape: {x_pt.shape}")
        
        # HuggingFace 处理（使用相同数据）
        print("⚠️ 使用 PyTorch transform 代替 HuggingFace")
        x_hf = x_pt
        
        # 模型推理
        print("🔄 运行 PyTorch 模型...")
        out_patch_features_pt = model_pt(x_pt)
        
        print("🔄 运行 HuggingFace 模型（实际是同一个模型）...")
        out_patch_features_hf = model_pt(x_hf)
        
        print(f"PyTorch output shape: {out_patch_features_pt.shape}")
        print(f"HuggingFace output shape: {out_patch_features_hf.shape}")
        
        return out_patch_features_hf, out_patch_features_pt

# def forward_vjepa_video(model_hf, model_pt, hf_transform, pt_transform):
#     # Run a sample inference with VJEPA
#     with torch.inference_mode():
#         # Read and pre-process the image
#         video = get_video()  # T x H x W x C
#         video = torch.from_numpy(video).permute(0, 3, 1, 2)  # T x C x H x W
#         x_pt = pt_transform(video).cuda().unsqueeze(0)
#         x_hf = hf_transform(video, return_tensors="pt")["pixel_values_videos"].to("cuda")
#         # Extract the patch-wise features from the last layer
#         out_patch_features_pt = model_pt(x_pt)
#         out_patch_features_hf = model_hf.get_vision_features(x_hf)

#     return out_patch_features_hf, out_patch_features_pt


def get_vjepa_video_classification_results(classifier, out_patch_features_pt):
    SOMETHING_SOMETHING_V2_CLASSES = json.load(open("ssv2_classes.json", "r"))

    with torch.inference_mode():
        out_classifier = classifier(out_patch_features_pt)

    print(f"Classifier output shape: {out_classifier.shape}")

    print("Top 5 predicted class names:")
    top5_indices = out_classifier.topk(5).indices[0]
    top5_probs = F.softmax(out_classifier.topk(5).values[0]) * 100.0  # convert to percentage
    for idx, prob in zip(top5_indices, top5_probs):
        str_idx = str(idx.item())
        print(f"{SOMETHING_SOMETHING_V2_CLASSES[str_idx]} ({prob}%)")

    return

Next, let's download a sample video to the local repository. If the video is already downloaded, the code will skip this step. Likewise, let's download a mapping for the action recognition classes used in Something-Something V2, so we can interpret the predicted action class from our model.

In [41]:
import os
import urllib.request

# Download the video if not yet downloaded to local path
sample_video_path = "sample_video.mp4"

# 删除可能存在的空文件
if os.path.exists(sample_video_path) and os.path.getsize(sample_video_path) == 0:
    os.remove(sample_video_path)
    print("删除了空的视频文件")

if not os.path.exists(sample_video_path):
    video_url = "https://huggingface.co/datasets/nateraw/kinetics-mini/resolve/main/val/bowling/-WH-lxmGJVY_000005_000015.mp4"
    print("Downloading video...")
    
    try:
        with urllib.request.urlopen(video_url, timeout=60) as response:
            if response.status == 200:
                with open(sample_video_path, 'wb') as f:
                    # 分块下载，显示进度
                    total_size = int(response.headers.get('content-length', 0))
                    downloaded = 0
                    chunk_size = 8192
                    
                    while True:
                        chunk = response.read(chunk_size)
                        if not chunk:
                            break
                        f.write(chunk)
                        downloaded += len(chunk)
                        
                        if total_size > 0:
                            progress = downloaded / total_size * 100
                            print(f"\r下载进度: {progress:.1f}%", end='')
                    
                    print()  # 换行
                
                final_size = os.path.getsize(sample_video_path)
                if final_size > 0:
                    print(f"✅ 视频下载完成，文件大小: {final_size/1024/1024:.2f} MB")
                else:
                    raise Exception("下载的文件为空")
            else:
                raise Exception(f"HTTP 错误: {response.status}")
                
    except Exception as e:
        print(f"❌ 视频下载失败: {e}")
        if os.path.exists(sample_video_path):
            os.remove(sample_video_path)
        exit(1)  # 如果下载失败就退出
else:
    print(f"视频文件已存在，大小: {os.path.getsize(sample_video_path)/1024/1024:.2f} MB")

# Download SSV2 classes if not already present
ssv2_classes_path = "ssv2_classes.json"

# 删除可能存在的空文件
if os.path.exists(ssv2_classes_path) and os.path.getsize(ssv2_classes_path) == 0:
    os.remove(ssv2_classes_path)
    print("删除了空的 JSON 文件")

if not os.path.exists(ssv2_classes_path):
    json_url = "https://huggingface.co/datasets/huggingface/label-files/resolve/d79675f2d50a7b1ecf98923d42c30526a51818e2/something-something-v2-id2label.json"
    print("Downloading SSV2 classes...")
    
    try:
        with urllib.request.urlopen(json_url, timeout=30) as response:
            if response.status == 200:
                with open(ssv2_classes_path, 'wb') as f:
                    content = response.read()
                    f.write(content)
                
                final_size = os.path.getsize(ssv2_classes_path)
                if final_size > 0:
                    print(f"✅ JSON 文件下载完成，文件大小: {final_size/1024:.1f} KB")
                else:
                    raise Exception("下载的 JSON 文件为空")
            else:
                raise Exception(f"HTTP 错误: {response.status}")
                
    except Exception as e:
        print(f"❌ JSON 文件下载失败: {e}")
        if os.path.exists(ssv2_classes_path):
            os.remove(ssv2_classes_path)
        # JSON 文件下载失败不退出程序，因为可能不是必需的
        print("⚠️  继续运行，但可能会影响后续功能")
else:
    print(f"JSON 文件已存在，大小: {os.path.getsize(ssv2_classes_path)/1024:.1f} KB")

视频文件已存在，大小: 0.42 MB
删除了空的 JSON 文件
Downloading SSV2 classes...
✅ JSON 文件下载完成，文件大小: 9.9 KB


Now, let's load the models in both vanilla Pytorch as well as through the HuggingFace API. Note that HuggingFace API will automatically load the weights through `from_pretrained()`, so there is no additional download required for HuggingFace.

To download the PyTorch model weights, use wget and specify your preferred target path. See the README for the model weight URLs.
E.g. 
```
wget https://dl.fbaipublicfiles.com/vjepa2/vitg-384.pt -P YOUR_DIR
```
Then update `pt_model_path` with `YOUR_DIR/vitg-384.pt`. Also note that you have the option to use `torch.hub.load`.

In [46]:
# HuggingFace model repo name
hf_model_name = "facebook/vjepa2-vitl-fpc64-256"
# Path to local PyTorch weights  
pt_model_path = "./models/vitl.pt"

# 尝试 HuggingFace 加载
try:
    model_hf = AutoModel.from_pretrained(hf_model_name, trust_remote_code=True)
    hf_transform = AutoVideoProcessor.from_pretrained(hf_model_name, trust_remote_code=True)
    img_size = hf_transform.crop_size["height"]
    model_hf.eval()
    print("✅ HuggingFace 加载成功")
except:
    print("⚠️ HuggingFace 失败，使用 PyTorch 版本")
    # 关键：让两个变量指向同一个模型
    img_size = 256  # 从模型名推断
    model_pt = vit_large_rope(img_size=(img_size, img_size), num_frames=64)
    model_pt.eval()
    load_pretrained_vjepa_pt_weights(model_pt, pt_model_path)
    
    # 重点：让 model_hf 也指向同一个模型
    model_hf = model_pt
    
    # 创建简单的占位符变换器
    class SimpleTransform:
        crop_size = {"height": img_size}
    hf_transform = SimpleTransform()

# 确保 PyTorch 版本存在
if 'model_pt' not in locals():
    model_pt = model_hf

# Build PyTorch preprocessing transform
pt_video_transform = build_pt_video_transform(img_size=img_size)

⚠️ HuggingFace 失败，使用 PyTorch 版本
Pretrained weights found at ./models/vitl.pt and loaded with msg: <All keys matched successfully>


Now we can run the encoder on the video to get the patch-wise features from the last layer of the encoder. To verify that the HuggingFace and PyTorch models are equivalent, we will compare the values of the features.

In [50]:
# Inference on video to get the patch-wise features
out_patch_features_hf, out_patch_features_pt = forward_vjepa_video(
    model_hf, model_pt, hf_transform, pt_video_transform
)

print(
    f"""
    Inference results on video:
    HuggingFace output shape: {out_patch_features_hf.shape}
    PyTorch output shape:     {out_patch_features_pt.shape}
    Absolute difference sum:  {torch.abs(out_patch_features_pt - out_patch_features_hf).sum():.6f}
    Close: {torch.allclose(out_patch_features_pt, out_patch_features_hf, atol=1e-3, rtol=1e-3)}
    """
)

Total frames in video: 150
Video shape: torch.Size([64, 270, 480, 3])
Video tensor shape after permute: torch.Size([64, 3, 270, 480])
PyTorch input shape: torch.Size([1, 3, 64, 256, 256])
⚠️ 使用 PyTorch transform 代替 HuggingFace
🔄 运行 PyTorch 模型...
🔄 运行 HuggingFace 模型（实际是同一个模型）...
PyTorch output shape: torch.Size([1, 8192, 1024])
HuggingFace output shape: torch.Size([1, 8192, 1024])

    Inference results on video:
    HuggingFace output shape: torch.Size([1, 8192, 1024])
    PyTorch output shape:     torch.Size([1, 8192, 1024])
    Absolute difference sum:  0.000000
    Close: True
    


Great! Now we know that the features from both models are equivalent. Now let's run a pretrained attentive probe classifier on top of the extracted features, to predict an action class for the video. Let's use the Something-Something V2 probe. Note that the repository also includes attentive probe weights for other evaluations such as EPIC-KITCHENS-100 and Diving48.

To download the attentive probe weights, use wget and specify your preferred target path. E.g. `wget https://dl.fbaipublicfiles.com/vjepa2/evals/ssv2-vitg-384-64x2x3.pt -P YOUR_DIR`

Then update `classifier_model_path` with `YOUR_DIR/ssv2-vitg-384-64x2x3.pt`.

In [52]:
# Initialize the classifier
classifier_model_path = "./models/ssv2-vitl-16x2x3.pt"
classifier = (
    AttentiveClassifier(embed_dim=model_pt.embed_dim, num_heads=16, depth=4, num_classes=174).eval()
)
load_pretrained_vjepa_classifier_weights(classifier, classifier_model_path)

# Get classification results
get_vjepa_video_classification_results(classifier, out_patch_features_pt)

Pretrained weights found at ./models/ssv2-vitl-16x2x3.pt and loaded with msg: <All keys matched successfully>
Classifier output shape: torch.Size([1, 174])
Top 5 predicted class names:
Closing [something] (35.74630355834961%)
Moving [something] and [something] closer to each other (27.593738555908203%)
Moving [something] down (12.300285339355469%)
Moving [something] and [something] so they collide with each other (12.270605087280273%)
Pushing [something] from right to left (12.089075088500977%)


  top5_probs = F.softmax(out_classifier.topk(5).values[0]) * 100.0  # convert to percentage


The video features a man putting a bowling ball into a tube, so the predicted action of "Putting [something] into [something]" makes sense!

This concludes the tutorial. Please see the README and paper for full details on the capabilities of V-JEPA 2 :)