# SNOW VLM4D Evaluation on Google Colab

This notebook runs the full SNOW pipeline on VLM4D benchmark:
1. Video frames → MapAnything → 3D point clouds
2. Point clouds → HDBSCAN clustering → Object segmentation
3. Objects → Cross-frame tracking → Temporal tracks
4. Tracks → 4D Scene Graph → Text serialization
5. Text 4DSG → **Gemma3-4B-IT** → Answer

**Requirements:** GPU runtime (T4 recommended)

In [None]:
# Check GPU availability
!nvidia-smi

## 1. Setup Environment

In [None]:
# Install dependencies
!pip install -q torch torchvision
!pip install -q hdbscan scipy numpy opencv-python pillow
!pip install -q google-genai
!pip install -q huggingface_hub transformers

In [None]:
# Set Google AI API Key
import os
from google.colab import userdata

# Option 1: Use Colab secrets (recommended)
try:
    os.environ['GOOGLE_AI_API_KEY'] = userdata.get('GOOGLE_AI_API_KEY')
except:
    # Option 2: Set manually
    os.environ['GOOGLE_AI_API_KEY'] = 'YOUR_API_KEY_HERE'  # Replace with your key

print("API Key set:", "Yes" if os.environ.get('GOOGLE_AI_API_KEY') else "No")

## 2. Define SNOW Core Components

Since we can't clone a private repo, we define the essential SNOW components inline.

In [None]:
import torch
import cv2
import numpy as np
from pathlib import Path
from PIL import Image
import tempfile
import urllib.request
from tqdm import tqdm
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
from enum import Enum
import json

print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# ==== Token Definitions ====

@dataclass(frozen=True)
class PatchToken:
    row: int
    col: int
    iou: float

@dataclass(frozen=True)
class CentroidToken:
    x: float
    y: float
    z: float

@dataclass(frozen=True)
class ShapeToken:
    x_mu: float
    x_sigma: float
    x_min: float
    x_max: float
    y_mu: float
    y_sigma: float
    y_min: float
    y_max: float
    z_mu: float
    z_sigma: float
    z_min: float
    z_max: float

@dataclass(frozen=True)
class TemporalToken:
    t_start: int
    t_end: int

@dataclass(frozen=True)
class STEPToken:
    patch_tokens: List[PatchToken]
    centroid: CentroidToken
    shape: ShapeToken
    temporal: TemporalToken

def mask_to_patch_tokens(mask, grid_size=16, iou_threshold=0.5):
    H, W = mask.shape
    patch_h, patch_w = H // grid_size, W // grid_size
    tokens = []
    for row in range(grid_size):
        for col in range(grid_size):
            y_start, y_end = row * patch_h, (row + 1) * patch_h
            x_start, x_end = col * patch_w, (col + 1) * patch_w
            patch = mask[y_start:y_end, x_start:x_end]
            iou = patch.sum() / patch.size if patch.size > 0 else 0
            if iou > iou_threshold:
                tokens.append(PatchToken(row=row, col=col, iou=float(iou)))
    return tokens

def build_centroid_token(points_xyz):
    if len(points_xyz) == 0:
        return CentroidToken(0, 0, 0)
    mean = points_xyz.mean(axis=0)
    return CentroidToken(float(mean[0]), float(mean[1]), float(mean[2]))

def build_shape_token(points_xyz):
    if len(points_xyz) == 0:
        return ShapeToken(0,0,0,0, 0,0,0,0, 0,0,0,0)
    return ShapeToken(
        x_mu=float(points_xyz[:,0].mean()), x_sigma=float(points_xyz[:,0].std()),
        x_min=float(points_xyz[:,0].min()), x_max=float(points_xyz[:,0].max()),
        y_mu=float(points_xyz[:,1].mean()), y_sigma=float(points_xyz[:,1].std()),
        y_min=float(points_xyz[:,1].min()), y_max=float(points_xyz[:,1].max()),
        z_mu=float(points_xyz[:,2].mean()), z_sigma=float(points_xyz[:,2].std()),
        z_min=float(points_xyz[:,2].min()), z_max=float(points_xyz[:,2].max()),
    )

