# Alpamayo-R1 Interactive Inference

Interactive notebook for exploring model inference on driving scenes.

**Requirements:**
- NVIDIA GPU with 24GB+ VRAM
- `pip install -e alpamayo/`

In [None]:

from pathlib import Path

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from alpamayo_r1.models.alpamayo_r1 import AlpamayoR1
from alpamayo_r1 import load_physical_aiavdataset
from alpamayo_r1 import helper

# Paths
PROJECT_ROOT = Path.cwd().parent.parent
CLIP_IDS_FILE = PROJECT_ROOT / "alpamayo" / "notebooks" / "clip_ids.parquet"

print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 1. Load Model

In [None]:
MODEL_ID = "nvidia/Alpamayo-R1-10B"

print(f"Loading {MODEL_ID}...")
model = AlpamayoR1.from_pretrained(MODEL_ID, dtype=torch.bfloat16).to("cuda")
processor = helper.get_processor(model.tokenizer)
print("Model loaded!")

## 2. Load Scene Data

In [None]:
# Load available clip IDs
clip_ids_df = pd.read_parquet(CLIP_IDS_FILE)
all_clip_ids = clip_ids_df["clip_id"].tolist()
print(f"Available clips: {len(all_clip_ids)}")

# Select a clip (change index or use random)
CLIP_INDEX = 0  # Change this to explore different clips
# CLIP_INDEX = random.randint(0, len(all_clip_ids) - 1)  # Uncomment for random

clip_id = all_clip_ids[CLIP_INDEX]
t0_us = 5_000_000  # 5 seconds into clip

print(f"\nLoading clip: {clip_id}")
print(f"Timestamp: {t0_us / 1e6:.1f}s")

data = load_physical_aiavdataset(clip_id, t0_us=t0_us)
print(f"Image frames shape: {data['image_frames'].shape}")

## 3. Visualize Scene

In [None]:
def show_camera_views(image_frames, frame_idx=-1):
    """Display all 4 camera views."""
    fig, axes = plt.subplots(2, 2, figsize=(16, 9))
    
    camera_names = [
        "Cross-Left 120°",
        "Front-Wide 120°",
        "Cross-Right 120°",
        "Front-Tele 30°"
    ]
    
    for idx, (ax, name) in enumerate(zip(axes.flat, camera_names)):
        img = image_frames[idx, frame_idx].permute(1, 2, 0).numpy().astype(np.uint8)
        ax.imshow(img)
        ax.set_title(name, fontsize=12)
        ax.axis('off')
    
    plt.suptitle(f"Clip: {clip_id[:16]}... | t0 = {t0_us/1e6:.1f}s", fontsize=14)
    plt.tight_layout()
    plt.show()

show_camera_views(data['image_frames'])

## 4. Run Inference

In [None]:
# Prepare inputs
messages = helper.create_message(data["image_frames"].flatten(0, 1))
inputs = processor.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=False,
    continue_final_message=True,
    return_dict=True,
    return_tensors="pt",
)

model_inputs = {
    "tokenized_data": inputs,
    "ego_history_xyz": data["ego_history_xyz"],
    "ego_history_rot": data["ego_history_rot"],
}
model_inputs = helper.to_device(model_inputs, "cuda")

# Run inference
print("Running inference...")
torch.cuda.manual_seed_all(42)

with torch.autocast("cuda", dtype=torch.bfloat16):
    pred_xyz, pred_rot, extra = model.sample_trajectories_from_data_with_vlm_rollout(
        data=model_inputs,
        top_p=0.98,
        temperature=0.6,
        num_traj_samples=1,
        max_generation_length=256,
        return_extra=True,
    )

print("Inference complete!")

## 5. Chain-of-Causation Reasoning

In [None]:
coc_text = extra["cot"][0] if extra.get("cot") else "No CoC generated"

print("=" * 80)
print("CHAIN-OF-CAUSATION REASONING")
print("=" * 80)
print(coc_text)

## 6. Trajectory Comparison

In [None]:
# Extract trajectories
gt_xyz = data["ego_future_xyz"].cpu().numpy()[0, 0]  # (64, 3)
pred_xyz_np = pred_xyz.cpu().numpy()[0, 0, 0]  # (64, 3)
history_xyz = data["ego_history_xyz"].cpu().numpy()[0, 0]  # (16, 3)

