<cell_type>markdown</cell_type># VDPM + GPT-5 mini 在 VLM4D 错题上的测试

这个 notebook 测试在加入 VDPM 点云轨迹图后，GPT-5 mini 能否回答正确之前答错的题目。

**在 Colab 上运行:**
1. Runtime → Change runtime type → GPU (T4)
2. 按顺序运行所有 cell

**流程:**
1. 安装 VDPM 依赖 (会自动重启一次)
2. 下载 10 个测试视频
3. 用 VDPM 生成点云
4. 渲染轨迹图
5. 对比测试 GPT-5 mini

<cell_type>markdown</cell_type>## 1. 安装 VDPM 依赖 (首次运行会重启)

In [None]:
import os

SETUP_FLAG = '/content/.vdpm_gpt_ready_v7'

def check_installation():
    """验证所有依赖是否已安装且版本兼容"""
    try:
        import vggt
        import omegaconf
        import plotly
        import openai
        import torch
        import kaleido
        # 检查 plotly 版本是否 >= 6.0
        major_version = int(plotly.__version__.split('.')[0])
        if major_version < 6:
            print(f"Plotly 版本过低: {plotly.__version__}, 需要 >= 6.0")
            return False
        # 检查 Chrome 是否安装 (可能在多个位置)
        chrome_paths = [
            '/usr/bin/google-chrome',
            '/usr/local/lib/python3.12/dist-packages/choreographer/cli/browser_exe/chrome-linux64/chrome',
            '/usr/local/lib/python3.11/dist-packages/choreographer/cli/browser_exe/chrome-linux64/chrome',
            '/usr/local/lib/python3.10/dist-packages/choreographer/cli/browser_exe/chrome-linux64/chrome',
        ]
        chrome_found = any(os.path.exists(p) for p in chrome_paths)
        if not chrome_found:
            print("Chrome 未安装")
            return False
        return True
    except ImportError:
        return False

if os.path.exists(SETUP_FLAG) and check_installation():
    print("✓ 已完成安装，继续运行下面的 cell")
else:
    # 清除旧标记
    for f in ['/content/.vdpm_gpt_ready_v3', '/content/.vdpm_gpt_ready_v4', '/content/.vdpm_gpt_ready_v5', '/content/.vdpm_gpt_ready_v6', '/content/.vdpm_gpt_ready_v7']:
        if os.path.exists(f):
            os.remove(f)
    
    print("安装 VDPM 依赖...")
    
    # Clone VDPM
    print("\n[1/6] Clone VDPM...")
    !rm -rf /content/vdpm
    !git clone --depth 1 https://github.com/eldar/vdpm.git /content/vdpm
    
    # Fix NumPy
    print("\n[2/6] Fix NumPy...")
    !pip uninstall -y numpy
    !pip install numpy==1.26.4
    
    # Install VGGT
    print("\n[3/6] Install VGGT...")
    !pip install git+https://github.com/facebookresearch/vggt.git@44b3afb
    
    # Install other deps
    print("\n[4/6] Install other deps...")
    !pip install roma omegaconf einops jaxtyping
    
    print("\n[5/6] Install OpenAI & Plotly...")
    !pip install openai aiolimiter tqdm python-dotenv opencv-python pydantic
    # 强制升级 plotly 到兼容版本
    !pip install --upgrade "plotly>=6.1.1" "kaleido>=1.2.0"
    
    # 安装 Chrome (Kaleido 需要)
    print("\n[6/6] Install Chrome for Kaleido...")
    # 先安装 Chrome 所需的系统依赖
    !apt-get update && apt-get install -y libnss3 libatk-bridge2.0-0 libcups2 libxcomposite1 libxdamage1 libxfixes3 libxrandr2 libgbm1 libxkbcommon0 libpango-1.0-0 libcairo2 libasound2
    !plotly_get_chrome
    
    # 验证安装
    print("\n验证安装...")
    try:
        import vggt
        import omegaconf
        import plotly
        import kaleido
        print(f"✓ plotly={plotly.__version__}")
        print("✓ kaleido OK")
        print("✓ Chrome installed")
        
        # 只有验证通过才标记完成
        !touch {SETUP_FLAG}
        
        print("\n" + "="*50)
        print("✓ 安装完成！正在重启...")
        print("重启后请从头重新运行所有 cell")
        print("="*50)
        os._exit(0)
    except ImportError as e:
        print(f"✗ 验证失败: {e}")
        print("请重新运行此 cell")

<cell_type>markdown</cell_type>## 2. 导入和配置

In [None]:
import os, sys
os.chdir('/content/vdpm')
sys.path.insert(0, '/content/vdpm')

import json
import asyncio
import base64
import hashlib
import cv2
import requests
import random
import numpy as np
import torch
import tempfile
from pathlib import Path
from string import Template
from tqdm import tqdm
from openai import AsyncOpenAI

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# OpenAI API Key
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
    # 在 Colab 中手动设置
    from google.colab import userdata
    try:
        OPENAI_API_KEY = userdata.get('OPENAI_API_KEY')
    except:
        pass

if not OPENAI_API_KEY:
    OPENAI_API_KEY = input("请输入 OPENAI_API_KEY: ")
    
print(f"API Key: {OPENAI_API_KEY[:10]}...")

<cell_type>markdown</cell_type>## 3. 加载 10 道错题

