# DINO-Foresight Demo: Long Rollout + Language Swap

This notebook lets you:
1. Run long rollout predictions and decode future frames.
2. Swap language condition to test whether predictions change plausibly.

Notes:
- Fill in paths in the config cell.
- You can use a decoder trained from raw images (`train_rgb_decoder.py`) or from features (`train_rgb_decoder_from_feats.py`).


In [None]:
import os
from types import SimpleNamespace
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from src.data import OpenDV_VideoData
from src.dino_f import Dino_f
from save_predicted_dino_features import add_missing_args, denormalize_images
from train_rgb_decoder import DinoV2RGBDecoder
from train_rgb_decoder_from_feats import FeatureRgbDecoder


In [None]:
# ===== Config =====
DINO_F_CKPT_LANG = '/cpfs/pengyu/DINO-Foresight/dino-foresight/nbvdoouw/checkpoints/last.ckpt'
DINO_F_CKPT_NOLANG = '/cpfs/pengyu/DINO-Foresight/dino-foresight/klyq45xm/checkpoints/last.ckpt'
DECODER_CKPT = '/cpfs/pengyu/DINO-Foresight/dino-foresight/j5ludt8t/checkpoints/epoch=20-step=189609.ckpt'
DECODER_TYPE = 'from_dino'  # 'from_dino' or 'from_feats'
PCA_CKPT = '/cpfs/pengyu/DINO-Foresight/dinov2_pca_224_l[2,_5,_8,_11]_1152.pth'

OPENDV_ROOT = '/pfs/pengyu/OpenDV-YouTube'
OPENDV_LANG_ROOT = '/pfs/pengyu/OpenDV-YouTube-Language'
LANG_CACHE = '/pfs/pengyu/OpenDV-YouTube-Language/mini_val_cache.json'
OPENDV_LANG_FEAT_NAME = 'lang_clip_{start}_{end}.pt'

SEQUENCE_LENGTH = 5
IMG_SIZE = (196, 392)
ROLLOUT_STEPS = 6  # adjust rollout steps here
BATCH_SIZE = 1
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Optional: use DINO hub cache offline
os.environ.setdefault('DINO_REPO', '/cpfs/pengyu/.cache/torch/hub/facebookresearch_dinov2_main')

HF_HOME = os.environ.get('HF_HOME', '/cpfs/pengyu/hfcaches')  # contains hub/
CLIP_CACHE_DIR = os.environ.get('CLIP_CACHE_DIR', HF_HOME)
CLIP_LOCAL_ONLY = True



In [None]:
# Load DINO-Foresight predictors (with / without language)
model_lang = Dino_f.load_from_checkpoint(DINO_F_CKPT_LANG, strict=False, pca_ckpt=PCA_CKPT).to(DEVICE)
model_lang.eval()
model_lang._init_feature_extractor()
if model_lang.dino_v2 is not None:
    model_lang.dino_v2 = model_lang.dino_v2.to(DEVICE)
if model_lang.eva2clip is not None:
    model_lang.eva2clip = model_lang.eva2clip.to(DEVICE)
if model_lang.sam is not None:
    model_lang.sam = model_lang.sam.to(DEVICE)

model_nolang = Dino_f.load_from_checkpoint(DINO_F_CKPT_NOLANG, strict=False, pca_ckpt=PCA_CKPT).to(DEVICE)
model_nolang.eval()
model_nolang._init_feature_extractor()
if model_nolang.dino_v2 is not None:
    model_nolang.dino_v2 = model_nolang.dino_v2.to(DEVICE)
if model_nolang.eva2clip is not None:
    model_nolang.eva2clip = model_nolang.eva2clip.to(DEVICE)
if model_nolang.sam is not None:
    model_nolang.sam = model_nolang.sam.to(DEVICE)


