# Visualize Drifting Model (Single Checkpoint)

This notebook visualizes a **single drifting model** (no baseline comparison):
- one-step predictions on the same input with multiple noise samples
- diversity maps (feature/RGB std)
- drift vector norm heatmap on one sample (if GT next frame is available)


In [None]:
import os
from pathlib import Path
from types import SimpleNamespace

import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from src.data import OpenDV_VideoData
from src.dino_f import Dino_f
from src.drifting_utils import compute_V, build_token_sample_ids
from save_predicted_dino_features import add_missing_args, denormalize_images

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print('DEVICE =', DEVICE)

In [None]:
# ===== Config =====
DINO_F_CKPT = '/cpfs/pengyu/DINO-Foresight/logs/dino_foresight_lowres_opendv_drift_noallgather/20260215_011651/checkpoints/epoch=9-step=535547.ckpt'

OPENDV_ROOT = '/pfs/pengyu/OpenDV-YouTube'
OPENDV_LANG_ROOT = '/pfs/pengyu/OpenDV-YouTube-Language'
LANG_CACHE_VAL = '/pfs/pengyu/OpenDV-YouTube-Language/mini_val_cache.json'
OPENDV_LANG_FEAT_NAME = 'lang_clip_{start}_{end}.pt'

BATCH_SIZE = 1
NUM_WORKERS = 2
NUM_NOISE_SAMPLES = 6  # number of stochastic one-step predictions on the same input
ROLLOUT_STEPS = 6      # optional long rollout visualization
OPENDV_RETURN_FUTURE_GT = True
OPENDV_EVAL_FUTURE_STEPS = max(0, ROLLOUT_STEPS - 1)

# Optional decoder for RGB visualization (set to None to visualize feature maps only)
DECODER_CKPT = '/cpfs/pengyu/DINO-Foresight/dino-foresight/j5ludt8t/checkpoints/epoch=20-step=189609.ckpt'
DECODER_TYPE = 'from_dino'  # 'from_dino' or 'from_feats'

# Optional: local DINOv2 hub cache
os.environ.setdefault('DINO_REPO', '/cpfs/pengyu/.cache/torch/hub/facebookresearch_dinov2_main')

assert Path(DINO_F_CKPT).exists(), f'Missing checkpoint: {DINO_F_CKPT}'


In [None]:
# Load drifting model
model = Dino_f.load_from_checkpoint(DINO_F_CKPT, strict=False, map_location='cpu').to(DEVICE)
model.eval()
model._init_feature_extractor()
if model.dino_v2 is not None:
    model.dino_v2 = model.dino_v2.to(DEVICE)
if model.eva2clip is not None:
    model.eva2clip = model.eva2clip.to(DEVICE)
if model.sam is not None:
    model.sam = model.sam.to(DEVICE)

print('use_drifting_loss:', getattr(model.args, 'use_drifting_loss', False))
print('drift_temperatures:', getattr(model.args, 'drift_temperatures', None))
print('drift_step_size:', getattr(model.args, 'drift_step_size', None))
print('drift_anchor_weight:', getattr(model.args, 'drift_anchor_weight', None))
print('feature_extractor:', model.args.feature_extractor)
print('img_size:', model.args.img_size)

In [None]:
# Build OpenDV val loader aligned to this checkpoint
use_lang_cond = bool(getattr(model.args, 'use_language_condition', False))
use_precomputed_text = bool(getattr(model.args, 'use_precomputed_text', False))

