# Drifting Effect Test (A/B Comparison)

Compare **drifting model** vs **non-drifting baseline** on the same OpenDV validation batches.

Outputs:
- one-step feature MSE (pred feature vs GT feature)
- one-step feature cosine similarity
- optional RGB PSNR/SSIM (if decoder is provided)
- qualitative side-by-side visualization


In [None]:
import os
from pathlib import Path
from types import SimpleNamespace

import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from src.data import OpenDV_VideoData
from src.dino_f import Dino_f
from save_predicted_dino_features import add_missing_args, denormalize_images

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print('DEVICE =', DEVICE)

In [None]:
# ===== Config =====
# Required: two checkpoints trained with same backbone/data settings
CKPT_DRIFT = '/cpfs/pengyu/DINO-Foresight/logs/dino_foresight_lowres_opendv_drift_noallgather/REPLACE_ME/checkpoints/last.ckpt'
CKPT_BASELINE = '/cpfs/pengyu/DINO-Foresight/logs/dino_foresight_lowres_opendv_baseline/REPLACE_ME/checkpoints/last.ckpt'

# Data
OPENDV_ROOT = '/pfs/pengyu/OpenDV-YouTube'
OPENDV_LANG_ROOT = '/pfs/pengyu/OpenDV-YouTube-Language'
LANG_CACHE_VAL = '/pfs/pengyu/OpenDV-YouTube-Language/mini_val_cache.json'
OPENDV_LANG_FEAT_NAME = 'lang_clip_{start}_{end}.pt'

# Eval setup
BATCH_SIZE = 1
NUM_WORKERS = 2
NUM_EVAL_BATCHES = 50
ROLLOUT_STEPS = 6  # qualitative only

# Optional decoder for RGB metrics/visualization
DECODER_CKPT = None
DECODER_TYPE = 'from_dino'  # 'from_dino' or 'from_feats'
SAVE_VIS_DIR = None  # e.g. './tmp_decode_vis' to save decoded RGB images

# Optional: offline DINOv2 hub path
os.environ.setdefault('DINO_REPO', '/cpfs/pengyu/.cache/torch/hub/facebookresearch_dinov2_main')

assert Path(CKPT_DRIFT).exists(), f'Missing CKPT_DRIFT: {CKPT_DRIFT}'
assert Path(CKPT_BASELINE).exists(), f'Missing CKPT_BASELINE: {CKPT_BASELINE}'


In [None]:
def load_model(ckpt_path):
    model = Dino_f.load_from_checkpoint(ckpt_path, strict=False, map_location='cpu').to(DEVICE)
    model.eval()
    model._init_feature_extractor()
    if model.dino_v2 is not None:
        model.dino_v2 = model.dino_v2.to(DEVICE)
    if model.eva2clip is not None:
        model.eva2clip = model.eva2clip.to(DEVICE)
    if model.sam is not None:
        model.sam = model.sam.to(DEVICE)
    return model

model_drift = load_model(CKPT_DRIFT)
model_base = load_model(CKPT_BASELINE)

print('drift use_drifting_loss:', getattr(model_drift.args, 'use_drifting_loss', False))
print('base  use_drifting_loss:', getattr(model_base.args, 'use_drifting_loss', False))
print('drift feature_extractor:', model_drift.args.feature_extractor)
print('base  feature_extractor:', model_base.args.feature_extractor)
print('drift img_size:', model_drift.args.img_size)
print('base  img_size:', model_base.args.img_size)

In [None]:
# Build one shared validation loader (use drift model args as source of truth)
m = model_drift
use_lang_cond = bool(getattr(m.args, 'use_language_condition', False))
use_precomputed_text = bool(getattr(m.args, 'use_precomputed_text', False))