In [None]:
# Build OpenDV dataloader with language features
args = SimpleNamespace(
    data_path=OPENDV_ROOT,
    opendv_root=OPENDV_ROOT,
    opendv_lang_root=OPENDV_LANG_ROOT,
    opendv_use_lang_annos=True,
    opendv_lang_cache_train=None,
    opendv_lang_cache_val=LANG_CACHE,
    opendv_use_lang_features=True,
    opendv_lang_feat_name=OPENDV_LANG_FEAT_NAME,
    opendv_video_dir=None,
    sequence_length=SEQUENCE_LENGTH,
    img_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    num_workers=4,
    subset='val',
    eval_mode=True,
    eval_midterm=False,
    eval_modality=None,
    use_language_condition=getattr(model_lang, 'use_language_condition', False),
)
args = add_missing_args(args, model_lang.args)
args.feature_extractor = model_lang.args.feature_extractor
args.dinov2_variant = getattr(model_lang.args, 'dinov2_variant', 'vitb14_reg')
args.return_rgb_path = True

data = OpenDV_VideoData(arguments=args, subset='val', batch_size=BATCH_SIZE)
loader = data.val_dataloader()
batch = next(iter(loader))


In [None]:
def parse_batch(batch):
    if isinstance(batch, (list, tuple)):
        frames = batch[0]
        gt_img = None
        text_tokens = None
        text_mask = None
        rgb_paths = None
        for item in batch[1:]:
            if not torch.is_tensor(item):
                if isinstance(item, (list, tuple)) and item:
                    first = item[0]
                    if isinstance(first, (str, Path)):
                        rgb_paths = [str(p) for p in item]
                continue
            if item.ndim == 4 and item.shape[1] == 3:
                gt_img = item
            elif item.ndim == 3:
                text_tokens = item
            elif item.ndim == 2:
                text_mask = item
        return frames, gt_img, text_tokens, text_mask, rgb_paths
    return batch, None, None, None, None

frames, gt_img, text_tokens, text_mask, rgb_paths = parse_batch(batch)
frames = frames.to(DEVICE)
if text_tokens is not None:
    text_tokens = text_tokens.to(DEVICE)
if text_mask is not None:
    text_mask = text_mask.to(DEVICE)

print('frames', tuple(frames.shape))
print('text_tokens', None if text_tokens is None else tuple(text_tokens.shape))
print('text_mask', None if text_mask is None else tuple(text_mask.shape))


In [None]:
# Show GT image (if available) and last context frame
ctx = denormalize_images(frames[0], model_lang.args.feature_extractor).cpu()
last_ctx = ctx[-1]
plt.figure(figsize=(6, 3))
plt.subplot(1, 2, 1)
plt.imshow(last_ctx.permute(1, 2, 0))
plt.title('last context frame')
plt.axis('off')

plt.subplot(1, 2, 2)
if gt_img is not None:
    gt_rgb = denormalize_images(gt_img, model_lang.args.feature_extractor).cpu()
    plt.imshow(gt_rgb[0].permute(1, 2, 0))
    plt.title('gt image')
else:
    plt.text(0.1, 0.5, 'gt_img is None', fontsize=12)
    plt.title('gt image')
plt.axis('off')
plt.tight_layout()


In [None]:
def rollout_predictions(model, frames, unroll_steps, text_tokens=None, text_mask=None):
    preds = []
    with torch.no_grad():
        x = model.preprocess(frames)
        if not getattr(model, 'use_language_condition', False):
            text_tokens = None
            text_mask = None
        if text_tokens is not None:
            text_tokens = text_tokens.to(x.device)
            if hasattr(model, 'text_proj'):
                in_dim = model.text_proj.in_features
                out_dim = model.text_proj.out_features
                if text_tokens.shape[-1] == in_dim:
                    text_tokens = text_tokens.to(dtype=model.text_proj.weight.dtype)
                    text_tokens = model.text_proj(text_tokens)
                elif text_tokens.shape[-1] != out_dim:
                    raise ValueError(
                        f'Unexpected text token dim {text_tokens.shape[-1]} (expected {in_dim} or {out_dim}).'
                    )
        if text_mask is not None:
            text_mask = text_mask.to(x.device)
        for _ in range(unroll_steps):
            masked_x, mask = model.get_mask_tokens(x, mode='full_mask', mask_frames=1)
            mask = mask.to(x.device)
            if model.args.vis_attn:
                _, final_tokens, _ = model.forward(
                    x, masked_x, mask, text_tokens=text_tokens, text_mask=text_mask
                )
            else:
                _, final_tokens = model.forward(
                    x, masked_x, mask, text_tokens=text_tokens, text_mask=text_mask
                )
            # Roll the context: drop oldest frame, append new prediction
            x[:, 0:-1] = x[:, 1:].clone()
            x[:, -1] = final_tokens[:, -1]
            pred_feats = model.postprocess(x)[:, -1]
            preds.append(pred_feats.detach().cpu())
    return torch.stack(preds, dim=1)