# Compute minADE
diff = np.linalg.norm(pred_xyz_np[:, :2] - gt_xyz[:, :2], axis=1)
min_ade = diff.mean()

print(f"minADE: {min_ade:.3f} meters")

# Plot
fig, ax = plt.subplots(figsize=(10, 10))

# History
ax.plot(history_xyz[:, 0], history_xyz[:, 1], 'b.-', 
        label=f'History (1.6s)', linewidth=2, markersize=6)

# Ground truth future
ax.plot(gt_xyz[:, 0], gt_xyz[:, 1], 'g.-', 
        label=f'Ground Truth (6.4s)', linewidth=2, markersize=4)

# Predicted future
ax.plot(pred_xyz_np[:, 0], pred_xyz_np[:, 1], 'r.-', 
        label=f'Predicted (minADE={min_ade:.2f}m)', linewidth=2, markersize=4)

# Ego position at t0
ax.scatter([0], [0], c='black', s=200, marker='*', label='t0 (ego)', zorder=5)

ax.set_xlabel('X (meters)')
ax.set_ylabel('Y (meters)')
ax.set_title(f'Trajectory Comparison\nClip: {clip_id[:16]}...')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_aspect('equal')

plt.tight_layout()
plt.show()

## 7. Quick Test Function

Convenience function to test different clips quickly.

In [None]:
def test_clip(clip_id: str, t0_us: int = 5_000_000, show_images: bool = True):
    """Run full inference pipeline on a clip."""
    
    # Load data
    data = load_physical_aiavdataset(clip_id, t0_us=t0_us)
    
    if show_images:
        show_camera_views(data['image_frames'])
    
    # Prepare inputs
    messages = helper.create_message(data["image_frames"].flatten(0, 1))
    inputs = processor.apply_chat_template(
        messages, tokenize=True, add_generation_prompt=False,
        continue_final_message=True, return_dict=True, return_tensors="pt",
    )
    model_inputs = helper.to_device({
        "tokenized_data": inputs,
        "ego_history_xyz": data["ego_history_xyz"],
        "ego_history_rot": data["ego_history_rot"],
    }, "cuda")
    
    # Inference
    with torch.autocast("cuda", dtype=torch.bfloat16):
        pred_xyz, pred_rot, extra = model.sample_trajectories_from_data_with_vlm_rollout(
            data=model_inputs, top_p=0.98, temperature=0.6,
            num_traj_samples=1, max_generation_length=256, return_extra=True,
        )
    
    # Results
    coc = extra["cot"][0] if extra.get("cot") else ""
    gt_xy = data["ego_future_xyz"].cpu().numpy()[0, 0, :, :2]
    pred_xy = pred_xyz.cpu().numpy()[0, 0, 0, :, :2]
    min_ade = np.linalg.norm(pred_xy - gt_xy, axis=1).mean()
    
    print(f"\n{'='*60}")
    print(f"Clip: {clip_id}")
    print(f"minADE: {min_ade:.3f} m")
    print(f"{'='*60}")
    print(f"CoC: {coc}")
    
    return {"clip_id": clip_id, "coc": coc, "min_ade": min_ade, "data": data, "pred_xyz": pred_xyz}

# Example: test a random clip
# result = test_clip(random.choice(all_clip_ids))

## 8. Batch Test (Optional)

Test multiple clips and collect statistics.

In [None]:
# Uncomment to run batch test
# N_CLIPS = 10
# test_clips = random.sample(all_clip_ids, N_CLIPS)
# 
# results = []
# for i, cid in enumerate(test_clips):
#     print(f"\n[{i+1}/{N_CLIPS}] Testing {cid[:16]}...")
#     try:
#         r = test_clip(cid, show_images=False)
#         results.append(r)
#     except Exception as e:
#         print(f"  Error: {e}")
# 
# ades = [r["min_ade"] for r in results]
# print(f"\n\nBatch Results ({len(results)} clips):")
# print(f"  minADE: {np.mean(ades):.3f} ± {np.std(ades):.3f} m")