def build_step_token(mask, points_xyz, t_start, t_end, grid_size=16, iou_threshold=0.5):
    return STEPToken(
        patch_tokens=mask_to_patch_tokens(mask, grid_size, iou_threshold),
        centroid=build_centroid_token(points_xyz),
        shape=build_shape_token(points_xyz),
        temporal=TemporalToken(t_start, t_end),
    )

print("Token components defined!")

In [None]:
# ==== HDBSCAN Clustering ====
import hdbscan

@dataclass(frozen=True)
class HDBSCANConfig:
    min_cluster_size: int = 30
    min_samples: int = 5

@dataclass(frozen=True)
class ClusterResult:
    labels: np.ndarray
    clusters: List[np.ndarray]

def cluster_points(points_xyz, config):
    clusterer = hdbscan.HDBSCAN(
        min_cluster_size=config.min_cluster_size,
        min_samples=config.min_samples,
    )
    labels = clusterer.fit_predict(points_xyz)
    clusters = []
    for label in sorted(set(labels)):
        if label == -1:
            continue
        idx = np.where(labels == label)[0]
        clusters.append(idx)
    return ClusterResult(labels=labels, clusters=clusters)

print("HDBSCAN clustering defined!")

In [None]:
# ==== H-hop Filter ====

@dataclass(frozen=True)
class HHopConfig:
    max_extent: float = 50.0
    max_sigma: float = 10.0
    max_aspect_ratio: float = 20.0

def filter_implausible(steps, config):
    """Filter out geometrically implausible STEP tokens."""
    valid = {}
    for k, step in steps.items():
        s = step.shape
        extents = [s.x_max - s.x_min, s.y_max - s.y_min, s.z_max - s.z_min]
        max_ext = max(extents)
        min_ext = min(e for e in extents if e > 0.01) if any(e > 0.01 for e in extents) else 1
        aspect = max_ext / min_ext if min_ext > 0 else 1
        
        if max_ext > config.max_extent:
            continue
        if aspect > config.max_aspect_ratio:
            continue
        if max(s.x_sigma, s.y_sigma, s.z_sigma) > config.max_sigma:
            continue
        valid[k] = step
    return valid

print("H-hop filter defined!")

In [None]:
# ==== Object Tracker ====
from scipy.optimize import linear_sum_assignment

@dataclass(frozen=True)
class TrackerConfig:
    geometric_weight: float = 0.5
    semantic_weight: float = 0.5
    max_centroid_distance: float = 5.0
    max_association_cost: float = 2.0
    max_age: int = 5

@dataclass
class Track:
    track_id: int
    steps: List[STEPToken] = field(default_factory=list)
    age: int = 0
    
    def update(self, step):
        self.steps.append(step)
        self.age = 0

class ObjectTracker:
    def __init__(self, config=None):
        self.config = config or TrackerConfig()
        self.tracks = {}
        self.next_id = 0
    
    def update(self, detections, frame_idx):
        if not self.tracks:
            for det_id, step in detections.items():
                track = Track(track_id=self.next_id)
                track.update(step)
                self.tracks[self.next_id] = track
                self.next_id += 1
            return
        
        # Simple nearest-neighbor matching by centroid
        det_list = list(detections.values())
        track_list = list(self.tracks.values())
        
        if not det_list:
            for t in track_list:
                t.age += 1
            return
        
        cost = np.zeros((len(det_list), len(track_list)))
        for i, det in enumerate(det_list):
            for j, track in enumerate(track_list):
                if track.steps:
                    prev = track.steps[-1]
                    d = np.sqrt(
                        (det.centroid.x - prev.centroid.x)**2 +
                        (det.centroid.y - prev.centroid.y)**2 +
                        (det.centroid.z - prev.centroid.z)**2
                    )
                    cost[i, j] = d
                else:
                    cost[i, j] = 1e6
        
        row_ind, col_ind = linear_sum_assignment(cost)
        matched_dets = set()
        matched_tracks = set()
        
        for r, c in zip(row_ind, col_ind):
            if cost[r, c] < self.config.max_centroid_distance:
                track_list[c].update(det_list[r])
                matched_dets.add(r)
                matched_tracks.add(c)
        
        # New tracks for unmatched detections
        for i, det in enumerate(det_list):
            if i not in matched_dets:
                track = Track(track_id=self.next_id)
                track.update(det)
                self.tracks[self.next_id] = track
                self.next_id += 1
        
        # Age unmatched tracks
        for j, track in enumerate(track_list):
            if j not in matched_tracks:
                track.age += 1
    
    def get_tracks(self):
        return {tid: t for tid, t in self.tracks.items() if t.age <= self.config.max_age}