In [None]:
# Load decoder
if DECODER_TYPE == 'from_feats':
    decoder = FeatureRgbDecoder.load_from_checkpoint(DECODER_CKPT, strict=False).to(DEVICE)
    decoder.eval()
    def decode_feats(feats_bhwc):
        return decoder(feats_bhwc.to(DEVICE)).detach().cpu()
elif DECODER_TYPE == 'from_dino':
    # Uses the DPT head trained with online DINO features.
    decoder = DinoV2RGBDecoder.load_from_checkpoint(DECODER_CKPT, strict=False, lpips_weight=0).to(DEVICE)
    decoder.eval()
    dpt_head = decoder.decoder
    feat_dim = decoder.emb_dim
    def decode_feats(feats_bhwc):
        feats_bhwc = feats_bhwc.to(DEVICE)
        b, h, w, c = feats_bhwc.shape
        if c % feat_dim != 0:
            raise ValueError(f'Feature dim mismatch: {c} not divisible by {feat_dim}')
        parts = torch.split(feats_bhwc, feat_dim, dim=-1)
        if len(parts) == 2:
            parts = [parts[0], parts[0], parts[1], parts[1]]
        elif len(parts) != 4:
            raise ValueError(f'Expected 2 or 4 feature chunks, got {len(parts)}')
        feat_list = [p.reshape(b, h * w, feat_dim) for p in parts]
        pred = dpt_head(feat_list, h, w)
        pred = F.interpolate(pred, size=IMG_SIZE, mode='bicubic', align_corners=False)
        pred = torch.sigmoid(pred)
        return pred.detach().cpu()
else:
    raise ValueError('DECODER_TYPE must be from_feats or from_dino')


In [None]:
# Long rollout
preds = rollout_predictions(model_lang, frames, ROLLOUT_STEPS, text_tokens=text_tokens, text_mask=text_mask)
# preds: [B, S, H, W, C]
pred_rgb = []
for t in range(preds.shape[1]):
    rgb = decode_feats(preds[:, t])  # [B,3,H,W]
    pred_rgb.append(rgb[0])

# Visualize context frames + rollout
ctx = denormalize_images(frames[0], model_lang.args.feature_extractor).cpu()  # [T,3,H,W]
num_ctx = ctx.shape[0]
fig_cols = max(num_ctx + len(pred_rgb), 1)
plt.figure(figsize=(3 * fig_cols, 3))
for i in range(num_ctx):
    plt.subplot(1, fig_cols, i + 1)
    plt.imshow(ctx[i].permute(1, 2, 0))
    plt.title(f'ctx {i}')
    plt.axis('off')
for j, rgb in enumerate(pred_rgb):
    plt.subplot(1, fig_cols, num_ctx + j + 1)
    plt.imshow(rgb.permute(1, 2, 0))
    plt.title(f'pred {j+1}')
    plt.axis('off')
plt.tight_layout()


In [None]:
# Compare rollout WITH vs WITHOUT language condition
preds_lang = rollout_predictions(model_lang, frames, ROLLOUT_STEPS, text_tokens=text_tokens, text_mask=text_mask)
preds_nolang = rollout_predictions(model_nolang, frames, ROLLOUT_STEPS, text_tokens=None, text_mask=None)

