# DINO-Foresight Test Notebook (Trained Checkpoint)

This notebook is for quick qualitative testing after training:
- load checkpoint
- pull one OpenDV validation batch
- run one-step prediction
- optional multi-step rollout
- optional RGB decode for visualization

Update paths in the config cell first.

In [None]:
import os
from pathlib import Path
from types import SimpleNamespace

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from src.data import OpenDV_VideoData
from src.dino_f import Dino_f
from save_predicted_dino_features import add_missing_args, denormalize_images

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print('DEVICE =', DEVICE)

In [None]:
# ===== Config =====
# 1) Your trained checkpoint
DINO_F_CKPT = '/cpfs/pengyu/DINO-Foresight/logs/dino_foresight_lowres_opendv_drift_noallgather/REPLACE_ME/checkpoints/last.ckpt'

# 2) Data
OPENDV_ROOT = '/pfs/pengyu/OpenDV-YouTube'
OPENDV_LANG_ROOT = '/pfs/pengyu/OpenDV-YouTube-Language'
LANG_CACHE_VAL = '/pfs/pengyu/OpenDV-YouTube-Language/mini_val_cache.json'
OPENDV_LANG_FEAT_NAME = 'lang_clip_{start}_{end}.pt'

# 3) Optional decoder (set to None to skip RGB decoding)
DECODER_CKPT = None
DECODER_TYPE = 'from_dino'  # 'from_dino' or 'from_feats'

# 4) Eval settings
BATCH_SIZE = 1
NUM_WORKERS = 2
ROLLOUT_STEPS = 6
MAX_CLIPS = None

# Optional: offline DINO hub cache
os.environ.setdefault('DINO_REPO', '/cpfs/pengyu/.cache/torch/hub/facebookresearch_dinov2_main')

assert Path(DINO_F_CKPT).exists(), f'Checkpoint not found: {DINO_F_CKPT}'

In [None]:
# Load model
model = Dino_f.load_from_checkpoint(DINO_F_CKPT, strict=False, map_location='cpu').to(DEVICE)
model.eval()
model._init_feature_extractor()
if model.dino_v2 is not None:
    model.dino_v2 = model.dino_v2.to(DEVICE)
if model.eva2clip is not None:
    model.eva2clip = model.eva2clip.to(DEVICE)
if model.sam is not None:
    model.sam = model.sam.to(DEVICE)

print('feature_extractor:', model.args.feature_extractor)
print('img_size:', model.args.img_size)
print('sequence_length:', model.args.sequence_length)
print('use_language_condition:', getattr(model.args, 'use_language_condition', False))
print('use_precomputed_text:', getattr(model.args, 'use_precomputed_text', False))

In [None]:
# Build OpenDV validation dataloader (aligned with checkpoint args)
use_lang_cond = bool(getattr(model.args, 'use_language_condition', False))
use_precomputed_text = bool(getattr(model.args, 'use_precomputed_text', False))

args = SimpleNamespace(
    data_path=OPENDV_ROOT,
    opendv_root=OPENDV_ROOT,
    opendv_lang_root=OPENDV_LANG_ROOT,
    opendv_use_lang_annos=bool(OPENDV_LANG_ROOT),
    opendv_lang_cache_train=None,
    opendv_lang_cache_val=LANG_CACHE_VAL,
    opendv_use_lang_features=bool(use_lang_cond and use_precomputed_text),
    opendv_return_language=bool(use_lang_cond and (not use_precomputed_text)),
    opendv_lang_feat_name=OPENDV_LANG_FEAT_NAME,
    opendv_video_dir=None,
    opendv_max_clips=MAX_CLIPS,
    sequence_length=getattr(model.args, 'sequence_length', 5),
    img_size=tuple(getattr(model.args, 'img_size', (224, 448))),
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    num_workers_val=NUM_WORKERS,
    eval_mode=True,
    eval_midterm=False,
    eval_modality=None,
    use_language_condition=use_lang_cond,
    use_precomputed_text=use_precomputed_text,
)
args = add_missing_args(args, model.args)
args.feature_extractor = model.args.feature_extractor
args.dinov2_variant = getattr(model.args, 'dinov2_variant', 'vitb14_reg')
args.return_rgb_path = True

