In [1]:
import sys
sys.path.insert(0, "..")

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.nn import functional as F

from app.vjepa_droid.transforms import make_transforms
from utils.mpc_utils import (
    compute_new_pose,
    poses_to_diff
)
from transformers import UMT5EncoderModel, AutoTokenizer
import torch
from torch.nn import functional as F
from src.models.ac_predictor import vit_ac_predictor
from src.models.vision_transformer import vit_giant_xformers

In [None]:
text_encoder = UMT5EncoderModel.from_pretrained(
    "Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="text_encoder", torch_dtype=torch.bfloat16
).to("cuda").eval()
tokenizer = AutoTokenizer.from_pretrained("Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="tokenizer")

In [None]:
# Initialize VJEPA 2-AC model
encoder = vit_giant_xformers(
    img_size=256,
    patch_size=16,
    num_frames=512,
    tubelet_size=2,
    uniform_power=True,
    use_sdpa=True,
    use_silu=False,
    wide_silu=True,
    use_activation_checkpointing=False,
    use_rope=True,
).eval().to("cuda")
predictor = vit_ac_predictor(
        img_size=256,
        patch_size=16,
        num_frames=512,
        tubelet_size=2,
        embed_dim=encoder.embed_dim,
        predictor_embed_dim=1024,
        depth=24,
        is_frame_causal=True,
        num_heads=16,
        uniform_power=True,
        use_rope=True,
        use_sdpa=True,
        use_silu=False,
        wide_silu=True,
        use_activation_checkpointing=False,
    ).eval().to("cuda")
# This is the agibot checkpoint
state_dict = torch.load("/mnt/weka/home/yi.gu/tokenizer/zh/vjepa2/results/bridge/4.8.vitg16-256px-8f_25_07_03_16_41_00/e275.pt")
# This is the language table dataset
#state_dict = torch.load("/mnt/weka/home/yi.gu/tokenizer/zh/vjepa2/results/bridge/4.8.vitg16-256px-8f_25_07_03_16_40_36/e375.pt")
renamed_encoder_state_dict = {}
for k, v in state_dict["encoder"].items():
    if k.startswith("module."):
        renamed_encoder_state_dict[k[7:]] = v
    else:
        renamed_encoder_state_dict[k] = v

renamed_predictor_state_dict = {}
for k, v in state_dict["predictor"].items():
    if k.startswith("module."):
        renamed_predictor_state_dict[k[7:]] = v
    else:
        renamed_predictor_state_dict[k] = v

encoder.load_state_dict(renamed_encoder_state_dict)
predictor.load_state_dict(renamed_predictor_state_dict)

# Initialize transform
crop_size = 256
tokens_per_frame = int((crop_size // encoder.patch_size) ** 2)
transform = make_transforms(
    random_horizontal_flip=False,
    random_resize_aspect_ratio=(1., 1.),
    random_resize_scale=(1., 1.),
    reprob=0.,
    auto_augment=False,
    motion_shift=False,
    crop_size=crop_size,
)

In [5]:
def step_predictor(_z, _t):
    _z = predictor(_z, _t)
    _z = F.layer_norm(_z, (_z.size(-1),))
    return _z

In [6]:
import json
from decord import VideoReader

In [None]:
data = []
#just using bridge data for a sample. You only need a first image and a text to do inference
with open("/mnt/weka/home/yi.gu/world-model/evaluation/bridge/output_video0622/index.jsonl", "r") as f:
    for line in f:
        data.append(json.loads(line))
sample = data[0]
sample.keys()

In [None]:
text_instruction = sample["instruction"]
text_instruction = tokenizer(text_instruction, return_tensors="pt", padding="max_length", truncation=True, max_length=32)
text_input_ids, mask = text_instruction.input_ids[0], text_instruction.attention_mask[0]
text_input_ids = text_input_ids.to("cuda").unsqueeze(0)
mask = mask.to("cuda").unsqueeze(0)
with torch.no_grad():
    encoded_text = text_encoder(text_input_ids, attention_mask=mask).last_hidden_state

In [15]:
video_path = "/mnt/weka/home/yi.gu/world-model/evaluation/bridge/output_video0622/" + sample['video']
loaded_video_clip = VideoReader(video_path)
first_image = np.expand_dims(loaded_video_clip[0].asnumpy(), axis=0)
first_image = transform(first_image)

In [16]:
def forward_target(c):
    batch_size = c.shape[0]
    with torch.no_grad():
        c = c.permute(0, 2, 1, 3, 4).flatten(0, 1).unsqueeze(2).repeat(1, 1, 2, 1, 1)
        h = encoder(c)
        h = h.view(batch_size, 1, -1, h.size(-1)).flatten(1, 2)
        h = F.layer_norm(h, (h.size(-1),))
        return h

In [None]:
first_image = forward_target(first_image.unsqueeze(0).to("cuda"))

In [None]:
_z = first_image
encoded_text = encoded_text.to(torch.float32)
for n in range(10):
    _z_nxt = step_predictor(_z, encoded_text)[:, -256:]
    _z = torch.cat([_z, _z_nxt], dim=1)

In [None]:
t = 11
n = 256
frame_by_frame_representation = _z.reshape(1, t, n, -1)
frame_by_frame_representation.shape