In [None]:
# 10 道 GPT-5 mini 之前回答错误的关于运动方向的题目
SELECTED_QUESTIONS = [
    {
        "id": "validation_5",
        "video": "https://huggingface.co/datasets/shijiezhou/VLM4D/resolve/main/videos_real/davis/baseball.mp4",
        "question": "What direction did the ball come from?",
        "choices": {"A": "right", "B": "left", "C": "below", "D": "above"},
        "answer": "left",
        "model_wrong_answer": "A (right)"
    },
    {
        "id": "validation_11",
        "video": "https://huggingface.co/datasets/shijiezhou/VLM4D/resolve/main/videos_real/davis/basketball-game.mp4",
        "question": "How many times did the person with the ball dribble the ball with his left hand?",
        "choices": {"A": 4, "B": 8, "C": 1, "D": 0},
        "answer": 0,
        "model_wrong_answer": "A (4)"
    },
    {
        "id": "validation_18",
        "video": "https://huggingface.co/datasets/shijiezhou/VLM4D/resolve/main/videos_real/davis/bear.mp4",
        "question": "Is the bear spinning clockwise or counter-clockwise?",
        "choices": {"A": "clockwise", "B": "counter-clockwise", "C": "there are no bears in the video", "D": "not spinning"},
        "answer": "not spinning",
        "model_wrong_answer": "A (clockwise)"
    },
    {
        "id": "validation_25",
        "video": "https://huggingface.co/datasets/shijiezhou/VLM4D/resolve/main/videos_real/davis/bike-packing.mp4",
        "question": "Which direction is the bike moving towards?",
        "choices": {"A": "right", "B": "staying in place", "C": "away from the camera", "D": "left"},
        "answer": "staying in place",
        "model_wrong_answer": "A (right)"
    },
    {
        "id": "validation_34",
        "video": "https://huggingface.co/datasets/shijiezhou/VLM4D/resolve/main/videos_real/davis/blackswan.mp4",
        "question": "Is the swan spinning clockwise or counter-clockwise?",
        "choices": {"A": "counter-clockwise", "B": "both ways", "C": "clockwise", "D": "not spinning"},
        "answer": "not spinning",
        "model_wrong_answer": "C (clockwise)"
    },
    {
        "id": "validation_37",
        "video": "https://huggingface.co/datasets/shijiezhou/VLM4D/resolve/main/videos_real/davis/bmx-bumps.mp4",
        "question": "From the camera perspective, what direction is the boy moving towards?",
        "choices": {"A": "left", "B": "not moving", "C": "right", "D": "towards the camera"},
        "answer": "left",
        "model_wrong_answer": "C (right)"
    },
    {
        "id": "validation_46",
        "video": "https://huggingface.co/datasets/shijiezhou/VLM4D/resolve/main/videos_real/davis/bmx-rider.mp4",
        "question": "Is the bike rider spinning clockwise or counter-clockwise in the air?",
        "choices": {"A": "clockwise", "B": "not spinning", "C": "counter-clockwise", "D": "there are no people in the video"},
        "answer": "counter-clockwise",
        "model_wrong_answer": "A (clockwise)"
    },
    {
        "id": "validation_54",
        "video": "https://huggingface.co/datasets/shijiezhou/VLM4D/resolve/main/videos_real/davis/boat.mp4",
        "question": "Is the boat rotating clockwise or counter-clockwise?",
        "choices": {"A": "clockwise", "B": "not rotating", "C": "counter-clockwise", "D": "there are no boats in the video"},
        "answer": "not rotating",
        "model_wrong_answer": "C (counter-clockwise)"
    },
    {
        "id": "validation_66",
        "video": "https://huggingface.co/datasets/shijiezhou/VLM4D/resolve/main/videos_real/davis/breakdance.mp4",
        "question": "How many full revolutions does the dancer spin clockwise?",
        "choices": {"A": 5, "B": 7, "C": 2, "D": 0},
        "answer": 2,
        "model_wrong_answer": "B (7)"
    },
    {
        "id": "validation_72",
        "video": "https://huggingface.co/datasets/shijiezhou/VLM4D/resolve/main/videos_real/davis/breakdance-flare.mp4",
        "question": "How many full revolutions does the dancer spin counter-clockwise?",
        "choices": {"A": 1, "B": 2, "C": 4, "D": 0},
        "answer": 0,
        "model_wrong_answer": "B (2)"
    }
]

print(f"加载了 {len(SELECTED_QUESTIONS)} 道错题")
for i, q in enumerate(SELECTED_QUESTIONS, 1):
    video_name = q['video'].split('/')[-1]
    print(f"{i}. [{video_name}] {q['question'][:50]}...")
    print(f"   正确答案: {q['answer']}, 之前错误回答: {q['model_wrong_answer']}")

<cell_type>markdown</cell_type>## 4. 下载视频

In [None]:
VIDEO_DIR = Path("/content/vlm4d_videos")
VIDEO_DIR.mkdir(parents=True, exist_ok=True)

for q in SELECTED_QUESTIONS:
    video_name = q['video'].split('/')[-1].replace('.mp4', '')
    video_path = VIDEO_DIR / f"{video_name}.mp4"
    
    if video_path.exists():
        print(f"✓ 已存在: {video_name}.mp4")
    else:
        print(f"下载: {video_name}.mp4 ...", end=" ")
        resp = requests.get(q['video'])
        video_path.write_bytes(resp.content)
        print("完成")
    
    q['local_video_path'] = str(video_path)

print(f"\n✓ 所有视频已下载到 {VIDEO_DIR}")

<cell_type>markdown</cell_type>## 5. 用 VDPM 生成点云

In [None]:
from omegaconf import OmegaConf
from dpm.model import VDPM
from vggt.utils.load_fn import load_and_preprocess_images

# 加载 VDPM 模型
print("加载 VDPM 模型...")
cfg = OmegaConf.create({
    'model': {'name': 'dpm-video', 'pretrained': None, 'decoder_depth': 4}
})
model = VDPM(cfg).to(device)

url = "https://huggingface.co/edgarsucar/vdpm/resolve/main/model.pt"
weights = torch.hub.load_state_dict_from_url(url, file_name="vdpm_model.pt")
model.load_state_dict(weights, strict=True)
model.eval()
print("✓ 模型加载完成")

In [None]:
def extract_frames(video_path, output_dir, sample_hz=1.0):
    """提取视频帧"""
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
    interval = max(int(fps / sample_hz), 1)
    
    paths = []
    count = 0
    frame_idx = 0
    
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        if count % interval == 0:
            path = output_dir / f"{frame_idx:04d}.png"
            cv2.imwrite(str(path), frame)
            paths.append(str(path))
            frame_idx += 1
        count += 1
    cap.release()
    return sorted(paths)