args = SimpleNamespace(
    data_path=OPENDV_ROOT,
    opendv_root=OPENDV_ROOT,
    opendv_lang_root=OPENDV_LANG_ROOT,
    opendv_use_lang_annos=bool(OPENDV_LANG_ROOT),
    opendv_lang_cache_train=None,
    opendv_lang_cache_val=LANG_CACHE_VAL,
    opendv_use_lang_features=bool(use_lang_cond and use_precomputed_text),
    opendv_return_language=bool(use_lang_cond and (not use_precomputed_text)),
    opendv_lang_feat_name=OPENDV_LANG_FEAT_NAME,
    opendv_video_dir=None,
    opendv_max_clips=None,
    sequence_length=getattr(model.args, 'sequence_length', 5),
    img_size=tuple(getattr(model.args, 'img_size', (224, 448))),
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    num_workers_val=NUM_WORKERS,
    eval_mode=True,
    eval_midterm=False,
    eval_modality=None,
    use_language_condition=use_lang_cond,
    use_precomputed_text=use_precomputed_text,
    opendv_return_future_gt=OPENDV_RETURN_FUTURE_GT,
    opendv_eval_future_steps=OPENDV_EVAL_FUTURE_STEPS,
)
args = add_missing_args(args, model.args)
args.feature_extractor = model.args.feature_extractor
args.dinov2_variant = getattr(model.args, 'dinov2_variant', 'vitb14_reg')
args.return_rgb_path = True

data = OpenDV_VideoData(arguments=args, subset='val', batch_size=BATCH_SIZE)
loader = data.val_dataloader()
batch = next(iter(loader))
print('Loaded one val batch')

In [None]:
def parse_batch(batch):
    if not isinstance(batch, (list, tuple)):
        return batch, None, None, None, None, None

    frames = batch[0]
    gt_img = None
    text_tokens = None
    text_mask = None
    future_gt = None
    rgb_paths = None

    for item in batch[1:]:
        if torch.is_tensor(item):
            if item.ndim == 4 and item.shape[1] == 3:
                gt_img = item
            elif item.ndim == 3:
                text_tokens = item
            elif item.ndim == 2:
                text_mask = item
            elif item.ndim == 5 and item.shape[2] == 3:
                future_gt = item
        elif isinstance(item, (list, tuple)) and len(item) > 0 and isinstance(item[0], (str, Path)):
            rgb_paths = [str(p) for p in item]

    return frames, gt_img, text_tokens, text_mask, future_gt, rgb_paths


def prepare_text_tokens(model, text_tokens, text_mask, device):
    if not getattr(model, 'use_language_condition', False):
        return None, None
    if text_tokens is None:
        return None, None

    tt = text_tokens.to(device)
    tm = text_mask.to(device) if text_mask is not None else None

    if hasattr(model, 'text_proj'):
        in_dim = model.text_proj.in_features
        out_dim = model.text_proj.out_features
        if tt.shape[-1] == in_dim:
            tt = tt.to(dtype=model.text_proj.weight.dtype)
            tt = model.text_proj(tt)
        elif tt.shape[-1] != out_dim:
            raise ValueError(f'Unexpected text token dim {tt.shape[-1]} (expected {in_dim} or {out_dim})')

    return tt, tm


@torch.no_grad()
def one_step_pred_feats(model, frames, text_tokens=None, text_mask=None):
    x = model.preprocess(frames)
    tt, tm = prepare_text_tokens(model, text_tokens, text_mask, x.device)

    masked_x, mask = model.get_mask_tokens(x, mode='full_mask', mask_frames=1)
    mask = mask.to(x.device)

    if model.args.vis_attn:
        _, x_pred, _ = model.forward(x, masked_x, mask, text_tokens=tt, text_mask=tm)
    else:
        _, x_pred = model.forward(x, masked_x, mask, text_tokens=tt, text_mask=tm)

    x_pred = model.postprocess(x_pred)
    return x_pred[:, -1], x, tt, tm  # [B,Hf,Wf,C], plus internals


@torch.no_grad()
def rollout_pred_feats(model, frames, steps, text_tokens=None, text_mask=None):
    x = model.preprocess(frames)
    tt, tm = prepare_text_tokens(model, text_tokens, text_mask, x.device)
    outs = []
    for _ in range(steps):
        masked_x, mask = model.get_mask_tokens(x, mode='full_mask', mask_frames=1)
        mask = mask.to(x.device)
        if model.args.vis_attn:
            _, x_pred_latent, _ = model.forward(x, masked_x, mask, text_tokens=tt, text_mask=tm)
        else:
            _, x_pred_latent = model.forward(x, masked_x, mask, text_tokens=tt, text_mask=tm)

        # Recurrence must stay in preprocess(latent) space.
        x[:, :-1] = x[:, 1:].clone()
        x[:, -1] = x_pred_latent[:, -1]

        # Export in original feature space for decoder/visualization.
        x_pred_post = model.postprocess(x_pred_latent)
        outs.append(x_pred_post[:, -1].detach().cpu())

    return torch.stack(outs, dim=1)