rgb_lang = decode_feats(preds_lang[:, -1])[0]
rgb_nolang = decode_feats(preds_nolang[:, -1])[0]

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.imshow(rgb_lang.permute(1, 2, 0))
plt.title('with language (last step)')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(rgb_nolang.permute(1, 2, 0))
plt.title('no language (last step)')
plt.axis('off')
plt.tight_layout()


In [None]:
# Quality metrics per step + average (if gt_img is available)
def _psnr(pred, gt):
    mse = torch.mean((pred - gt) ** 2).clamp_min(1e-12)
    return (10.0 * torch.log10(1.0 / mse)).item()

def _try_ssim(pred, gt):
    try:
        from torchmetrics.functional import structural_similarity_index_measure as ssim
        return ssim(pred.unsqueeze(0), gt.unsqueeze(0), data_range=1.0).item()
    except Exception:
        return None

if gt_img is None:
    print('gt_img is None; skip metrics')
else:
    gt_rgb = denormalize_images(gt_img, model_lang.args.feature_extractor).cpu()[0]
    psnr_lang_list, psnr_nolang_list = [], []
    ssim_lang_list, ssim_nolang_list = [], []

    for t in range(ROLLOUT_STEPS):
        pred_lang_t = decode_feats(preds_lang[:, t])[0].clamp(0, 1).cpu()
        pred_nolang_t = decode_feats(preds_nolang[:, t])[0].clamp(0, 1).cpu()

        psnr_lang = _psnr(pred_lang_t, gt_rgb)
        psnr_nolang = _psnr(pred_nolang_t, gt_rgb)
        psnr_lang_list.append(psnr_lang)
        psnr_nolang_list.append(psnr_nolang)

        ssim_lang = _try_ssim(pred_lang_t, gt_rgb)
        ssim_nolang = _try_ssim(pred_nolang_t, gt_rgb)
        if ssim_lang is not None and ssim_nolang is not None:
            ssim_lang_list.append(ssim_lang)
            ssim_nolang_list.append(ssim_nolang)

        print(f'step {t+1:02d} | PSNR lang {psnr_lang:.2f} dB | PSNR nolang {psnr_nolang:.2f} dB')
        if ssim_lang is not None and ssim_nolang is not None:
            print(f'          | SSIM lang {ssim_lang:.4f}    | SSIM nolang {ssim_nolang:.4f}')

    print('--- average over steps ---')
    print(f'PSNR lang   : {sum(psnr_lang_list)/len(psnr_lang_list):.2f} dB')
    print(f'PSNR nolang : {sum(psnr_nolang_list)/len(psnr_nolang_list):.2f} dB')
    if ssim_lang_list and ssim_nolang_list:
        print(f'SSIM lang   : {sum(ssim_lang_list)/len(ssim_lang_list):.4f}')
        print(f'SSIM nolang : {sum(ssim_nolang_list)/len(ssim_nolang_list):.4f}')
    else:
        print('SSIM not available (torchmetrics not installed)')


In [None]:
# Compare full rollout (all steps) between lang and no-lang models
preds_lang = rollout_predictions(model_lang, frames, ROLLOUT_STEPS, text_tokens=text_tokens, text_mask=text_mask)
preds_nolang = rollout_predictions(model_nolang, frames, ROLLOUT_STEPS, text_tokens=None, text_mask=None)

rgb_lang = [decode_feats(preds_lang[:, t])[0] for t in range(preds_lang.shape[1])]
rgb_nolang = [decode_feats(preds_nolang[:, t])[0] for t in range(preds_nolang.shape[1])]

cols = ROLLOUT_STEPS
plt.figure(figsize=(3 * cols, 6))
for t in range(ROLLOUT_STEPS):
    plt.subplot(2, cols, t + 1)
    plt.imshow(rgb_lang[t].permute(1, 2, 0))
    plt.title(f'lang t{t+1}')
    plt.axis('off')

    plt.subplot(2, cols, cols + t + 1)
    plt.imshow(rgb_nolang[t].permute(1, 2, 0))
    plt.title(f'nolang t{t+1}')
    plt.axis('off')

plt.tight_layout()


In [None]:
# Per-step comparison with GT (lang + nolang)
if gt_img is None:
    print('gt_img is None; skip GT comparison')