def run_vdpm(video_path, output_dir, ref_frame=0):
    """运行 VDPM 生成点云"""
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    with tempfile.TemporaryDirectory() as tmp:
        frame_paths = extract_frames(video_path, tmp)
        print(f"  提取了 {len(frame_paths)} 帧")
        images = load_and_preprocess_images(frame_paths).to(device)
    
    with torch.no_grad():
        result = model.inference(None, images=images.unsqueeze(0))
    
    pointmaps = result['pointmaps']
    pts_list = [pm['pts3d'].detach().cpu().numpy() for pm in pointmaps]
    conf_list = [pm['conf'].detach().cpu().numpy() for pm in pointmaps]
    
    world_points = np.concatenate(pts_list, axis=0)
    world_conf = np.concatenate(conf_list, axis=0)
    
    num_frames = world_points.shape[0]
    all_pts, all_conf = [], []
    
    for t in range(num_frames):
        pts = world_points[t, ref_frame, :, :, :].reshape(-1, 3)
        conf = world_conf[t, ref_frame, :, :].reshape(-1)
        all_pts.append(pts)
        all_conf.append(conf)
    
    np.savez(output_dir / "sequence.npz", points=np.stack(all_pts), conf=np.stack(all_conf))
    return output_dir


# 生成所有视频的点云
POINTCLOUD_DIR = Path("/content/vlm4d_pointclouds")
POINTCLOUD_DIR.mkdir(parents=True, exist_ok=True)

for q in SELECTED_QUESTIONS:
    video_name = q['video'].split('/')[-1].replace('.mp4', '')
    output_dir = POINTCLOUD_DIR / video_name
    npz_path = output_dir / "sequence.npz"
    
    if npz_path.exists():
        print(f"✓ {video_name}: 点云已存在")
        q['npz_path'] = str(npz_path)
        continue
    
    print(f"处理: {video_name}")
    try:
        run_vdpm(q['local_video_path'], output_dir)
        q['npz_path'] = str(npz_path)
        print(f"  ✓ 完成")
    except Exception as e:
        print(f"  ✗ 失败: {e}")
    
    torch.cuda.empty_cache()

print("\n✓ 点云生成完成")

<cell_type>markdown</cell_type>## 6. 渲染轨迹图

颜色说明：**浅青色→深红色** 表示时间从早到晚，即运动方向。

In [None]:
"""
VDPM 轨迹渲染代码 (内嵌自 render_trajectory.py)

基于 VDPM gradio_demo.py 的 Plotly 渲染逻辑，
直接从 .npz 文件生成带轨迹的点云图片。
"""

import matplotlib
import matplotlib.colors
import plotly.graph_objects as go
from typing import Tuple, Optional

# 参数（来自 gradio_demo.py）
MAX_POINTS_PER_FRAME = 50_000
TRAIL_LENGTH = 16
MAX_TRACKS = 200
STATIC_THRESHOLD = 0.025


def load_vdpm_data(npz_path: str, video_path: str) -> dict:
    """加载 VDPM 数据和视频帧颜色"""
    data = np.load(npz_path)
    points = data['points']  # (T, N, 3)
    conf = data['conf']      # (T, N)

    T, N, _ = points.shape

    # 推断分辨率
    H, W = None, None
    for h in range(200, 600):
        if N % h == 0:
            w = N // h
            if 0.9 < w / h < 2.1:
                H, W = h, w
                break

    if H is None:
        raise ValueError(f"无法推断分辨率，点数: {N}")

    print(f"点云分辨率: {W}x{H}")

    # 读取视频帧获取颜色
    video = cv2.VideoCapture(video_path)
    images = []
    while True:
        ret, frame = video.read()
        if not ret:
            break
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) / 255.0
        frame_resized = cv2.resize(frame_rgb, (W, H))
        images.append(frame_resized)
    video.release()

    # 均匀采样到 T 帧
    indices = np.linspace(0, len(images) - 1, T, dtype=int)
    images = np.array([images[i] for i in indices])

    world_points = points.reshape(T, H, W, 3)

    return {
        'world_points': world_points,
        'conf': conf.reshape(T, H, W) if conf is not None else None,
        'images': images,
    }


def compute_scene_bounds(world_points: np.ndarray):
    """计算场景边界"""
    all_pts = world_points.reshape(-1, 3)
    raw_min = all_pts.min(axis=0)
    raw_max = all_pts.max(axis=0)

    center = 0.5 * (raw_min + raw_max)
    half_extent = 0.5 * (raw_max - raw_min) * 1.05

    if np.all(half_extent < 1e-6):
        half_extent[:] = 1.0
    else:
        half_extent[half_extent < 1e-6] = half_extent.max()

    global_min = center - half_extent
    global_max = center + half_extent

    max_half = half_extent.max()
    aspectratio = {
        "x": float(half_extent[0] / max_half),
        "y": float(half_extent[1] / max_half),
        "z": float(half_extent[2] / max_half),
    }
    return global_min, global_max, aspectratio