def gt_next_feats(model, gt_img, h, w):
    with torch.no_grad():
        feats = model.extract_features(gt_img)
        feats = feats.reshape(feats.shape[0], h, w, -1)
    return feats

In [None]:
# Optional decoder
decode_feats = None
if DECODER_CKPT is not None:
    if DECODER_TYPE == 'from_feats':
        from train_rgb_decoder_from_feats import FeatureRgbDecoder
        decoder = FeatureRgbDecoder.load_from_checkpoint(DECODER_CKPT, strict=False).to(DEVICE)
        decoder.eval()
        def decode_feats(feats_bhwc):
            return decoder(feats_bhwc.to(DEVICE)).detach().cpu()
    elif DECODER_TYPE == 'from_dino':
        from train_rgb_decoder import DinoV2RGBDecoder
        decoder = DinoV2RGBDecoder.load_from_checkpoint(DECODER_CKPT, strict=False, lpips_weight=0).to(DEVICE)
        decoder.eval()
        dpt_head = decoder.decoder
        feat_dim = decoder.emb_dim

        def decode_feats(feats_bhwc):
            feats_bhwc = feats_bhwc.to(DEVICE)
            b, h, w, c = feats_bhwc.shape
            if c % feat_dim != 0:
                raise ValueError(f'Feature dim mismatch: {c} not divisible by {feat_dim}')
            parts = torch.split(feats_bhwc, feat_dim, dim=-1)
            if len(parts) == 2:
                parts = [parts[0], parts[0], parts[1], parts[1]]
            elif len(parts) != 4:
                raise ValueError(f'Expected 2 or 4 feature chunks, got {len(parts)}')
            feat_list = [p.reshape(b, h * w, feat_dim) for p in parts]
            pred = dpt_head(feat_list, h, w)
            pred = F.interpolate(pred, size=tuple(model.args.img_size), mode='bicubic', align_corners=False)
            pred = torch.sigmoid(pred)
            return pred.detach().cpu()
    else:
        raise ValueError("DECODER_TYPE must be 'from_dino' or 'from_feats'")

print('decoder enabled:', decode_feats is not None)

In [None]:
# Run multiple one-step stochastic samples on the same input
frames, gt_img, text_tokens, text_mask, future_gt, rgb_paths = parse_batch(batch)
frames = frames.to(DEVICE)
if gt_img is not None:
    gt_img = gt_img.to(DEVICE)
if future_gt is not None:
    future_gt = future_gt.to(DEVICE)
    print('future_gt shape:', tuple(future_gt.shape))

pred_list = []
x_internal = None
for k in range(NUM_NOISE_SAMPLES):
    pred_k, x_internal, tt_internal, tm_internal = one_step_pred_feats(
        model, frames, text_tokens=text_tokens, text_mask=text_mask
    )
    pred_list.append(pred_k.detach().cpu())

pred_stack = torch.stack(pred_list, dim=1)  # [B, K, Hf, Wf, C]
print('pred_stack shape:', tuple(pred_stack.shape))

# Diversity in feature space
feat_std_map = pred_stack[0].std(dim=0).norm(dim=-1)  # [Hf, Wf]
print('feature std mean:', float(feat_std_map.mean()))

In [None]:
# Visualize: context/gt + sampled predictions + diversity map
ctx = denormalize_images(frames[0].detach().cpu(), model.args.feature_extractor).cpu()
ctx_last = ctx[-1]