print("Object tracker defined!")

In [None]:
# ==== Scene Graph & 4DSG ====

class HorizontalRelation(Enum):
    LEFT = "left"
    RIGHT = "right"
    FRONT = "front"
    BACK = "back"

class VerticalRelation(Enum):
    ABOVE = "above"
    BELOW = "below"
    LEVEL = "level"

@dataclass(frozen=True)
class SceneNode:
    node_id: int
    step: STEPToken
    position: np.ndarray

@dataclass(frozen=True)
class SceneEdge:
    src: int
    dst: int
    distance: float
    relation: str

@dataclass(frozen=True)
class SceneGraph:
    nodes: Dict[int, SceneNode]
    edges: List[SceneEdge]
    frame_idx: int

@dataclass(frozen=True)
class TemporalTrack:
    track_id: int
    steps: List[STEPToken]

@dataclass(frozen=True)
class TemporalWindow:
    tracks: Dict[int, TemporalTrack]

@dataclass(frozen=True)
class FourDSceneGraph:
    spatial_graphs: List[SceneGraph]
    temporal_window: TemporalWindow
    ego_poses: Dict[int, List[float]]

def compute_spatial_relation(dx, dy, dz, dist):
    h = "right" if dx > 0 else "left" if abs(dx) > abs(dy) else "front" if dy > 0 else "back"
    v = "above" if dz > 0.3 else "below" if dz < -0.3 else "level"
    d = "near" if dist < 3 else "medium" if dist < 8 else "far"
    return f"{h}_{v}_{d}"

def build_scene_graph(steps, frame_idx=0):
    nodes = {}
    for nid, step in steps.items():
        pos = np.array([step.centroid.x, step.centroid.y, step.centroid.z])
        nodes[nid] = SceneNode(node_id=nid, step=step, position=pos)
    
    edges = []
    ids = sorted(nodes.keys())
    for i, src_id in enumerate(ids):
        for dst_id in ids[i+1:]:
            delta = nodes[dst_id].position - nodes[src_id].position
            dist = float(np.linalg.norm(delta))
            if dist < 50:
                rel = compute_spatial_relation(delta[0], delta[1], delta[2], dist)
                edges.append(SceneEdge(src=src_id, dst=dst_id, distance=dist, relation=rel))
    
    return SceneGraph(nodes=nodes, edges=edges, frame_idx=frame_idx)

print("Scene graph components defined!")

In [None]:
# ==== 4DSG Serialization ====

def serialize_4dsg(four_dsg):
    """Serialize 4DSG to text for VLM."""
    lines = ["=== 4D Scene Graph ===", ""]
    
    # Ego poses
    lines.append("Ego Agent Trajectory:")
    for t, pose in sorted(four_dsg.ego_poses.items()):
        lines.append(f"  Frame {t}: position=({pose[0]:.2f}, {pose[1]:.2f}, {pose[2]:.2f})")
    lines.append("")
    
    # Object tracks
    lines.append(f"Objects ({len(four_dsg.temporal_window.tracks)} tracked):")
    for tid, track in four_dsg.temporal_window.tracks.items():
        lines.append(f"\n  Object {tid}:")
        for step in track.steps[-5:]:  # Last 5 observations
            c = step.centroid
            s = step.shape
            size_x = s.x_max - s.x_min
            size_y = s.y_max - s.y_min
            size_z = s.z_max - s.z_min
            lines.append(f"    Position: ({c.x:.2f}, {c.y:.2f}, {c.z:.2f})")
            lines.append(f"    Size: {size_x:.2f}m x {size_y:.2f}m x {size_z:.2f}m")
            lines.append(f"    Visible: frames {step.temporal.t_start}-{step.temporal.t_end}")
    lines.append("")
    
    # Spatial relations (from last frame)
    if four_dsg.spatial_graphs:
        last_sg = four_dsg.spatial_graphs[-1]
        lines.append(f"Spatial Relations (Frame {last_sg.frame_idx}):")
        for edge in last_sg.edges[:20]:
            lines.append(f"  Object {edge.src} is {edge.relation} of Object {edge.dst} ({edge.distance:.1f}m)")
    
    return "\n".join(lines)