def prepare_tracks(
    world_points: np.ndarray,
    images: np.ndarray,
    conf: Optional[np.ndarray],
    conf_thres: float = 1.5,
    color_mode: str = "rainbow",
) -> Tuple[Optional[np.ndarray], Optional[list], Optional[np.ndarray]]:
    """准备轨迹数据"""
    S, H, W, _ = world_points.shape
    N = H * W
    if S < 2 or N == 0:
        return None, None, None

    tracks_xyz = world_points.reshape(S, N, 3)

    disp = np.linalg.norm(tracks_xyz - tracks_xyz[0:1], axis=-1)
    dynamic_mask = disp.max(axis=0) > STATIC_THRESHOLD

    if conf is not None:
        conf_flat = conf.reshape(S, N)
        conf_score = conf_flat.mean(axis=0)
        dynamic_mask &= (conf_score >= conf_thres)

    idx_tracks = np.nonzero(dynamic_mask)[0]
    if idx_tracks.size == 0:
        return None, None, None

    if idx_tracks.size > MAX_TRACKS:
        step = int(np.ceil(idx_tracks.size / MAX_TRACKS))
        idx_tracks = idx_tracks[::step][:MAX_TRACKS]

    tracks_xyz = tracks_xyz[:, idx_tracks, :]

    order = np.argsort(tracks_xyz[0, :, 1])
    tracks_xyz = tracks_xyz[:, order, :]

    num_tracks = tracks_xyz.shape[1]
    num_frames = tracks_xyz.shape[0]

    if color_mode == "depth":
        colorscale = []
        for t in range(num_frames):
            ratio = t / max(num_frames - 1, 1)
            r = int(150 + (180 - 150) * ratio)
            g = int(230 - 230 * ratio)
            b = int(255 - 225 * ratio)
            pos = ratio
            colorscale.append([pos, f"rgb({r},{g},{b})"])
        track_ids = np.arange(num_frames, dtype=float)
    else:
        cmap = matplotlib.cm.get_cmap("hsv")
        norm = matplotlib.colors.Normalize(vmin=0, vmax=max(num_tracks - 1, 1))

        colorscale = []
        for t in range(num_tracks):
            r, g, b, _ = cmap(norm(t))
            r, g, b = int(r * 255), int(g * 255), int(b * 255)
            pos = t / max(num_tracks - 1, 1)
            colorscale.append([pos, f"rgb({r},{g},{b})"])
        track_ids = np.arange(num_tracks, dtype=float)

    return tracks_xyz, colorscale, track_ids, color_mode, num_frames


def track_segments_for_frame(
    tracks_xyz: Optional[np.ndarray],
    track_ids: Optional[np.ndarray],
    f: int,
    trail_length: int = TRAIL_LENGTH,
    color_mode: str = "rainbow",
    num_frames: int = 1,
):
    """获取某帧的轨迹线段"""
    if tracks_xyz is None or track_ids is None or f <= 0:
        return np.array([]), np.array([]), np.array([]), np.array([])

    start_t = max(0, f - trail_length)
    num_tracks = tracks_xyz.shape[1]

    xs, ys, zs, cs = [], [], [], []
    for j in range(num_tracks):
        seg = tracks_xyz[start_t: f + 1, j, :]
        if seg.shape[0] < 2:
            continue

        xs.extend([seg[:, 0], np.array([np.nan])])
        ys.extend([seg[:, 1], np.array([np.nan])])
        zs.extend([seg[:, 2], np.array([np.nan])])

        if color_mode == "depth":
            time_indices = np.arange(start_t, f + 1, dtype=float)
            cs.append(np.concatenate([time_indices, np.array([np.nan])]))
        else:
            cs.append(np.full(seg.shape[0] + 1, track_ids[j], dtype=float))

    x = np.concatenate(xs) if xs else np.array([])
    y = np.concatenate(ys) if ys else np.array([])
    z = np.concatenate(zs) if zs else np.array([])
    c = np.concatenate(cs) if cs else np.array([])

    return x, y, z, c


def sample_frame_points(
    world_points: np.ndarray,
    images: np.ndarray,
    conf: Optional[np.ndarray],
    frame_idx: int,
    conf_thres: float = 1.5,
    max_points: int = MAX_POINTS_PER_FRAME,
):
    """采样某帧的点和颜色"""
    S, H, W, _ = world_points.shape
    pts = world_points[frame_idx].reshape(-1, 3)
    cols = (images[frame_idx].reshape(-1, 3) * 255).astype(np.uint8)

    mask = np.ones(pts.shape[0], dtype=bool)
    if conf is not None:
        conf_flat = conf[frame_idx].reshape(-1)
        mask &= (conf_flat >= conf_thres)

    pts = pts[mask]
    cols = cols[mask]

    n = pts.shape[0]
    if n > max_points:
        step = int(np.ceil(n / max_points))
        pts = pts[::step]
        cols = cols[::step]

    colors_str = [f"#{r:02x}{g:02x}{b:02x}" for r, g, b in cols]
    return pts, colors_str


def render_frame_with_tracks(
    data: dict,
    frame_idx: int,
    output_path: str,
    conf_thres: float = 1.5,
    width: int = 1200,
    height: int = 900,
    show_tracks: bool = True,
    camera_eye: tuple = (0.0, 0.0, -2.0),
    camera_center: tuple = (0.0, 0.0, 0.5),
    color_mode: str = "rainbow",
) -> str:
    """渲染某帧的点云和轨迹，保存为图片"""
    world_points = data['world_points']
    images = data['images']
    conf = data.get('conf')

    S = world_points.shape[0]
    frame_idx = min(frame_idx, S - 1)

    global_min, global_max, aspectratio = compute_scene_bounds(world_points)

    if show_tracks:
        result = prepare_tracks(
            world_points, images, conf, conf_thres, color_mode
        )
        tracks_xyz, colorscale, track_ids, actual_color_mode, num_frames = result
        if actual_color_mode == "depth":
            track_cmax = max(num_frames - 1, 1)
        else:
            track_cmax = max(len(track_ids) - 1, 1) if track_ids is not None else 1
    else:
        tracks_xyz, colorscale, track_ids = None, None, None
        track_cmax = 1
        actual_color_mode = color_mode
        num_frames = S

    pts, cols = sample_frame_points(
        world_points, images, conf, frame_idx, conf_thres
    )

    x, y, z, c = track_segments_for_frame(
        tracks_xyz, track_ids, frame_idx,
        color_mode=actual_color_mode, num_frames=num_frames
    )

    traces = [
        go.Scatter3d(
            x=pts[:, 0],
            y=pts[:, 1],
            z=pts[:, 2],
            mode="markers",
            marker=dict(size=2, color=cols),
            showlegend=False,
            name="points",
        ),
    ]

    if show_tracks and len(x) > 0:
        traces.append(
            go.Scatter3d(
                x=x,
                y=y,
                z=z,
                mode="lines",
                line=dict(
                    width=3,
                    color=c if c is not None and c.size else None,
                    colorscale=colorscale if colorscale else None,
                    cmin=0,
                    cmax=track_cmax,
                ),
                hoverinfo="skip",
                showlegend=False,
                name="tracks",
            )
        )

    fig = go.Figure(data=traces)

    scene_cfg = dict(
        xaxis=dict(visible=False, showbackground=False, range=[float(global_min[0]), float(global_max[0])]),
        yaxis=dict(visible=False, showbackground=False, range=[float(global_min[1]), float(global_max[1])]),
        zaxis=dict(visible=False, showbackground=False, range=[float(global_min[2]), float(global_max[2])]),
        aspectmode="manual",
        aspectratio=aspectratio,
        camera=dict(
            eye=dict(x=0.0, y=0.0, z=-1.0),
            center=dict(x=0.0, y=0.0, z=0.0),
            up=dict(x=0.0, y=-1.0, z=0.0),
        ),
        bgcolor='white',
    )

    fig.update_layout(
        margin=dict(l=0, r=0, t=0, b=0),
        scene=scene_cfg,
        showlegend=False,
        width=width,
        height=height,
        paper_bgcolor='white',
    )

    fig.write_image(output_path, scale=2)
    print(f"已保存: {output_path}")
    return output_path