else:
    gt_rgb = denormalize_images(gt_img, model_lang.args.feature_extractor).cpu()[0]
    cols = ROLLOUT_STEPS
    plt.figure(figsize=(3 * cols, 9))
    for t in range(ROLLOUT_STEPS):
        # Row 1: GT (same for all steps)
        ax = plt.subplot(3, cols, t + 1)
        ax.imshow(gt_rgb.permute(1, 2, 0))
        if t == 0:
            ax.set_ylabel('GT', rotation=0, labelpad=40, fontsize=10)
        ax.set_title(f't{t+1}')
        ax.axis('off')

        # Row 2: lang
        ax = plt.subplot(3, cols, cols + t + 1)
        try:
            rgb_lang_t = rgb_lang[t]
        except Exception:
            rgb_lang_t = decode_feats(preds_lang[:, t])[0]
        ax.imshow(rgb_lang_t.permute(1, 2, 0))
        if t == 0:
            ax.set_ylabel('lang', rotation=0, labelpad=40, fontsize=10)
        ax.axis('off')

        # Row 3: nolang
        ax = plt.subplot(3, cols, 2 * cols + t + 1)
        try:
            rgb_nolang_t = rgb_nolang[t]
        except Exception:
            rgb_nolang_t = decode_feats(preds_nolang[:, t])[0]
        ax.imshow(rgb_nolang_t.permute(1, 2, 0))
        if t == 0:
            ax.set_ylabel('nolang', rotation=0, labelpad=40, fontsize=10)
        ax.axis('off')

    plt.tight_layout()


In [None]:
# Language condition swap
# Take another batch and swap text tokens to see effect
batch_b = next(iter(loader))
frames_b, gt_b, text_b, mask_b, _ = parse_batch(batch_b)
frames_b = frames_b.to(DEVICE)
if text_b is not None:
    text_b = text_b.to(DEVICE)
if mask_b is not None:
    mask_b = mask_b.to(DEVICE)

preds_a = rollout_predictions(model_lang, frames, ROLLOUT_STEPS, text_tokens=text_tokens, text_mask=text_mask)
preds_swap = rollout_predictions(model_lang, frames, ROLLOUT_STEPS, text_tokens=text_b, text_mask=mask_b)

rgb_a = decode_feats(preds_a[:, -1])[0]
rgb_swap = decode_feats(preds_swap[:, -1])[0]

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.imshow(rgb_a.permute(1, 2, 0))
plt.title('lang A (last step)')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(rgb_swap.permute(1, 2, 0))
plt.title('lang B (last step)')
plt.axis('off')
plt.tight_layout()


In [None]:
# Prompt-based rollout (online CLIP embeddings)
from transformers import CLIPTokenizer, CLIPTextModel
import os.path as osp

def _resolve_clip_sources(model_id, cache_dir, local_only):
    if osp.isdir(model_id):
        return model_id, model_id
    offline = local_only or os.environ.get('HF_HUB_OFFLINE') or os.environ.get('TRANSFORMERS_OFFLINE')
    if not offline:
        return model_id, model_id
    cache_root = cache_dir or os.environ.get('HF_HOME')
    if cache_root and osp.basename(cache_root.rstrip('/')) != 'hub':
        if osp.isdir(osp.join(cache_root, 'hub')):
            cache_root = osp.join(cache_root, 'hub')
    if not cache_root:
        return model_id, model_id
    model_dir = osp.join(cache_root, 'models--' + model_id.replace('/', '--'))
    snapshots_dir = osp.join(model_dir, 'snapshots')
    if not osp.isdir(snapshots_dir):
        return model_id, model_id
    snapshot_paths = [
        osp.join(snapshots_dir, name)
        for name in os.listdir(snapshots_dir)
        if osp.isdir(osp.join(snapshots_dir, name))
    ]
    tokenizer_path = None
    text_model_path = None
    for snap in snapshot_paths:
        if tokenizer_path is None and (
            osp.isfile(osp.join(snap, 'tokenizer.json'))
            or osp.isfile(osp.join(snap, 'vocab.json'))
        ):
            tokenizer_path = snap
        if text_model_path is None and osp.isfile(osp.join(snap, 'model.safetensors')):
            text_model_path = snap
    return tokenizer_path or model_id, text_model_path or model_id