data = OpenDV_VideoData(arguments=args, subset='val', batch_size=BATCH_SIZE)
loader = data.val_dataloader()
batch = next(iter(loader))
print('Loaded one val batch.')

In [None]:
def parse_batch(batch):
    if not isinstance(batch, (list, tuple)):
        return batch, None, None, None, None

    frames = batch[0]
    gt_img = None
    text_tokens = None
    text_mask = None
    rgb_paths = None

    for item in batch[1:]:
        if torch.is_tensor(item):
            if item.ndim == 4 and item.shape[1] == 3:
                gt_img = item
            elif item.ndim == 3:
                text_tokens = item
            elif item.ndim == 2:
                text_mask = item
        elif isinstance(item, (list, tuple)) and len(item) > 0 and isinstance(item[0], (str, Path)):
            rgb_paths = [str(p) for p in item]

    return frames, gt_img, text_tokens, text_mask, rgb_paths


def _prepare_text_tokens_for_model(model, text_tokens, text_mask, device):
    if not getattr(model, 'use_language_condition', False):
        return None, None
    if text_tokens is None:
        return None, None

    text_tokens = text_tokens.to(device)
    if hasattr(model, 'text_proj'):
        in_dim = model.text_proj.in_features
        out_dim = model.text_proj.out_features
        if text_tokens.shape[-1] == in_dim:
            text_tokens = text_tokens.to(dtype=model.text_proj.weight.dtype)
            text_tokens = model.text_proj(text_tokens)
        elif text_tokens.shape[-1] != out_dim:
            raise ValueError(
                f'Unexpected text token dim {text_tokens.shape[-1]} (expected {in_dim} or {out_dim}).'
            )
    if text_mask is not None:
        text_mask = text_mask.to(device)
    return text_tokens, text_mask


@torch.no_grad()
def one_step_predict_feats(model, frames, text_tokens=None, text_mask=None):
    x = model.preprocess(frames)
    text_tokens, text_mask = _prepare_text_tokens_for_model(model, text_tokens, text_mask, x.device)

    masked_x, mask = model.get_mask_tokens(x, mode='full_mask', mask_frames=1)
    mask = mask.to(x.device)

    if model.args.vis_attn:
        _, x_pred, _ = model.forward(x, masked_x, mask, text_tokens=text_tokens, text_mask=text_mask)
    else:
        _, x_pred = model.forward(x, masked_x, mask, text_tokens=text_tokens, text_mask=text_mask)

    x_pred = model.postprocess(x_pred)
    return x_pred[:, -1]  # [B, Hf, Wf, C]


@torch.no_grad()
def rollout_predict_feats(model, frames, steps, text_tokens=None, text_mask=None):
    preds = []
    x = model.preprocess(frames)
    text_tokens, text_mask = _prepare_text_tokens_for_model(model, text_tokens, text_mask, x.device)

    for _ in range(steps):
        masked_x, mask = model.get_mask_tokens(x, mode='full_mask', mask_frames=1)
        mask = mask.to(x.device)

        if model.args.vis_attn:
            _, x_pred, _ = model.forward(x, masked_x, mask, text_tokens=text_tokens, text_mask=text_mask)
        else:
            _, x_pred = model.forward(x, masked_x, mask, text_tokens=text_tokens, text_mask=text_mask)

        x_pred = model.postprocess(x_pred)
        preds.append(x_pred[:, -1].detach().cpu())

        # shift context and append latest prediction
        x[:, :-1] = x[:, 1:].clone()
        x[:, -1] = x_pred[:, -1]

    return torch.stack(preds, dim=1)  # [B, S, Hf, Wf, C]

In [None]:
# Parse batch and run one-step
frames, gt_img, text_tokens, text_mask, rgb_paths = parse_batch(batch)
frames = frames.to(DEVICE)
if gt_img is not None:
    gt_img = gt_img.to(DEVICE)

pred_feats = one_step_predict_feats(model, frames, text_tokens=text_tokens, text_mask=text_mask)
print('pred_feats shape:', tuple(pred_feats.shape))

# Build gt feature for quick feature-space MSE
if gt_img is not None:
    with torch.no_grad():
        gt_feats = model.extract_features(gt_img)
        h = frames.shape[-2] // model.patch_size
        w = frames.shape[-1] // model.patch_size
        gt_feats = gt_feats.reshape(gt_feats.shape[0], h, w, -1)
        feat_mse = F.mse_loss(pred_feats, gt_feats)
    print('feature MSE (pred vs gt):', float(feat_mse))