def render_trajectory_image(
    npz_path: str,
    video_path: str,
    output_path: str,
    frame_idx: int = -1,
    conf_thres: float = 1.5,
    width: int = 1200,
    height: int = 900,
    camera_eye: tuple = (0.0, 0.0, -2.0),
    camera_center: tuple = (0.0, 0.0, 0.5),
    color_mode: str = "rainbow",
) -> str:
    """主函数：从 npz 和视频生成轨迹图片"""
    print(f"加载数据: {npz_path}")
    data = load_vdpm_data(npz_path, video_path)

    T = data['world_points'].shape[0]
    if frame_idx == -1:
        frame_idx = T - 1

    print(f"渲染帧 {frame_idx}/{T-1}, 颜色模式: {color_mode}")
    return render_frame_with_tracks(
        data, frame_idx, output_path,
        conf_thres=conf_thres,
        width=width,
        height=height,
        camera_eye=camera_eye,
        camera_center=camera_center,
        color_mode=color_mode,
    )


print("轨迹渲染代码已加载")

In [None]:
# 渲染所有视频的轨迹图
RENDER_DIR = "/content/vdpm/vdpm_renders"  # 使用绝对路径
os.makedirs(RENDER_DIR, exist_ok=True)

for q in tqdm(SELECTED_QUESTIONS, desc="渲染轨迹图"):
    video_name = q['video'].split('/')[-1].replace('.mp4', '')
    npz_path = q.get('npz_path')
    
    if not npz_path or not os.path.exists(npz_path):
        print(f"跳过 {video_name}: 无点云")
        continue
    
    render_path = os.path.join(RENDER_DIR, f"{video_name}.png")
    
    if os.path.exists(render_path):
        print(f"已存在: {video_name}.png")
        q['render_path'] = render_path  # 即使已存在也要设置路径！
        continue
    
    try:
        render_trajectory_image(
            npz_path=npz_path,
            video_path=q['local_video_path'],
            output_path=render_path,
            frame_idx=-1,  # 最后一帧
            color_mode="depth",  # 由浅到深
        )
        q['render_path'] = render_path
        print(f"✓ 渲染完成: {video_name}.png")
    except Exception as e:
        print(f"✗ 渲染失败 {video_name}: {e}")

# 验证所有轨迹图
print("\n轨迹图状态:")
for q in SELECTED_QUESTIONS:
    video_name = q['video'].split('/')[-1].replace('.mp4', '')
    has_render = 'render_path' in q and os.path.exists(q.get('render_path', ''))
    status = "✓" if has_render else "✗"
    print(f"  {status} {video_name}: {q.get('render_path', '无')}")

print("\n轨迹图渲染完成")

<cell_type>markdown</cell_type>## 7. 准备 Prompt

In [None]:
# 常量
MAX_TOKENS = 2048  # 增加到 2048 避免输出被截断 (之前 1024 导致 Q3/Q7 被截断)
GENERATION_TEMPERATURE = 1.0
TOTAL_FRAMES = 32

# 原始 Prompt (仅视频帧)
ORIGINAL_PROMPT = Template("""
Question: $question
$optionized_str

Answer the given multiple-choice question step by step. Begin by explaining your reasoning process clearly. In the last sentence of your response, you must conclude by stating the final answer using the following format: 'Therefore, the final answer is: $$LETTER' (without quotes), where $$LETTER must be only one of the options (A or B or C or D). Think step by step before answering.""")

# 新 Prompt (视频帧 + 轨迹图)
VDPM_PROMPT = Template("""
You are given video frames and a 3D point cloud trajectory visualization.

**About the trajectory image:**
- The trajectory image shows the motion paths of objects in 3D space
- Line colors indicate time: **light cyan = early position, dark red = later position**
- This helps you understand which direction objects are moving towards
- If trajectories are short or clustered, the object may not be moving much
- If trajectories are long and directional, the object is clearly moving in that direction

**Instructions:**
1. First, analyze the video frames to understand the scene
2. Then, use the trajectory image to determine the actual motion direction
3. The color gradient (light→dark) shows you exactly where objects moved from and to

Question: $question
$optionized_str

Answer the given multiple-choice question step by step. Use both the video frames AND the trajectory visualization to reason about motion. In the last sentence of your response, you must conclude by stating the final answer using the following format: 'Therefore, the final answer is: $$LETTER' (without quotes), where $$LETTER must be only one of the options (A or B or C or D). Think step by step before answering.""")

print("Prompt 模板已准备")

<cell_type>markdown</cell_type>## 8. 工具函数