print("Serialization defined!")

In [None]:
# ==== VLM Client ====
from google import genai
from google.genai import types

client = genai.Client(api_key=os.environ['GOOGLE_AI_API_KEY'])

# Test connection
response = client.models.generate_content(
    model='gemma-3-4b-it',
    contents='What is 2+2? Answer with just the number.',
)
print(f"Gemma test response: {response.text}")

## 3. Point Cloud Generation (Simulated)

Since MapAnything requires specific setup, we'll simulate point clouds from video frames using depth estimation.

In [None]:
def download_video(url, cache_dir='video_cache'):
    """Download video from URL."""
    os.makedirs(cache_dir, exist_ok=True)
    filename = url.split('/')[-1]
    local_path = os.path.join(cache_dir, filename)
    if not os.path.exists(local_path):
        urllib.request.urlretrieve(url, local_path)
    return local_path

def extract_frames(video_path, num_frames=5):
    """Extract frames from video."""
    cap = cv2.VideoCapture(video_path)
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    indices = np.linspace(0, total - 1, min(num_frames, total), dtype=int).tolist()
    
    frames = []
    for idx in indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
        ret, frame = cap.read()
        if ret:
            frames.append(frame)
    cap.release()
    return frames

def simulate_point_cloud(frame, num_points=5000):
    """Generate simulated point cloud from frame.
    
    This is a placeholder - in full SNOW pipeline, MapAnything would generate real 3D.
    We simulate by creating random 3D points with color-based clustering.
    """
    H, W = frame.shape[:2]
    
    # Sample random pixel locations
    ys = np.random.randint(0, H, num_points)
    xs = np.random.randint(0, W, num_points)
    
    # Create pseudo-depth based on image position (objects at bottom are closer)
    depths = 5.0 + (ys / H) * 20.0  # 5-25m range
    
    # Convert to 3D (simple pinhole camera model)
    fx, fy = W, H  # Simplified focal lengths
    cx, cy = W/2, H/2
    
    X = (xs - cx) * depths / fx
    Y = (ys - cy) * depths / fy
    Z = depths
    
    points = np.stack([X, Y, Z], axis=1).astype(np.float32)
    return points

print("Video processing functions defined!")

In [None]:
def run_snow_pipeline(video_path, question, num_frames=5):
    """Run full SNOW pipeline."""
    
    # Configs
    hdbscan_config = HDBSCANConfig(min_cluster_size=30)
    hhop_config = HHopConfig()
    tracker_config = TrackerConfig()
    
    # Step 1: Extract frames
    frames = extract_frames(video_path, num_frames)
    print(f"  Extracted {len(frames)} frames")
    
    # Step 2: Generate point clouds (simulated)
    tracker = ObjectTracker(tracker_config)
    spatial_graphs = []
    
    for frame_idx, frame in enumerate(frames):
        # Get point cloud
        points = simulate_point_cloud(frame)
        print(f"  Frame {frame_idx}: {len(points)} points")
        
        # Cluster points
        cluster_result = cluster_points(points, hdbscan_config)
        print(f"    Found {len(cluster_result.clusters)} clusters")
        
        # Create STEP tokens
        frame_steps = {}
        for cluster_idx, cluster_indices in enumerate(cluster_result.clusters):
            if len(cluster_indices) < 20:
                continue
            cluster_pts = points[cluster_indices]
            mask = np.ones((32, 32), dtype=bool)
            step = build_step_token(mask, cluster_pts, frame_idx, frame_idx + 1)
            frame_steps[cluster_idx] = step
        
        # Filter implausible
        filtered_steps = filter_implausible(frame_steps, hhop_config)
        print(f"    {len(filtered_steps)} valid objects after H-hop filter")
        
        # Update tracker
        tracker.update(filtered_steps, frame_idx)
        
        # Build spatial graph
        sg = build_scene_graph(filtered_steps, frame_idx=frame_idx)
        spatial_graphs.append(sg)
    
    # Get tracks and build 4DSG
    tracks = tracker.get_tracks()
    temporal_tracks = {
        tid: TemporalTrack(track_id=tid, steps=t.steps)
        for tid, t in tracks.items()
    }
    temporal_window = TemporalWindow(tracks=temporal_tracks)
    ego_poses = {i: [0.0, i * 0.5, 0.0, 1.0, 0.0, 0.0, 0.0] for i in range(len(frames))}
    
    four_dsg = FourDSceneGraph(
        temporal_window=temporal_window,
        spatial_graphs=spatial_graphs,
        ego_poses=ego_poses,
    )
    print(f"  4DSG: {len(tracks)} tracks, {len(spatial_graphs)} frames")
    
    # Serialize 4DSG
    scene_text = serialize_4dsg(four_dsg)
    
    # Query Gemma
    full_prompt = f"""You are a spatial reasoning assistant analyzing a 4D scene.

{scene_text}

Based on the scene information above, answer the following question:

{question}

Think step by step about the spatial relationships and provide your final answer."""
    
    response = client.models.generate_content(
        model='gemma-3-4b-it',
        contents=full_prompt,
        config=types.GenerateContentConfig(
            max_output_tokens=1024,
            temperature=0.0,
        )
    )
    
    return response.text, scene_text