if decode_feats is not None:
    pred_rgbs = [decode_feats(pred_stack[:, i])[0].clamp(0, 1) for i in range(pred_stack.shape[1])]

    show = [ctx_last]
    titles = ['context last']

    if gt_img is not None:
        gt_rgb = denormalize_images(gt_img.detach().cpu(), model.args.feature_extractor).cpu()[0]
        show.append(gt_rgb)
        titles.append('gt next')

    show.extend(pred_rgbs)
    titles.extend([f'sample {i+1}' for i in range(len(pred_rgbs))])

    plt.figure(figsize=(3 * len(show), 3.5))
    for i, (im, tt) in enumerate(zip(show, titles), start=1):
        ax = plt.subplot(1, len(show), i)
        ax.imshow(im.permute(1, 2, 0))
        ax.set_title(tt)
        ax.axis('off')
    plt.tight_layout()
    plt.show()

    rgb_stack = torch.stack(pred_rgbs, dim=0)  # [K,3,H,W]
    rgb_std = rgb_stack.std(dim=0).mean(dim=0)  # [H,W]
    plt.figure(figsize=(4, 4))
    plt.imshow(rgb_std, cmap='magma')
    plt.title('RGB std map across noise samples')
    plt.axis('off')
    plt.show()
else:
    # No decoder: show feature-space maps only
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(ctx_last.permute(1, 2, 0))
    plt.title('context last')
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(feat_std_map, cmap='magma')
    plt.title('feature std map across noise samples')
    plt.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
# Image reconstruction metrics (one-step, across noise samples)
if decode_feats is None:
    print('decoder is None: cannot compute image reconstruction metrics')
elif gt_img is None:
    print('gt_img is None: cannot compute image reconstruction metrics')
else:
    gt_rgb = denormalize_images(gt_img.detach().cpu(), model.args.feature_extractor).cpu()[0].clamp(0, 1)

    recon_list = []
    for i in range(pred_stack.shape[1]):
        recon = decode_feats(pred_stack[:, i])[0].clamp(0, 1).cpu()
        recon_list.append(recon)
    recon_stack = torch.stack(recon_list, dim=0)  # [K,3,H,W]

    mse = ((recon_stack - gt_rgb.unsqueeze(0)) ** 2).mean(dim=(1, 2, 3))
    l1 = (recon_stack - gt_rgb.unsqueeze(0)).abs().mean(dim=(1, 2, 3))
    psnr = 10.0 * torch.log10(1.0 / mse.clamp_min(1e-12))

    ssim_scores = None
    try:
        from torchmetrics.functional import structural_similarity_index_measure as ssim
        vals = []
        for i in range(recon_stack.shape[0]):
            vals.append(ssim(recon_stack[i:i+1], gt_rgb.unsqueeze(0), data_range=1.0).item())
        ssim_scores = torch.tensor(vals)
    except Exception:
        pass

    print('Per-sample one-step reconstruction metrics:')
    for i in range(recon_stack.shape[0]):
        msg = f'sample {i+1:02d} | MSE={mse[i].item():.6f} | L1={l1[i].item():.6f} | PSNR={psnr[i].item():.3f} dB'
        if ssim_scores is not None:
            msg += f' | SSIM={ssim_scores[i].item():.4f}'
        print(msg)

    print('--- summary over noise samples ---')
    print(f'MSE  mean={mse.mean().item():.6f} std={mse.std(unbiased=False).item():.6f}')
    print(f'L1   mean={l1.mean().item():.6f} std={l1.std(unbiased=False).item():.6f}')
    print(f'PSNR mean={psnr.mean().item():.3f} std={psnr.std(unbiased=False).item():.3f} dB')
    if ssim_scores is not None:
        print(f'SSIM mean={ssim_scores.mean().item():.4f} std={ssim_scores.std(unbiased=False).item():.4f}')
    else:
        print('SSIM unavailable (torchmetrics missing)')


In [None]:
# Drift vector norm heatmap on one sample (requires gt_img)
if gt_img is None:
    print('gt_img is None: skip V heatmap')
else:
    # use the first sampled pred as x
    x_pred = pred_stack[:, 0].to(DEVICE)  # [B,Hf,Wf,C]
    B, Hf, Wf, C = x_pred.shape
    n_tok = Hf * Wf

    # gt feature
    gt_f = gt_next_feats(model, gt_img, Hf, Wf).to(DEVICE)

    x_tokens = x_pred.reshape(B * n_tok, C).float()
    y_tokens = gt_f.reshape(B * n_tok, C).float()

    sample_ids = build_token_sample_ids(B, n_tok, x_tokens.device)
    V = compute_V(
        x=x_tokens,
        y_pos=y_tokens,
        y_neg=x_tokens,
        temperatures=getattr(model, 'drift_temperatures', (0.02, 0.05, 0.2)),
        x_sample_ids=sample_ids,
        y_pos_sample_ids=sample_ids,
        y_neg_sample_ids=sample_ids,
    )

    vnorm = V.norm(dim=-1).reshape(B, Hf, Wf)[0].detach().cpu()

    plt.figure(figsize=(4, 4))
    plt.imshow(vnorm, cmap='magma')
    plt.title('||V|| heatmap (token space)')
    plt.axis('off')
    plt.show()