In [None]:
def read_video_frames(video_path, total_frames):
    """从视频中均匀采样帧并转为 base64"""
    video = cv2.VideoCapture(video_path)
    if not video.isOpened():
        raise ValueError(f"Could not open video file: {video_path}")
    
    try:
        base64_frames = []
        while True:
            success, frame = video.read()
            if not success:
                break
            _, buffer = cv2.imencode('.jpg', frame)
            frame_base64 = base64.b64encode(buffer).decode('utf-8')
            base64_frames.append(frame_base64)
        
        # 均匀采样
        random.seed(42)
        if total_frames == 1:
            selected_indices = [np.random.choice(range(len(base64_frames)))]
        else:
            selected_indices = np.linspace(0, len(base64_frames) - 1, total_frames, dtype=int)
        
        selected_base64_frames = [base64_frames[i] for i in selected_indices]
        return selected_base64_frames
    finally:
        video.release()


def read_image_as_base64(image_path):
    """读取图片并转为 base64"""
    with open(image_path, "rb") as f:
        return base64.b64encode(f.read()).decode('utf-8')


def prepare_message_original(query, total_frames, prompt_template):
    """准备原始消息 (仅视频帧)"""
    # 准备文本
    optionized_list = [f"{key}: {value}" for key, value in query['choices'].items()]
    optionized_str = "\n".join(optionized_list)
    qa_text = prompt_template.substitute(question=query['question'], optionized_str=optionized_str)
    
    # 准备视频帧
    base64_frames = read_video_frames(query['local_video_path'], total_frames)
    
    content = [
        {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame}"}}
        for frame in base64_frames
    ]
    content.append({"type": "text", "text": qa_text})
    
    return [{"role": "user", "content": content}]


def prepare_message_with_vdpm(query, total_frames, prompt_template):
    """准备带 VDPM 轨迹图的消息"""
    # 准备文本
    optionized_list = [f"{key}: {value}" for key, value in query['choices'].items()]
    optionized_str = "\n".join(optionized_list)
    qa_text = prompt_template.substitute(question=query['question'], optionized_str=optionized_str)
    
    # 准备视频帧
    base64_frames = read_video_frames(query['local_video_path'], total_frames)
    
    content = [
        {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame}"}}
        for frame in base64_frames
    ]
    
    # 添加轨迹图 (放在视频帧之后)
    if 'render_path' in query and os.path.exists(query['render_path']):
        trajectory_base64 = read_image_as_base64(query['render_path'])
        content.append({
            "type": "image_url",
            "image_url": {"url": f"data:image/png;base64,{trajectory_base64}"}
        })
    
    content.append({"type": "text", "text": qa_text})
    
    return [{"role": "user", "content": content}]


print("工具函数已准备")

<cell_type>markdown</cell_type>## 9. API 调用函数

In [None]:
from pydantic import BaseModel
import aiolimiter
from tqdm.asyncio import tqdm_asyncio

# LLM 评估输出格式 (与 VLM4D 官方一致)
class EvaluationOutput(BaseModel):
    extracted_answer: str
    correct: bool

# 评估 Prompt (来自 VLM4D utils/eval_utils.py)
EVAL_INSTRUCTION = """Your task is to evaluate whether the model's final answer is correct by comparing it to the ground-truth answer provided for the given question.

You should first extract the final answer from the model's response, and then compare the extracted answer with the choice that matches the ground-truth answer to determine its correctness.
Output your response in the following structured format:
{
    "extracted_answer": // str value "A" "B" "C" "D", followed by a colon and the corresponding answer text, e.g., "A: Answer A text". If the model's response does not contain a valid choice and reasoning, then "No Valid Answer".
    "correct": // boolean value, True if the extracted answer matches the ground-truth answer (correct choice), False otherwise ("No Valid Answer" is also considered False).
}
"""


def prepare_evaluation_message(example, response):
    """准备评估消息 (来自 VLM4D utils/eval_utils.py)"""
    optionized_list = [f"{key}: {value}" for key, value in example['choices'].items()]
    optionized_str = "\n".join(optionized_list)
    question_context = f"Question: {example['question']}\n\nOptions:\n{optionized_str}"
    gt_answer = f"Ground Truth Answer: {example['answer']}"
    model_response = f"Model Response to the Question: {response}"
    
    user_prompt = f"{question_context}\n\n{gt_answer}\n\n{model_response}"
    
    return [
        {"role": "system", "content": EVAL_INSTRUCTION},
        {"role": "user", "content": user_prompt},
    ]


async def _throttled_openai_chat_completion_acreate(
    client,
    model,
    messages,
    temperature,
    max_tokens,
    top_p,
    limiter,
    question_id="unknown",  # 添加问题ID用于调试
):
    """单次 API 调用（带重试）"""
    async with limiter:
        for attempt in range(10):
            try:
                response = await client.chat.completions.create(
                    model=model,
                    messages=messages,
                    temperature=temperature,
                    max_completion_tokens=max_tokens,
                    top_p=top_p,
                )
                # 检查是否有 refusal 或 finish_reason 问题
                choice = response.choices[0]
                if choice.finish_reason == "content_filter":
                    print(f"⚠️ {question_id}: 内容被过滤 (content_filter)")
                elif choice.finish_reason == "length":
                    print(f"⚠️ {question_id}: 输出被截断 (length)")
                elif hasattr(choice.message, 'refusal') and choice.message.refusal:
                    print(f"⚠️ {question_id}: 模型拒绝回答: {choice.message.refusal}")
                
                return response
            except Exception as e:
                error_str = str(e).lower()
                if "rate_limit" in error_str:
                    print(f"Rate limit exceeded, retrying (attempt {attempt+1})...")
                    await asyncio.sleep(random.randint(10, 20))
                elif "bad_request" in error_str:
                    print(f"Bad request for {question_id}: {e}")
                    return None
                elif "context_length" in error_str or "too many tokens" in error_str:
                    print(f"⚠️ {question_id}: 输入太长，超过上下文限制: {e}")
                    return None
                else:
                    print(f"Error for {question_id} (attempt {attempt+1}): {e}")
                    await asyncio.sleep(random.randint(5, 10))
        print(f"⚠️ {question_id}: 10次重试后仍失败")
        return None