else:
    gt_feats = None
    print('gt_img is None, skip feature MSE')

In [None]:
# Optional decoder for RGB visualization
decode_feats = None
if DECODER_CKPT is not None:
    if DECODER_TYPE == 'from_feats':
        from train_rgb_decoder_from_feats import FeatureRgbDecoder
        decoder = FeatureRgbDecoder.load_from_checkpoint(DECODER_CKPT, strict=False).to(DEVICE)
        decoder.eval()
        def decode_feats(feats_bhwc):
            return decoder(feats_bhwc.to(DEVICE)).detach().cpu()
    elif DECODER_TYPE == 'from_dino':
        from train_rgb_decoder import DinoV2RGBDecoder
        decoder = DinoV2RGBDecoder.load_from_checkpoint(DECODER_CKPT, strict=False, lpips_weight=0).to(DEVICE)
        decoder.eval()
        dpt_head = decoder.decoder
        feat_dim = decoder.emb_dim

        def decode_feats(feats_bhwc):
            feats_bhwc = feats_bhwc.to(DEVICE)
            b, h, w, c = feats_bhwc.shape
            if c % feat_dim != 0:
                raise ValueError(f'Feature dim mismatch: {c} not divisible by {feat_dim}')
            parts = torch.split(feats_bhwc, feat_dim, dim=-1)
            if len(parts) == 2:
                parts = [parts[0], parts[0], parts[1], parts[1]]
            elif len(parts) != 4:
                raise ValueError(f'Expected 2 or 4 feature chunks, got {len(parts)}')
            feat_list = [p.reshape(b, h * w, feat_dim) for p in parts]
            pred = dpt_head(feat_list, h, w)
            pred = F.interpolate(pred, size=tuple(model.args.img_size), mode='bicubic', align_corners=False)
            pred = torch.sigmoid(pred)
            return pred.detach().cpu()
    else:
        raise ValueError("DECODER_TYPE must be 'from_dino' or 'from_feats'")

print('decoder ready:', decode_feats is not None)

In [None]:
# Visualize context + one-step prediction (+ gt if available)
ctx = denormalize_images(frames[0].detach().cpu(), model.args.feature_extractor).cpu()  # [T, 3, H, W]

show_imgs = []
show_titles = []

show_imgs.append(ctx[-1])
show_titles.append('context last')

if gt_img is not None:
    gt_rgb = denormalize_images(gt_img.detach().cpu(), model.args.feature_extractor).cpu()[0]
    show_imgs.append(gt_rgb)
    show_titles.append('gt next')

if decode_feats is not None:
    pred_rgb = decode_feats(pred_feats)[0].clamp(0, 1)
    show_imgs.append(pred_rgb)
    show_titles.append('pred next')
else:
    # fallback: visualize feature norm map
    feat_norm = pred_feats[0].norm(dim=-1).detach().cpu()
    plt.figure(figsize=(4, 3))
    plt.imshow(feat_norm)
    plt.title('pred feature norm map')
    plt.axis('off')
    plt.show()

if len(show_imgs) > 0:
    plt.figure(figsize=(4 * len(show_imgs), 4))
    for i, (img, title) in enumerate(zip(show_imgs, show_titles), start=1):
        plt.subplot(1, len(show_imgs), i)
        plt.imshow(img.permute(1, 2, 0))
        plt.title(title)
        plt.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
# Optional: rollout visualization
preds_rollout = rollout_predict_feats(model, frames, ROLLOUT_STEPS, text_tokens=text_tokens, text_mask=text_mask)
print('preds_rollout shape:', tuple(preds_rollout.shape))

if decode_feats is None:
    print('Set DECODER_CKPT to visualize RGB rollout.')
else:
    cols = ROLLOUT_STEPS
    plt.figure(figsize=(3 * cols, 3))
    for t in range(ROLLOUT_STEPS):
        rgb_t = decode_feats(preds_rollout[:, t])[0].clamp(0, 1)
        ax = plt.subplot(1, cols, t + 1)
        ax.imshow(rgb_t.permute(1, 2, 0))
        ax.set_title(f'pred t+{t+1}')
        ax.axis('off')
    plt.tight_layout()
    plt.show()