clip_model_name = getattr(model_lang.args, 'clip_model_name', 'openai/clip-vit-base-patch32')
clip_cache_dir = CLIP_CACHE_DIR
clip_local_only = CLIP_LOCAL_ONLY
clip_max_length = getattr(model_lang.args, 'clip_max_length', 77)

tok_src, txt_src = _resolve_clip_sources(clip_model_name, clip_cache_dir, clip_local_only)
tokenizer = CLIPTokenizer.from_pretrained(tok_src, cache_dir=clip_cache_dir, local_files_only=clip_local_only)
text_model = CLIPTextModel.from_pretrained(txt_src, use_safetensors=True, cache_dir=clip_cache_dir, local_files_only=clip_local_only)
text_model.eval().to(DEVICE)

plain_caption_dict = {
    0: 'Go straight.',
    1: 'Pass the intersection.',
    2: 'Turn left.',
    3: 'Turn right.',
    4: 'Change to the left lane.',
    5: 'Change to the right lane.',
    6: 'Go to the left lane branch.',
    7: 'Go to the right lane branch.',
    8: 'Pass the crosswalk.',
    9: 'Pass the railroad.',
    10: 'Merge.',
    11: 'Make a U-turn.',
    12: 'Stop.',
    13: 'Deviate.',
}

SELECTED_CMDS = [0, 2, 11]  # choose from keys above
PROMPTS = [plain_caption_dict[i] for i in SELECTED_CMDS]

tokens = tokenizer(PROMPTS, padding='max_length', truncation=True, max_length=clip_max_length, return_tensors='pt')
input_ids = tokens['input_ids'].to(DEVICE)
attention_mask = tokens['attention_mask'].to(DEVICE)
attention_mask_bool = attention_mask.to(dtype=torch.bool)
with torch.no_grad():
    outputs = text_model(input_ids=input_ids, attention_mask=attention_mask)
text_tokens_prompt = outputs.last_hidden_state

# Repeat frames for each prompt
frames_prompt = frames.repeat(len(PROMPTS), 1, 1, 1, 1)
preds_prompt = rollout_predictions(
    model_lang,
    frames_prompt,
    ROLLOUT_STEPS,
    text_tokens=text_tokens_prompt,
    text_mask=attention_mask_bool,
)

# Visualize multi-step outputs per prompt
rows = len(PROMPTS)
cols = ROLLOUT_STEPS
plt.figure(figsize=(3 * cols, 3 * rows))
for i, prompt in enumerate(PROMPTS):
    for t in range(ROLLOUT_STEPS):
        rgb = decode_feats(preds_prompt[i:i+1, t])[0]
        ax = plt.subplot(rows, cols, i * cols + t + 1)
        ax.imshow(rgb.permute(1, 2, 0))
        if t == 0:
            ax.set_ylabel(prompt, rotation=0, labelpad=40, fontsize=10)
        ax.set_title(f't{t+1}')
        ax.axis('off')
plt.tight_layout()

# Quick feature-diff sanity check between first two prompts
if len(PROMPTS) >= 2:
    preds_a = rollout_predictions(model_lang, frames, 1, text_tokens=text_tokens_prompt[:1], text_mask=attention_mask_bool[:1])
    preds_b = rollout_predictions(model_lang, frames, 1, text_tokens=text_tokens_prompt[1:2], text_mask=attention_mask_bool[1:2])
    print('feat diff (step1):', (preds_a - preds_b).abs().mean().item())
else:
    print('Need >=2 prompts for feature-diff check')


## Notes / Troubleshooting
- If you see GitHub access errors for DINOv2, set `DINO_REPO` to a local cache path.
- If `lpips` is not installed, pass `lpips_weight=0` when loading the decoder.
- If `text_tokens` is None, ensure `opendv_use_lang_annos` and `opendv_use_lang_features` are enabled and your language features are available.