async def generate_from_openai_chat_completion(
    client,
    messages,
    engine_name,
    temperature=1.0,
    max_tokens=512,
    top_p=1.0,
    requests_per_minute=150,
    question_ids=None,  # 添加问题ID列表用于调试
):
    """批量调用 OpenAI API"""
    delay = 60.0 / requests_per_minute
    limiter = aiolimiter.AsyncLimiter(1, delay)
    
    if question_ids is None:
        question_ids = [f"Q{i+1}" for i in range(len(messages))]
    
    async_responses = [
        _throttled_openai_chat_completion_acreate(
            client,
            model=engine_name,
            messages=message,
            temperature=temperature,
            max_tokens=max_tokens,
            top_p=top_p,
            limiter=limiter,
            question_id=qid,
        )
        for message, qid in zip(messages, question_ids)
    ]
    
    responses = await tqdm_asyncio.gather(*async_responses, desc="API 调用")
    
    outputs = []
    for i, (response, qid) in enumerate(zip(responses, question_ids)):
        if response is None:
            print(f"⚠️ {qid}: API 返回 None")
            outputs.append("")
        else:
            try:
                content = response.choices[0].message.content
                if content:
                    outputs.append(content)
                else:
                    # 打印完整的 choice 信息用于调试
                    choice = response.choices[0]
                    print(f"⚠️ {qid}: 回答内容为空!")
                    print(f"   finish_reason: {choice.finish_reason}")
                    if hasattr(choice.message, 'refusal') and choice.message.refusal:
                        print(f"   refusal: {choice.message.refusal}")
                    outputs.append("")
            except Exception as e:
                print(f"⚠️ {qid}: 提取回答失败 - {e}")
                outputs.append("")
    
    return outputs


async def _throttled_eval_call(client, model, messages, limiter):
    """单次评估调用 (o4-mini 需要用 max_completion_tokens)"""
    async with limiter:
        for _ in range(10):
            try:
                response = await client.beta.chat.completions.parse(
                    model=model,
                    messages=messages,
                    temperature=1.0,
                    max_completion_tokens=1024,  # o4-mini 需要用 max_completion_tokens
                    top_p=1.0,
                    response_format=EvaluationOutput,
                )
                return response.choices[0].message.parsed
            except Exception as e:
                if "rate_limit" in str(e).lower():
                    await asyncio.sleep(random.randint(10, 20))
                else:
                    print(f"Eval error: {e}")
                    await asyncio.sleep(random.randint(5, 10))
        return None


async def get_acc_async(examples, client, eval_model="o4-mini"):
    """评估所有响应 (来自 VLM4D utils/eval_utils.py)"""
    evaluation_messages = [
        prepare_evaluation_message(example, example['response'])
        for example in examples
    ]
    
    # 批量评估
    delay = 60.0 / 1000  # 1000 requests per minute for eval
    limiter = aiolimiter.AsyncLimiter(1, delay)
    
    tasks = [_throttled_eval_call(client, eval_model, msg, limiter) for msg in evaluation_messages]
    outputs = await tqdm_asyncio.gather(*tasks, desc="评估中")
    
    # 统计结果
    count = 0
    results = []
    for example, output in zip(examples, outputs):
        result = {
            "id": example["id"],
            "question": example["question"],
            "choices": example["choices"],
            "response": example["response"],
            "ground_truth_answer": example["answer"],
        }
        try:
            result["extracted_answer"] = output.extracted_answer
            result["correct"] = output.correct
        except Exception as e:
            result["extracted_answer"] = ""
            result["correct"] = False
            print(f"Error: {e}")
        
        results.append(result)
        count += result["correct"]
    
    return count / len(examples) if examples else 0, results


print("API 函数和评估器已加载 (与 VLM4D 官方一致)")

<cell_type>markdown</cell_type>## 10. 运行 VDPM 测试

In [None]:
MODEL_NAME = "gpt-5-mini"
client = AsyncOpenAI(api_key=OPENAI_API_KEY)

print("开始测试...")
print(f"模型: {MODEL_NAME}")
print(f"帧数: {TOTAL_FRAMES}")
print(f"题目数: {len(SELECTED_QUESTIONS)}")

In [None]:
# Step 1: 获取模型回答
print("="*50)
print("Step 1: 获取 GPT-5 mini 回答 (视频帧 + VDPM 轨迹图)")
print("="*50)

# 自动补充缺失的路径（处理重新运行的情况）
RENDER_DIR = "/content/vdpm/vdpm_renders"
VIDEO_DIR = Path("/content/vlm4d_videos")
POINTCLOUD_DIR = Path("/content/vlm4d_pointclouds")

for q in SELECTED_QUESTIONS:
    video_name = q['video'].split('/')[-1].replace('.mp4', '')
    
    # 补充 local_video_path
    if 'local_video_path' not in q:
        q['local_video_path'] = str(VIDEO_DIR / f"{video_name}.mp4")
    
    # 补充 npz_path
    if 'npz_path' not in q:
        npz_path = POINTCLOUD_DIR / video_name / "sequence.npz"
        if npz_path.exists():
            q['npz_path'] = str(npz_path)
    
    # 补充 render_path
    if 'render_path' not in q:
        render_path = os.path.join(RENDER_DIR, f"{video_name}.png")
        if os.path.exists(render_path):
            q['render_path'] = render_path

# 准备所有消息
all_messages = []
valid_indices = []  # 记录有轨迹图的题目索引
question_ids = []  # 记录问题ID用于调试

for i, q in enumerate(SELECTED_QUESTIONS):
    video_name = q['video'].split('/')[-1].replace('.mp4', '')
    
    # 检查是否有轨迹图
    if 'render_path' not in q or not os.path.exists(q.get('render_path', '')):
        print(f"跳过 {q['id']}: 无轨迹图 (render_path={q.get('render_path', '无')})")
        continue
    
    message = prepare_message_with_vdpm(q, TOTAL_FRAMES, VDPM_PROMPT)
    all_messages.append(message)
    valid_indices.append(i)
    question_ids.append(f"{q['id']} ({video_name})")
    
    # 添加 question_type 字段 (评估需要)
    q['question_type'] = 'multiple-choice'
    
    # 打印图片大小信息
    render_size = os.path.getsize(q['render_path']) / 1024  # KB
    print(f"  {q['id']}: 轨迹图大小 {render_size:.1f} KB")

