In [1]:
import sys
from pathlib import Path

# Ensure repo root is on sys.path so `utils` imports work
root = Path.cwd()
if root.name == "notebooks":
    root = root.parent
sys.path.insert(0, str(root))


In [2]:
import torch
from IPython.display import display
from utils.filesystem import get_files_dir_from_example_dir
from utils.video import visualise_video_pt

example_dir = "/private/home/francoisporcher/FutureLatents/experiment/debug/dinov3_kinetics_400_deterministic_cross_attention/dump/example_00"
files_dir = get_files_dir_from_example_dir(example_dir)
video_pt = torch.load(files_dir["video_pt"])  # (T, C, H, W)

display(visualise_video_pt(video_pt, fps=10, loop=True))


In [3]:
dir_context_latents = files_dir['context_latents']
dir_prediction_latents = files_dir['prediction_latents']
dir_target_latents = files_dir['target_latents']

# load
context_latents = torch.load(dir_context_latents)
prediction_latents = torch.load(dir_prediction_latents)
target_latents = torch.load(dir_target_latents)

# print shapes
print("context_latents", context_latents.shape) # (N, D)
print("prediction_latents", prediction_latents.shape) # (N, D)
print("target_latents", target_latents.shape) # (N, D)

# define T, H, W
T, H, W = 16, 16, 16
# derive temporal lengths from latent counts
n_ctx_lat = int(context_latents.shape[0] // (H * W))
n_tgt_lat = int(target_latents.shape[0] // (H * W))

# PCA + reshape via utils
from utils.pca import pca_latents_to_video_tensors
context_latents_pca_reshaped, target_latents_pca_reshaped, prediction_latents_pca_reshaped = pca_latents_to_video_tensors(
    context_latents, target_latents, prediction_latents, n_ctx_lat=n_ctx_lat, n_tgt_lat=n_tgt_lat, H=H, W=W, n_components=3, fit_on='context'
)

# visualize PCA in video
display(visualise_video_pt(context_latents_pca_reshaped, fps=10, loop=True))


context_latents torch.Size([4096, 1024])
prediction_latents torch.Size([4096, 1024])
target_latents torch.Size([4096, 1024])


In [4]:
# display target latents
display(visualise_video_pt(target_latents_pca_reshaped, fps=10, loop=True))

In [5]:
# display prediction latents
display(visualise_video_pt(prediction_latents_pca_reshaped, fps=10, loop=True))