In [6]:
import json
import os
import subprocess

import numpy as np
import torch
import torch.nn.functional as F
from decord import VideoReader
from transformers import AutoVideoProcessor, AutoModel

import src.datasets.utils.video.transforms as video_transforms
import src.datasets.utils.video.volume_transforms as volume_transforms
from src.models.attentive_pooler import AttentiveClassifier
from src.models.vision_transformer import vit_giant_xformers_rope

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)

def load_pretrained_vjepa_pt_weights(model, pretrained_weights):
    # Load weights of the VJEPA2 encoder
    # The PyTorch state_dict is already preprocessed to have the right key names
    pretrained_dict = torch.load(pretrained_weights, weights_only=True, map_location="cpu")["encoder"]
    pretrained_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()}
    pretrained_dict = {k.replace("backbone.", ""): v for k, v in pretrained_dict.items()}
    msg = model.load_state_dict(pretrained_dict, strict=False)
    print("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))


def load_pretrained_vjepa_classifier_weights(model, pretrained_weights):
    # Load weights of the VJEPA2 classifier
    # The PyTorch state_dict is already preprocessed to have the right key names
    pretrained_dict = torch.load(pretrained_weights, weights_only=True, map_location="cpu")["classifiers"][0]
    pretrained_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()}
    msg = model.load_state_dict(pretrained_dict, strict=False)
    print("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))


def build_pt_video_transform(img_size):
    short_side_size = int(256.0 / 224 * img_size)
    # Eval transform has no random cropping nor flip
    eval_transform = video_transforms.Compose(
        [
            video_transforms.Resize(short_side_size, interpolation="bilinear"),
            video_transforms.CenterCrop(size=(img_size, img_size)),
            volume_transforms.ClipToTensor(),
            video_transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
        ]
    )
    return eval_transform


def get_video():
    vr = VideoReader("sample/vjepa2_test1_1.mp4")
    # choosing some frames here, you can define more complex sampling strategy
    frame_idx = np.arange(0, 128, 2)
    video = vr.get_batch(frame_idx).asnumpy()
    return video


def forward_vjepa_video(model_hf, model_pt, hf_transform, pt_transform):
    # Run a sample inference with VJEPA
    with torch.inference_mode():
        # Read and pre-process the image
        video = get_video()  # T x H x W x C
        video = torch.from_numpy(video).permute(0, 3, 1, 2)  # T x C x H x W
        x_pt = pt_transform(video).cuda().unsqueeze(0)
        x_hf = hf_transform(video, return_tensors="pt")["pixel_values_videos"].to("cuda")
        # Extract the patch-wise features from the last layer
        out_patch_features_pt = model_pt(x_pt)
        out_patch_features_hf = model_hf.get_vision_features(x_hf)

    return out_patch_features_hf, out_patch_features_pt


def get_vjepa_video_classification_results(classifier, out_patch_features_pt):
    SOMETHING_SOMETHING_V2_CLASSES = json.load(open("ssv2_classes.json", "r"))

    with torch.inference_mode():
        out_classifier = classifier(out_patch_features_pt)

    print(f"Classifier output shape: {out_classifier.shape}")

    print("Top 5 predicted class names:")
    top5_indices = out_classifier.topk(100).indices[0]
    top5_probs = F.softmax(out_classifier.topk(100).values[0]) * 100.0  # convert to percentage
    for idx, prob in zip(top5_indices, top5_probs):
        str_idx = str(idx.item())
        print(f"{SOMETHING_SOMETHING_V2_CLASSES[str_idx]} ({prob}%)")

    return

# HuggingFace model repo name
hf_model_name = (
    "facebook/vjepa2-vitg-fpc64-384"  # Replace with your favored model, e.g. facebook/vjepa2-vitg-fpc64-384
)
# Path to local PyTorch weights
#pt_model_path = "YOUR_MODEL_PATH"

# Initialize the HuggingFace model, load pretrained weights
model_hf = AutoModel.from_pretrained(hf_model_name)
model_hf.cuda().eval()

# Build HuggingFace preprocessing transform
hf_transform = AutoVideoProcessor.from_pretrained(hf_model_name)
img_size = hf_transform.crop_size["height"]  # E.g. 384, 256, etc.

# Initialize both models
model_pt = vit_giant_xformers_rope(img_size=(img_size, img_size), num_frames=64)
model_pt.cuda().eval()
pt_model_path = "vitg-384.pt"
load_pretrained_vjepa_pt_weights(model_pt, pt_model_path)

pt_video_transform = build_pt_video_transform(img_size=img_size)
out_patch_features_hf, out_patch_features_pt = forward_vjepa_video(model_hf, model_pt, hf_transform, pt_video_transform)

print(f"""
HuggingFace output shape: {out_patch_features_hf.shape}
PyTorch output shape: {out_patch_features_pt.shape}
Absolute difference sum: {torch.abs(out_patch_features_pt - out_patch_features_hf).sum():.6f}
Close: {torch.allclose(out_patch_features_pt, out_patch_features_hf, atol=1e-3, rtol=1e-3)}
""")

# Load classifier
classifier_model_path = "ssv2-vitg-384-64x2x3.pt"
assert os.path.exists(classifier_model_path), "Classifier model not found!"

classifier = AttentiveClassifier(embed_dim=model_hf.config.hidden_size, num_heads=16, depth=4, num_classes=174).cuda().eval()
pretrained_dict = torch.load(classifier_model_path, weights_only=True, map_location="cpu")["classifiers"][0]
pretrained_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()}
msg = classifier.load_state_dict(pretrained_dict, strict=False)
print("Pretrained weights found and loaded with msg: {}".format(msg))