print(f"\n准备了 {len(all_messages)} 个有效问题")

if len(all_messages) == 0:
    print("\n⚠️  没有有效问题！请检查:")
    print("1. 是否运行了点云生成 cell")
    print("2. 是否运行了轨迹图渲染 cell")
    print("3. 轨迹图文件是否存在于", RENDER_DIR)
else:
    # 批量调用 API
    base_rate = 100
    requests_per_minute = int(base_rate / (TOTAL_FRAMES ** 0.5))
    print(f"请求速率: {requests_per_minute}/min")

    responses = await generate_from_openai_chat_completion(
        client=client,
        messages=all_messages,
        engine_name=MODEL_NAME,
        temperature=GENERATION_TEMPERATURE,
        max_tokens=MAX_TOKENS,
        top_p=1.0,
        requests_per_minute=requests_per_minute,
        question_ids=question_ids,  # 传递问题ID用于调试
    )

    # 将响应添加到对应的问题中，并打印哪些失败了
    print("\n回答状态:")
    for idx, response, qid in zip(valid_indices, responses, question_ids):
        SELECTED_QUESTIONS[idx]['response'] = response
        SELECTED_QUESTIONS[idx]['vdpm_response'] = response
        
        if response:
            print(f"  ✓ {qid}: 收到回答 ({len(response)} 字符)")
        else:
            print(f"  ✗ {qid}: 回答为空!")

    success_count = sum(1 for r in responses if r)
    print(f"\n获得 {success_count}/{len(responses)} 个有效响应")

In [None]:
# 查看模型完整回答
print("="*50)
print("模型完整回答")
print("="*50)

for i, q in enumerate(SELECTED_QUESTIONS, 1):
    video_name = q['video'].split('/')[-1].replace('.mp4', '')
    response = q.get('response', '')
    
    print(f"\n--- [{q['id']}] {video_name} ---")
    print(f"问题: {q['question']}")
    print(f"选项: {q['choices']}")
    print(f"正确答案: {q['answer']}")
    print(f"之前错误回答: {q['model_wrong_answer']}")
    print(f"\n模型回答:")
    if response:
        # 显示完整回答，如果太长则截断
        if len(response) > 800:
            print(f"{response[:800]}...")
        else:
            print(response)
    else:
        print("(无回答 - 跳过)")
    print()

In [None]:
# Step 2: 使用 LLM 评估
print("="*50)
print("Step 2: 使用 o4-mini 进行 LLM 评估")
print("="*50)

# 只评估有回答的问题
examples_to_eval = [q for q in SELECTED_QUESTIONS if q.get('response')]

print(f"评估 {len(examples_to_eval)} 个回答...")

vdpm_accuracy, eval_results = await get_acc_async(examples_to_eval, client, eval_model="o4-mini")

print(f"\n" + "=" * 50)
print(f"VDPM 准确率: {vdpm_accuracy:.0%} ({int(vdpm_accuracy * len(eval_results))}/{len(eval_results)})")
print("=" * 50)

<cell_type>markdown</cell_type>## 11. 结果汇总

In [None]:
# 详细评估结果
print("="*60)
print("详细评估结果 (LLM 评估)")
print("="*60)

# 构建 ID 到评估结果的映射
eval_map = {r['id']: r for r in eval_results}

correct_ids = []
wrong_ids = []
skipped_ids = []

for q in SELECTED_QUESTIONS:
    qid = q['id']
    if qid in eval_map:
        r = eval_map[qid]
        if r['correct']:
            correct_ids.append(qid)
        else:
            wrong_ids.append(qid)
    else:
        skipped_ids.append(qid)

print(f"\n模型: {MODEL_NAME}")
print(f"评估模型: o4-mini")
print(f"帧数: {TOTAL_FRAMES}")
print(f"测试题目: {len(SELECTED_QUESTIONS)}")
print(f"\nVDPM 准确率: {vdpm_accuracy:.0%} ({len(correct_ids)}/{len(eval_results)})")
print(f"(这些题目之前 GPT-5 mini 全部答错)")

print(f"\n✓ 正确题目 ({len(correct_ids)}): {', '.join(correct_ids) if correct_ids else '无'}")
print(f"✗ 错误题目 ({len(wrong_ids)}): {', '.join(wrong_ids) if wrong_ids else '无'}")
if skipped_ids:
    print(f"⊘ 跳过题目 ({len(skipped_ids)}): {', '.join(skipped_ids)}")

print("\n" + "="*60)
print("逐题详情")
print("="*60)

for i, r in enumerate(eval_results, 1):
    status = "✓" if r['correct'] else "✗"
    print(f"\n{status} [{r['id']}]")
    print(f"   问题: {r['question'][:60]}...")
    print(f"   正确答案: {r['ground_truth_answer']}")
    print(f"   LLM提取答案: {r['extracted_answer']}")
    
    # 找到原始问题以获取之前的错误回答
    orig_q = next((q for q in SELECTED_QUESTIONS if q['id'] == r['id']), None)
    if orig_q:
        print(f"   之前错误回答: {orig_q['model_wrong_answer']}")
    
    if r['correct']:
        print(f"   ★ VDPM 帮助纠正了这道题!")

<cell_type>markdown</cell_type>## 12. 保存结果

In [None]:
# 保存结果
results = {
    "model": MODEL_NAME,
    "eval_model": "o4-mini",
    "total_frames": TOTAL_FRAMES,
    "vdpm_accuracy": vdpm_accuracy,
    "num_correct": len(correct_ids),
    "num_total": len(eval_results),
    "correct_ids": correct_ids,
    "wrong_ids": wrong_ids,
    "skipped_ids": skipped_ids,
    "eval_results": eval_results,  # 完整评估结果
}

output_file = "/content/vdpm_test_results.json"
with open(output_file, "w", encoding="utf-8") as f:
    json.dump(results, f, indent=2, ensure_ascii=False, default=str)

print(f"结果已保存到: {output_file}")

# 下载结果
from google.colab import files
files.download(output_file)