In [None]:
# Optional rollout visualization (single model) + GT comparison
pred_roll = rollout_pred_feats(model, frames, ROLLOUT_STEPS, text_tokens=text_tokens, text_mask=text_mask)
print('pred_roll shape:', tuple(pred_roll.shape))

if decode_feats is None:
    print('Set DECODER_CKPT to visualize rollout RGBs.')
else:
    pred_roll_rgb = [decode_feats(pred_roll[:, t])[0].clamp(0, 1) for t in range(ROLLOUT_STEPS)]

    if gt_img is None:
        print('gt_img is None: show rollout only (no GT comparison).')
        plt.figure(figsize=(3 * ROLLOUT_STEPS, 3))
        for t in range(ROLLOUT_STEPS):
            ax = plt.subplot(1, ROLLOUT_STEPS, t + 1)
            ax.imshow(pred_roll_rgb[t].permute(1, 2, 0))
            ax.set_title(f'pred t+{t+1}')
            ax.axis('off')
        plt.tight_layout()
        plt.show()
    else:
        # Build GT rollout list: [gt_next] + future_gt[:]
        gt_roll_rgb = [denormalize_images(gt_img.detach().cpu(), model.args.feature_extractor).cpu()[0].clamp(0, 1)]
        if future_gt is not None and future_gt.shape[1] > 0:
            future_rgb = denormalize_images(future_gt[0].detach().cpu(), model.args.feature_extractor).cpu().clamp(0, 1)
            for i in range(future_rgb.shape[0]):
                gt_roll_rgb.append(future_rgb[i])

        n_compare = min(ROLLOUT_STEPS, len(gt_roll_rgb))
        if n_compare < ROLLOUT_STEPS:
            print(f'Only {n_compare} GT steps available for comparison.')

        plt.figure(figsize=(3 * ROLLOUT_STEPS, 6))
        for t in range(ROLLOUT_STEPS):
            ax1 = plt.subplot(2, ROLLOUT_STEPS, t + 1)
            if t < len(gt_roll_rgb):
                ax1.imshow(gt_roll_rgb[t].permute(1, 2, 0))
                ax1.set_title(f'GT t+{t+1}')
            else:
                ax1.text(0.5, 0.5, 'N/A', ha='center', va='center')
                ax1.set_title(f'GT t+{t+1}')
            ax1.axis('off')

            ax2 = plt.subplot(2, ROLLOUT_STEPS, ROLLOUT_STEPS + t + 1)
            ax2.imshow(pred_roll_rgb[t].permute(1, 2, 0))
            ax2.set_title(f'pred t+{t+1}')
            ax2.axis('off')
        plt.tight_layout()
        plt.show()

        # Quantitative metrics for available GT steps.
        print('Per-step rollout metrics:')
        for t in range(n_compare):
            pred_t = pred_roll_rgb[t]
            gt_t = gt_roll_rgb[t]
            mse = torch.mean((pred_t - gt_t) ** 2).item()
            l1 = torch.mean(torch.abs(pred_t - gt_t)).item()
            psnr = 10.0 * torch.log10(1.0 / torch.tensor(max(mse, 1e-12))).item()

            ssim_val = None
            try:
                from torchmetrics.functional import structural_similarity_index_measure as ssim
                ssim_val = ssim(pred_t.unsqueeze(0), gt_t.unsqueeze(0), data_range=1.0).item()
            except Exception:
                pass

            msg = f't+{t+1} | MSE={mse:.6f} | L1={l1:.6f} | PSNR={psnr:.3f} dB'
            if ssim_val is not None:
                msg += f' | SSIM={ssim_val:.4f}'
            print(msg)