print("SNOW pipeline defined!")

## 4. Download VLM4D Data & Run Evaluation

In [None]:
# Download VLM4D benchmark data
!mkdir -p data/vlm4d
!wget -q https://huggingface.co/datasets/shijiezhou/VLM4D/resolve/main/real_mc.json -O data/vlm4d/real_mc.json

with open('data/vlm4d/real_mc.json') as f:
    queries = json.load(f)
print(f"Loaded {len(queries)} questions from real_mc.json")

In [None]:
# Run evaluation (limit samples for testing)
MAX_SAMPLES = 10
NUM_FRAMES = 5

test_queries = queries[:MAX_SAMPLES]
results = []

for i, query in enumerate(tqdm(test_queries, desc="Processing")):
    print(f"\n[{i+1}/{len(test_queries)}] {query['id']}")
    
    try:
        video_path = download_video(query['video'])
        
        # Build question with choices
        question = f"{query['question']}\n"
        for key, value in query['choices'].items():
            question += f"{key}. {value}\n"
        question += "\nAnswer with the letter of your choice (A, B, C, or D)."
        
        response, scene_text = run_snow_pipeline(video_path, question, NUM_FRAMES)
        query['response'] = response
        print(f"  Response: {response[:100]}...")
        
    except Exception as e:
        print(f"  Error: {e}")
        query['response'] = f"Error: {e}"
    
    results.append(query)

# Save results
with open('snow_vlm4d_results.json', 'w') as f:
    json.dump(results, f, indent=2)
print(f"\nSaved results to snow_vlm4d_results.json")

## 5. Evaluate Results

In [None]:
import re

def extract_answer(response, choices):
    """Extract answer from model response."""
    response_upper = response.upper()
    
    patterns = [
        r'final answer[:\s]+([A-D])',
        r'answer[:\s]+([A-D])',
        r'\(([A-D])\)',
        r'^([A-D])[.:\s]',
    ]
    
    for pattern in patterns:
        match = re.search(pattern, response, re.IGNORECASE | re.MULTILINE)
        if match:
            return match.group(1).upper()
    
    for char in response_upper:
        if char in ['A', 'B', 'C', 'D']:
            return char
    
    return ''

def evaluate_results(results):
    """Compute accuracy metrics."""
    correct = 0
    total = 0
    
    for r in results:
        if 'Error' in r.get('response', ''):
            continue
        
        pred = extract_answer(r['response'], r['choices'])
        gt = r['answer']
        
        if isinstance(gt, str) and len(gt) == 1:
            gt_letter = gt.upper()
        else:
            gt_letter = None
            for key, value in r['choices'].items():
                if str(value).lower() == str(gt).lower():
                    gt_letter = key.upper()
                    break
            if gt_letter is None:
                gt_letter = str(gt).upper()
        
        is_correct = pred == gt_letter
        if is_correct:
            correct += 1
        total += 1
        
        print(f"{r['id']}: pred={pred}, gt={gt_letter}, correct={is_correct}")
    
    accuracy = correct / total if total > 0 else 0
    print(f"\n{'='*50}")
    print(f"Accuracy: {accuracy:.2%} ({correct}/{total})")
    print(f"{'='*50}")
    return accuracy

accuracy = evaluate_results(results)

## 6. Download Results

In [None]:
from google.colab import files
files.download('snow_vlm4d_results.json')