get_vjepa_video_classification_results(classifier, out_patch_features_pt)


Pretrained weights found at vitg-384.pt and loaded with msg: <All keys matched successfully>

HuggingFace output shape: torch.Size([1, 18432, 1408])
PyTorch output shape: torch.Size([1, 18432, 1408])
Absolute difference sum: 11951387.000000
Close: False

Pretrained weights found and loaded with msg: <All keys matched successfully>
Classifier output shape: torch.Size([1, 174])
Top 5 predicted class names:


  top5_probs = F.softmax(out_classifier.topk(100).values[0]) * 100.0  # convert to percentage


Showing [something] to the camera (62.27323532104492%)
Spinning [something] so it continues spinning (15.901102066040039%)
Rolling [something] on a flat surface (2.63826060295105%)
Throwing [something] (2.478363513946533%)
Pushing [something] so it spins (1.733525276184082%)
Letting [something] roll along a flat surface (1.6069393157958984%)
Poking [something] so that it spins around (1.201206922531128%)
Twisting [something] (0.9632881879806519%)
Turning the camera left while filming [something] (0.863765299320221%)
Spreading [something] onto [something] (0.6216221451759338%)
Showing a photo of [something] to the camera (0.5682592391967773%)
Moving [part] of [something] (0.5400030612945557%)
Pretending to spread air onto [something] (0.40627753734588623%)
Moving [something] across a surface without it falling down (0.39038965106010437%)
Pretending to turn [something] upside down (0.3809271454811096%)
Moving [something] and [something] so they pass each other (0.35168275237083435%)
Movi

In [3]:
ls

 ドライブ D のボリューム ラベルは ボリューム です
 ボリューム シリアル番号は ECD2-8FC7 です

 D:\otake\gpu-service\users\vidzshan\notebooks\Finalvjepa2\vjepa2 のディレクトリ

2025/07/03  11:07    <DIR>          .
2025/07/03  10:09    <DIR>          ..
2025/07/03  10:09               120 .flake8
2025/07/03  10:09    <DIR>          .github
2025/07/03  10:09               499 .gitignore
2025/07/03  10:39    <DIR>          .ipynb_checkpoints
2025/07/03  10:09            11,349 APACHE-LICENSE
2025/07/03  10:09    <DIR>          app
2025/07/03  10:09    <DIR>          assets
2025/07/03  10:09                74 CHANGELOG.md
2025/07/03  10:09             3,535 CODE_OF_CONDUCT.md
2025/07/03  10:09    <DIR>          configs
2025/07/03  10:09             1,504 CONTRIBUTING.md
2025/07/03  10:09    <DIR>          evals
2025/07/03  10:09               429 hubconf.py
2025/07/03  10:09             1,087 LICENSE
2025/07/03  10:09    <DIR>          notebooks
2025/07/03  10:09                77 pyproject.toml
2025/07/03  10:09            17,245 