In [130]:
import torch
from einops import repeat
import matplotlib.pyplot as plt

from video_jepa.data import PendulumDataset
from video_jepa.world_model import WorldModel

In [None]:
video_encoder, _ = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_vit_large')

n_frames = 3
input_size = (128, 128)
action_embed_dim = 96
patch_size = video_encoder.patch_size
tubelet_size = video_encoder.tubelet_size
model = WorldModel(
    num_hist=n_frames,
    num_pred=n_frames,
    video_encoder=video_encoder,
    input_size=input_size,
    action_dim=1,
    action_embed_dim=96,
)

# Load pretrained models
model.latent_predictor.load_state_dict(torch.load("output/latent_predictor.pt"))
model.decoder.load_state_dict(torch.load("output/decoder.pt"))
model.action_encoder.load_state_dict(torch.load("output/action_emb.pt"))
model.cuda()

# Predicting future frames
model.latent_predictor

In [None]:
# Initialize the training set to visualize next frame predictions
train_dataset = PendulumDataset(
    seq_len=12,
    input_size=(128, 128),
    include_states=False,
    include_actions=True
)

# Sample an index from the training set
batch = train_dataset[1]
x = batch["video"].unsqueeze(0).moveaxis(1, 2).cuda()
actions = batch["actions"].unsqueeze(0).cuda()

x.shape

In [None]:
# The video looks as following
vid = x[0].cpu().permute(1, 0, 2, 3)  # (12, 3, H, W)
fig, axes = plt.subplots(3, 4, figsize=(12, 10))

for i, ax in enumerate(axes.flat):
    ax.imshow(vid[i].moveaxis(0, -1).cpu().numpy())
    ax.set_title(f"t={i}", fontsize=10)
    ax.axis("off")

plt.tight_layout()
plt.show()

This is the dataset on which the predictor and decoder are finetuned on. We feed the first 6 images, and let the model VQVAE model predict the next frames.

In [None]:
B, C, T, H, W = x.shape

patch_h = input_size[0] // patch_size 
patch_w = input_size[1] // patch_size

# Inputting the first 6 frames as context
with torch.no_grad():
    z = model.encoder(x[:, :, :n_frames*tubelet_size, ...])

patch_t = z.shape[1] // (patch_h * patch_w)
z = z.reshape(B, patch_t, -1, z.shape[-1])
z_src = z[:, : n_frames, :, :]
z.shape

In [None]:
# Action encoding
z_act = model.action_encoder(actions)
z_act = z_act[:, : n_frames].unsqueeze(2)
act_tiled = repeat(
    z_act,
    "b t 1 a -> b t f a",
    f=z_src.shape[2]
)

# Latent Predictor, ViT
# (B, num_pred, num_patches, 2)
z_src = torch.cat([z_src, act_tiled], dim=3)
z_src = z_src.reshape(B, -1, z.shape[-1] + 96).detach()

In [None]:
# Latent Predictor inference, ViT
# (b * frames * num_patches, dim)
with torch.no_grad():
    z_pred = model.latent_predictor(z_src)
z_pred = z_pred.reshape(B, n_frames, patch_h * patch_w, -1)
z_pred.shape

In [None]:
# Decoder to visualize the predictions
with torch.no_grad():
    visual_pred, diff_pred = model.decoder(
        z_pred[..., :-action_embed_dim],
        patch_h,
        patch_w,
        frames_per_latent=tubelet_size
    )

# TODO: Currently the frames are still averaged as the VAE only outputs 
# a single frame per latent.
k = n_frames * tubelet_size
visual_tgt = x[:, :, k :, ...].moveaxis(1, 2)
# Reshape to (B, tubelet, C, H, W) to average tubelet size
visual_tgt = visual_tgt.view(B, k, C, H, W)
visual_pred = visual_pred.view(B, k, C, H, W)

tgt_frames  = visual_tgt[0][-k:]
pred_frames = visual_pred[0][-k:]

T = visual_tgt[0].shape[0]  # total timesteps
times = range(T - k, T)

fig, axes = plt.subplots(2, k, figsize=(4*k, 8))

for i, (t, f) in enumerate(zip(times, tgt_frames)):
    axes[0, i].imshow(f.moveaxis(0, -1).detach().clamp(0, 1).cpu().numpy())
    axes[0, i].set_title(f"t = {t}")
    axes[0, i].axis("off")

for i, (t, f) in enumerate(zip(times, pred_frames)):
    axes[1, i].imshow(f.moveaxis(0, -1).detach().clamp(0, 1).cpu().numpy())
    axes[1, i].axis("off")

plt.tight_layout()
plt.show()

We can also compare the predicted latents of the source and the prediction model.

In [None]:
# Preparing everything for PCA
embed_dim = model.encoder.embed_dim
z_pred = z_pred[..., :-action_embed_dim].reshape(-1, embed_dim)

def pca(X, n_components=3):
    Z_mean = X.mean(0, keepdim=True)
    Z = X - Z_mean
    U, S, VT = torch.linalg.svd(Z, full_matrices=False)
    
    max_col = torch.argmax(torch.abs(U), dim=0)
    signs = torch.sign(U[max_col, range(U.shape[1])])
    VT *= signs[:, None]

    Z = torch.matmul(Z, VT[:n_components].T)
    return Z

def min_max(X, target_min = 0.0, target_max = 1.0):
    eps = 1e-8
    X_std = (X - X.min(0, True).values) / (X.max(0, True).values - X.min(0, True).values + eps)
    X_scaled = X_std * (target_max - target_min) + target_min
    return X_scaled

z_pred = min_max(pca(z_pred))
z_pred.shape

In [None]:
with torch.no_grad():
    z_target = model.encoder(x)

patch_t = z_target.shape[1] // (patch_h * patch_w)
z_target = z_target.reshape(B, patch_t, -1, embed_dim)

z_target = z_target[:, n_frames:].reshape(-1, embed_dim)
z_target = min_max(pca(z_target))
z_target.shape

In [None]:
pred = z_pred.reshape(3, patch_h, patch_w, 3)
gt = z_target.reshape(3, patch_h, patch_w, 3)

fig, axes = plt.subplots(2, 3, figsize=(12, 8))
for t in range(3):
    axes[0, t].imshow(gt[t].cpu())
    axes[0, t].set_title(f"GT  t={t}")
    axes[0, t].axis("off")

    axes[1, t].imshow(pred[t].cpu())
    axes[1, t].set_title(f"Pred t={t}")
    axes[1, t].axis("off")

plt.tight_layout()
plt.show()