args = SimpleNamespace(
    data_path=OPENDV_ROOT,
    opendv_root=OPENDV_ROOT,
    opendv_lang_root=OPENDV_LANG_ROOT,
    opendv_use_lang_annos=bool(OPENDV_LANG_ROOT),
    opendv_lang_cache_train=None,
    opendv_lang_cache_val=LANG_CACHE_VAL,
    opendv_use_lang_features=bool(use_lang_cond and use_precomputed_text),
    opendv_return_language=bool(use_lang_cond and (not use_precomputed_text)),
    opendv_lang_feat_name=OPENDV_LANG_FEAT_NAME,
    opendv_video_dir=None,
    opendv_max_clips=None,
    sequence_length=getattr(m.args, 'sequence_length', 5),
    img_size=tuple(getattr(m.args, 'img_size', (224, 448))),
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    num_workers_val=NUM_WORKERS,
    eval_mode=True,
    eval_midterm=False,
    eval_modality=None,
    use_language_condition=use_lang_cond,
    use_precomputed_text=use_precomputed_text,
)
args = add_missing_args(args, m.args)
args.feature_extractor = m.args.feature_extractor
args.dinov2_variant = getattr(m.args, 'dinov2_variant', 'vitb14_reg')
args.return_rgb_path = True

data = OpenDV_VideoData(arguments=args, subset='val', batch_size=BATCH_SIZE)
loader = data.val_dataloader()
print('loader ready')

In [None]:
def parse_batch(batch):
    if not isinstance(batch, (list, tuple)):
        return batch, None, None, None, None

    frames = batch[0]
    gt_img = None
    text_tokens = None
    text_mask = None
    rgb_paths = None

    for item in batch[1:]:
        if torch.is_tensor(item):
            if item.ndim == 4 and item.shape[1] == 3:
                gt_img = item
            elif item.ndim == 3:
                text_tokens = item
            elif item.ndim == 2:
                text_mask = item
        elif isinstance(item, (list, tuple)) and len(item) > 0 and isinstance(item[0], (str, Path)):
            rgb_paths = [str(p) for p in item]

    return frames, gt_img, text_tokens, text_mask, rgb_paths


def prepare_text_for_model(model, text_tokens, text_mask, device):
    if not getattr(model, 'use_language_condition', False):
        return None, None
    if text_tokens is None:
        return None, None

    tt = text_tokens.to(device)
    tm = text_mask.to(device) if text_mask is not None else None
    if hasattr(model, 'text_proj'):
        in_dim = model.text_proj.in_features
        out_dim = model.text_proj.out_features
        if tt.shape[-1] == in_dim:
            tt = tt.to(dtype=model.text_proj.weight.dtype)
            tt = model.text_proj(tt)
        elif tt.shape[-1] != out_dim:
            raise ValueError(f'Unexpected text token dim {tt.shape[-1]} (expected {in_dim} or {out_dim})')
    return tt, tm


@torch.no_grad()
def one_step_pred_feats(model, frames, text_tokens=None, text_mask=None):
    x = model.preprocess(frames)
    tt, tm = prepare_text_for_model(model, text_tokens, text_mask, x.device)

    masked_x, mask = model.get_mask_tokens(x, mode='full_mask', mask_frames=1)
    mask = mask.to(x.device)
    if model.args.vis_attn:
        _, x_pred, _ = model.forward(x, masked_x, mask, text_tokens=tt, text_mask=tm)
    else:
        _, x_pred = model.forward(x, masked_x, mask, text_tokens=tt, text_mask=tm)

    x_pred = model.postprocess(x_pred)
    return x_pred[:, -1]


def gt_next_feats(model, gt_img, h, w):
    with torch.no_grad():
        feats = model.extract_features(gt_img)
        feats = feats.reshape(feats.shape[0], h, w, -1)
    return feats


def psnr(pred, gt):
    mse = torch.mean((pred - gt) ** 2).clamp_min(1e-12)
    return (10.0 * torch.log10(1.0 / mse)).item()


def try_ssim(pred, gt):
    try:
        from torchmetrics.functional import structural_similarity_index_measure as ssim
        return ssim(pred.unsqueeze(0), gt.unsqueeze(0), data_range=1.0).item()
    except Exception:
        return None

In [None]:
# Optional decoder
decode_feats = None
if DECODER_CKPT is not None:
    if DECODER_TYPE == 'from_feats':
        from train_rgb_decoder_from_feats import FeatureRgbDecoder
        decoder = FeatureRgbDecoder.load_from_checkpoint(DECODER_CKPT, strict=False).to(DEVICE)
        decoder.eval()
        def decode_feats(feats_bhwc):
            return decoder(feats_bhwc.to(DEVICE)).detach().cpu()
    elif DECODER_TYPE == 'from_dino':
        from train_rgb_decoder import DinoV2RGBDecoder
        decoder = DinoV2RGBDecoder.load_from_checkpoint(DECODER_CKPT, strict=False, lpips_weight=0).to(DEVICE)
        decoder.eval()
        dpt_head = decoder.decoder
        feat_dim = decoder.emb_dim

        def decode_feats(feats_bhwc):
            feats_bhwc = feats_bhwc.to(DEVICE)
            b, h, w, c = feats_bhwc.shape
            if c % feat_dim != 0:
                raise ValueError(f'Feature dim mismatch: {c} not divisible by {feat_dim}')
            parts = torch.split(feats_bhwc, feat_dim, dim=-1)
            if len(parts) == 2:
                parts = [parts[0], parts[0], parts[1], parts[1]]
            elif len(parts) != 4:
                raise ValueError(f'Expected 2 or 4 feature chunks, got {len(parts)}')
            feat_list = [p.reshape(b, h * w, feat_dim) for p in parts]
            pred = dpt_head(feat_list, h, w)
            pred = F.interpolate(pred, size=tuple(model_drift.args.img_size), mode='bicubic', align_corners=False)
            pred = torch.sigmoid(pred)
            return pred.detach().cpu()
    else:
        raise ValueError("DECODER_TYPE must be 'from_dino' or 'from_feats'")

print('decoder enabled:', decode_feats is not None)

In [None]:
# Quantitative A/B test on NUM_EVAL_BATCHES
mse_drift, mse_base = [], []
cos_drift, cos_base = [], []
psnr_drift, psnr_base = [], []
ssim_drift, ssim_base = [], []

last_sample = None

for i, batch in enumerate(loader):
    if i >= NUM_EVAL_BATCHES:
        break

    frames, gt_img, text_tokens, text_mask, rgb_paths = parse_batch(batch)
    frames = frames.to(DEVICE)
    if gt_img is None:
        continue
    gt_img = gt_img.to(DEVICE)

    pred_d = one_step_pred_feats(model_drift, frames, text_tokens=text_tokens, text_mask=text_mask)
    pred_b = one_step_pred_feats(model_base, frames, text_tokens=text_tokens, text_mask=text_mask)

    h = frames.shape[-2] // model_drift.patch_size
    w = frames.shape[-1] // model_drift.patch_size
    gt_f = gt_next_feats(model_drift, gt_img, h, w)

    mse_drift.append(F.mse_loss(pred_d, gt_f).item())
    mse_base.append(F.mse_loss(pred_b, gt_f).item())

    cos_drift.append(F.cosine_similarity(pred_d.reshape(pred_d.shape[0], -1), gt_f.reshape(gt_f.shape[0], -1), dim=-1).mean().item())
    cos_base.append(F.cosine_similarity(pred_b.reshape(pred_b.shape[0], -1), gt_f.reshape(gt_f.shape[0], -1), dim=-1).mean().item())

    if decode_feats is not None:
        gt_rgb = denormalize_images(gt_img.detach().cpu(), model_drift.args.feature_extractor).cpu()[0].clamp(0, 1)
        pd_rgb = decode_feats(pred_d)[0].clamp(0, 1)
        pb_rgb = decode_feats(pred_b)[0].clamp(0, 1)

        psnr_drift.append(psnr(pd_rgb, gt_rgb))
        psnr_base.append(psnr(pb_rgb, gt_rgb))

        sd = try_ssim(pd_rgb, gt_rgb)
        sb = try_ssim(pb_rgb, gt_rgb)
        if sd is not None and sb is not None:
            ssim_drift.append(sd)
            ssim_base.append(sb)

    if i == 0:
        last_sample = (frames.detach().cpu(), gt_img.detach().cpu(), pred_d.detach().cpu(), pred_b.detach().cpu())

print('evaluated batches:', len(mse_drift))

if len(mse_drift) == 0:
    raise RuntimeError('No valid batches evaluated. Check val loader output format.')

print('=== Feature-space metrics (lower MSE / higher cosine better) ===')
print(f'MSE   drift: {np.mean(mse_drift):.6f} | base: {np.mean(mse_base):.6f} | delta(base-drift): {np.mean(mse_base)-np.mean(mse_drift):.6f}')
print(f'Cos   drift: {np.mean(cos_drift):.6f} | base: {np.mean(cos_base):.6f} | delta(drift-base): {np.mean(cos_drift)-np.mean(cos_base):.6f}')

if len(psnr_drift) > 0:
    print('=== RGB metrics (higher better) ===')
    print(f'PSNR  drift: {np.mean(psnr_drift):.3f} | base: {np.mean(psnr_base):.3f} | delta(drift-base): {np.mean(psnr_drift)-np.mean(psnr_base):.3f}')
if len(ssim_drift) > 0:
    print(f'SSIM  drift: {np.mean(ssim_drift):.4f} | base: {np.mean(ssim_base):.4f} | delta(drift-base): {np.mean(ssim_drift)-np.mean(ssim_base):.4f}')

In [None]:
# Qualitative one-step visualization (first evaluated batch)
if last_sample is None:
    print('No sample cached for visualization.')
else:
    frames_cpu, gt_img_cpu, pred_d_cpu, pred_b_cpu = last_sample
    ctx_last = denormalize_images(frames_cpu[0], model_drift.args.feature_extractor).cpu()[-1]
    gt_rgb = denormalize_images(gt_img_cpu, model_drift.args.feature_extractor).cpu()[0]

    imgs = [ctx_last, gt_rgb]
    titles = ['context last', 'gt next']

    if decode_feats is not None:
        pd_rgb = decode_feats(pred_d_cpu.to(DEVICE))[0].clamp(0, 1)
        pb_rgb = decode_feats(pred_b_cpu.to(DEVICE))[0].clamp(0, 1)
        imgs += [pd_rgb, pb_rgb]
        titles += ['pred drift', 'pred baseline']

        diff = (pd_rgb - pb_rgb).abs().mean(dim=0)
        plt.figure(figsize=(4, 4))
        plt.imshow(diff, cmap='magma')
        plt.title('abs diff map: drift vs baseline')
        plt.axis('off')
        plt.show()
    else:
        # fallback feature-norm map
        d_norm = pred_d_cpu[0].norm(dim=-1)
        b_norm = pred_b_cpu[0].norm(dim=-1)
        plt.figure(figsize=(8, 3))
        plt.subplot(1, 2, 1)
        plt.imshow(d_norm)
        plt.title('drift pred feat norm')
        plt.axis('off')
        plt.subplot(1, 2, 2)
        plt.imshow(b_norm)
        plt.title('base pred feat norm')
        plt.axis('off')
        plt.tight_layout()
        plt.show()



    if decode_feats is not None and SAVE_VIS_DIR is not None:
        save_dir = Path(SAVE_VIS_DIR)
        save_dir.mkdir(parents=True, exist_ok=True)
        plt.imsave(save_dir / 'one_step_context_last.png', ctx_last.permute(1, 2, 0).cpu().numpy().clip(0, 1))
        plt.imsave(save_dir / 'one_step_gt_next.png', gt_rgb.permute(1, 2, 0).cpu().numpy().clip(0, 1))
        plt.imsave(save_dir / 'one_step_pred_drift.png', pd_rgb.permute(1, 2, 0).cpu().numpy().clip(0, 1))
        plt.imsave(save_dir / 'one_step_pred_baseline.png', pb_rgb.permute(1, 2, 0).cpu().numpy().clip(0, 1))
        plt.imsave(save_dir / 'one_step_abs_diff_mean.png', diff.cpu().numpy(), cmap='magma')
        print(f'Saved one-step decoded images to: {save_dir}')

    plt.figure(figsize=(4 * len(imgs), 4))
    for i, (im, tt) in enumerate(zip(imgs, titles), start=1):
        plt.subplot(1, len(imgs), i)
        plt.imshow(im.permute(1, 2, 0))
        plt.title(tt)
        plt.axis('off')
    plt.tight_layout()
    plt.show()


In [None]:
# Optional: qualitative rollout side-by-side (no GT for later steps)
@torch.no_grad()
def rollout(model, frames, steps, text_tokens=None, text_mask=None):
    # Keep recurrent state in model's internal feature space (PCA space if enabled).
    x = model.preprocess(frames)
    tt, tm = prepare_text_for_model(model, text_tokens, text_mask, x.device)
    outs = []
    for _ in range(steps):
        masked_x, mask = model.get_mask_tokens(x, mode='full_mask', mask_frames=1)
        mask = mask.to(x.device)
        if model.args.vis_attn:
            _, x_pred_internal, _ = model.forward(x, masked_x, mask, text_tokens=tt, text_mask=tm)
        else:
            _, x_pred_internal = model.forward(x, masked_x, mask, text_tokens=tt, text_mask=tm)

        # Save decoded/postprocessed prediction for visualization metrics.
        x_pred_out = model.postprocess(x_pred_internal)
        outs.append(x_pred_out[:, -1].detach().cpu())

        # Recurrent update must stay in internal space to avoid PCA dim mismatch.
        x[:, :-1] = x[:, 1:].clone()
        x[:, -1] = x_pred_internal[:, -1]
    return torch.stack(outs, dim=1)

if last_sample is None or decode_feats is None:
    print('Need cached sample + decoder to show rollout.')
else:
    frames_cpu, _, _, _ = last_sample
    frames0 = frames_cpu.to(DEVICE)

    # Re-parse the first batch text tokens to keep language condition aligned
    batch0 = next(iter(loader))
    frames_b, _, text_t, text_m, _ = parse_batch(batch0)
    frames_b = frames_b.to(DEVICE)

    rd = rollout(model_drift, frames_b, ROLLOUT_STEPS, text_tokens=text_t, text_mask=text_m)
    rb = rollout(model_base, frames_b, ROLLOUT_STEPS, text_tokens=text_t, text_mask=text_m)

    plt.figure(figsize=(3 * ROLLOUT_STEPS, 6))
    for t in range(ROLLOUT_STEPS):
        img_d = decode_feats(rd[:, t].to(DEVICE))[0].clamp(0, 1)
        img_b = decode_feats(rb[:, t].to(DEVICE))[0].clamp(0, 1)

        if SAVE_VIS_DIR is not None:
            rollout_dir = Path(SAVE_VIS_DIR) / 'rollout'
            rollout_dir.mkdir(parents=True, exist_ok=True)
            plt.imsave(rollout_dir / f'drift_t{t+1}.png', img_d.permute(1, 2, 0).cpu().numpy().clip(0, 1))
            plt.imsave(rollout_dir / f'base_t{t+1}.png', img_b.permute(1, 2, 0).cpu().numpy().clip(0, 1))

        ax1 = plt.subplot(2, ROLLOUT_STEPS, t + 1)
        ax1.imshow(img_d.permute(1, 2, 0))
        ax1.set_title(f'drift t+{t+1}')
        ax1.axis('off')

        ax2 = plt.subplot(2, ROLLOUT_STEPS, ROLLOUT_STEPS + t + 1)
        ax2.imshow(img_b.permute(1, 2, 0))
        ax2.set_title(f'base t+{t+1}')
        ax2.axis('off')

    plt.tight_layout()
    plt.show()
    if SAVE_VIS_DIR is not None:
        print(f"Saved rollout decoded images to: {Path(SAVE_VIS_DIR) / 'rollout'}")
