# arcOS Benchmark — Video Scene Graph QA with GNN + LLM

This notebook implements the arcOS benchmark pipeline for video question answering:
- **Data:** AGQA 2.0 QA pairs + Action Genome scene graph annotations
- **Graph:** PyG HeteroData scene graphs (object nodes, spatial + temporal edges)
- **Retrieval:** Per-video FAISS k-NN + PCST subgraph extraction
- **GNN:** HeteroGATv2 encoder with per-edge-type attention
- **Verbalization:** Attention-weighted triple formatting for LLM prompts
- **Evaluation:** EM, F1, retrieval hit rate, attention precision

**Requirements:**
- Google Colab with GPU runtime (T4 or better)
- Google Drive mounted for checkpointing
- ~5GB free space on Drive

In [None]:
# Cell 1: Clone repository
import os
%cd /content
!rm -rf arcOS-benchmark-colab
!git clone https://github.com/ashtonalex/arcOS-benchmark-colab
%cd /content/arcOS-benchmark-colab
print('\n✓ Repository cloned')

In [None]:
# Cell 2: Install dependencies via uv
import os, sys, subprocess

# Colab UV workaround
os.environ['UV_CONSTRAINT'] = ''
os.environ['UV_BUILD_CONSTRAINT'] = ''

PY = sys.executable

# Install uv if needed
try:
    subprocess.run(['uv', '--version'], capture_output=True, check=True)
except (FileNotFoundError, subprocess.CalledProcessError):
    %pip install -q uv

# PyTorch + CUDA
!uv pip install --python {PY} torch torchvision torchaudio

# Core deps
!uv pip install --python {PY} faiss-gpu-cu12 sentence-transformers gdown pcst-fast tqdm numpy

# PyG + sparse/scatter
!uv pip install --python {PY} torch_geometric torch_scatter torch_sparse \
    -f https://data.pyg.org/whl/torch-2.8.0+cu128.html

# Verify GPU
import torch
print(f'\nGPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else "NOT AVAILABLE"}')
print('✓ Dependencies installed')

In [None]:
# Cell 3: Clean room import
import sys, shutil, importlib
from pathlib import Path

REPO_ROOT = Path('/content/arcOS-benchmark-colab')
SRC_ROOT = REPO_ROOT / 'src'

# Purge bytecode
for d in SRC_ROOT.rglob('__pycache__'):
    shutil.rmtree(d)

# Scrub cached modules
for k in [k for k in sys.modules if k.startswith(('src', 'src.'))]:
    del sys.modules[k]

# Pin path
repo_str = str(REPO_ROOT)
sys.path = [p for p in sys.path if p != repo_str]
sys.path.insert(0, repo_str)

# Fresh imports — all video pipeline modules
from src.config import BenchmarkConfig
from src.utils.seeds import set_seeds
from src.utils.checkpoints import (
    ensure_drive_mounted, checkpoint_exists,
    save_checkpoint, load_checkpoint, create_checkpoint_dirs,
)
from src.data.agqa_loader import AGQALoader
from src.data.ag_converter import load_ag_annotations, convert_all
from src.data.scene_graph_builder import SceneGraphBuilder
from src.retrieval.embeddings import TextEmbedder
from src.retrieval.video_retriever import VideoRetriever
from src.gnn.hetero_encoder import HeteroGATv2Encoder
from src.gnn.hetero_trainer import HeteroGNNTrainer
from src.gnn.hetero_model_wrapper import HeteroGNNModel
from src.verbalization.scene_verbalizer import SceneVerbalizer
from src.evaluation.benchmark import BenchmarkEvaluator

print('✓ All video pipeline modules imported (clean load)')

In [None]:
# Cell 4: Configuration
config = BenchmarkConfig(
    seed=42,
    deterministic=True,
    drive_root='/content/drive/MyDrive/arcOS_benchmark',
    # Video scene graph overrides
    agqa_subset_size=5000,
    ag_frame_sample_rate=3,
    top_k_seeds=10,
    pcst_budget=70,
    pcst_temporal_cost_weight=0.5,
    gnn_hidden_dim=256,
    gnn_num_layers=3,
    gnn_num_heads=4,
    gnn_encoder_type='hetero_gatv2',
    # Training
    num_epochs=10,
    patience=5,
    learning_rate=1e-4,
    batch_size=16,
    # Paths
    ag_annotations_dir='/content/action_genome',
    agqa_data_dir='/content/agqa',
)

config.print_summary()

In [None]:
# Cell 5: Seeds + Drive mount + checkpoint dirs
set_seeds(config.seed, config.deterministic)

drive_mounted = ensure_drive_mounted()
if drive_mounted:
    create_checkpoint_dirs(config.checkpoint_dir, config.results_dir)
else:
    print('WARNING: Drive not mounted. Checkpoints will use local storage.')

## Phase 1: AGQA + Action Genome Data

In [None]:
# Cell 7: Download Action Genome annotations
import gdown
from pathlib import Path

AG_DIR = Path(config.ag_annotations_dir)
AG_DIR.mkdir(parents=True, exist_ok=True)

# Action Genome annotation files (Google Drive IDs)
AG_FILES = {
    'object_bbox_and_relationship.pkl': '1MBTtSbchUXTdhUzFyk-XfMSD25Y1c_8R',
    'person_bbox.pkl': '1Vbg-hMaIBhbP4VVaKSA9UlG7l3ZoOPl0',
    'frame_list.txt': '1mADeXW_cJbUfrroJMY-EhFPcfDNwKwfB',
    'object_classes.txt': '1gMsXfK8ZdqvNB2XVDSWuixI2F3-jmFqq',
    'relationship_classes.txt': '1lyalRkbSn1zVB8LFR5H6YjJF8SHYqLjr',
}

print('Downloading Action Genome annotations...')
for filename, file_id in AG_FILES.items():
    dest = AG_DIR / filename
    if dest.exists():
        print(f'  {filename}: already exists')
    else:
        print(f'  {filename}: downloading...')
        gdown.download(id=file_id, output=str(dest), quiet=True)
        print(f'  {filename}: done ({dest.stat().st_size / 1e6:.1f} MB)')

print('\n✓ Action Genome annotations ready')

In [None]:
# Cell 8: Download AGQA 2.0 balanced QA pairs
AGQA_DIR = Path(config.agqa_data_dir)
AGQA_DIR.mkdir(parents=True, exist_ok=True)

print('Downloading AGQA 2.0 balanced splits...')
agqa_paths = AGQALoader.download_agqa(str(AGQA_DIR))

print('\n✓ AGQA data ready')
for split, path in agqa_paths.items():
    print(f'  {split}: {path} ({path.stat().st_size / 1e6:.1f} MB)')

In [None]:
# Cell 9: Parse AGQA, subsample, split by video_id
loader = AGQALoader(config)

splits_path = config.get_checkpoint_path('agqa_splits.pkl')

if checkpoint_exists(splits_path):
    print('Loading AGQA splits from checkpoint...')
    agqa_splits = load_checkpoint(splits_path, format='pickle')
    train_samples = agqa_splits['train']
    val_samples = agqa_splits['val']
    test_samples = agqa_splits['test']
else:
    # Load all splits and merge
    all_samples = []
    for split, path in agqa_paths.items():
        print(f'Loading {split}...')
        samples = loader.load_from_file(str(path))
        print(f'  {split}: {len(samples)} QA pairs')
        all_samples.extend(samples)

    print(f'\nTotal QA pairs: {len(all_samples)}')

    # Subsample
    all_samples = loader.subsample(all_samples)
    print(f'After subsample: {len(all_samples)}')

    # Split by video_id (no leakage)
    train_samples, val_samples, test_samples = loader.split(all_samples)

    # Save
    agqa_splits = {'train': train_samples, 'val': val_samples, 'test': test_samples}
    save_checkpoint(agqa_splits, splits_path, format='pickle')

print(f'\nSplit sizes:')
print(f'  Train: {len(train_samples)} QA pairs')
print(f'  Val:   {len(val_samples)} QA pairs')
print(f'  Test:  {len(test_samples)} QA pairs')

# Video coverage
train_vids = loader.get_unique_video_ids(train_samples)
val_vids = loader.get_unique_video_ids(val_samples)
test_vids = loader.get_unique_video_ids(test_samples)
all_vids = train_vids | val_vids | test_vids

print(f'\nVideo coverage:')
print(f'  Train videos: {len(train_vids)}')
print(f'  Val videos:   {len(val_vids)}')
print(f'  Test videos:  {len(test_vids)}')
print(f'  Total unique: {len(all_vids)}')
print(f'  Overlap train/val: {len(train_vids & val_vids)} (should be 0)')
print(f'  Overlap train/test: {len(train_vids & test_vids)} (should be 0)')

In [None]:
# Cell 10: Convert AG annotations for AGQA video IDs
import pickle

converted_path = config.get_checkpoint_path('ag_converted.pkl')

if checkpoint_exists(converted_path):
    print('Loading converted AG annotations from checkpoint...')
    ag_annotations = load_checkpoint(converted_path, format='pickle')
else:
    print('Loading raw AG annotations...')
    raw_ag = load_ag_annotations(str(AG_DIR / 'object_bbox_and_relationship.pkl'))
    print(f'  Raw AG: {len(raw_ag)} frame entries')

    # Load class map if available
    class_map = None
    class_file = AG_DIR / 'object_classes.txt'
    if class_file.exists():
        with open(class_file) as f:
            classes = [line.strip() for line in f if line.strip()]
        class_map = {i: c for i, c in enumerate(classes)}
        print(f'  Loaded {len(class_map)} object classes')

    print(f'\nConverting annotations for {len(all_vids)} videos...')
    ag_annotations = convert_all(
        raw_ag, all_vids,
        frame_sample_rate=config.ag_frame_sample_rate,
        class_map=class_map,
    )
    print(f'  Converted: {len(ag_annotations)} videos with frames')

    save_checkpoint(ag_annotations, converted_path, format='pickle')
    del raw_ag  # Free memory

print(f'\n✓ AG annotations: {len(ag_annotations)} videos')

# Quick stats
total_frames = sum(len(v['frames']) for v in ag_annotations.values())
total_objects = sum(
    sum(len(f['objects']) for f in v['frames'])
    for v in ag_annotations.values()
)
total_relations = sum(
    sum(len(f['relations']) for f in v['frames'])
    for v in ag_annotations.values()
)
print(f'  Total frames: {total_frames}')
print(f'  Total objects: {total_objects}')
print(f'  Total relations: {total_relations}')

In [None]:
# Cell 11: Data inspection
from collections import Counter
import random

# Sample 3 videos
rng = random.Random(config.seed)
sample_vids = rng.sample(sorted(ag_annotations.keys()), min(3, len(ag_annotations)))

print('=' * 60)
print('DATA INSPECTION — Sample Videos')
print('=' * 60)

for vid in sample_vids:
    ann = ag_annotations[vid]
    frames = ann['frames']
    print(f'\n--- Video: {vid} ---')
    print(f'  Frames: {len(frames)}')
    
    all_classes = []
    all_preds = []
    for f in frames:
        for obj in f['objects']:
            all_classes.append(obj['class'])
        for rel in f['relations']:
            all_preds.append(rel['predicate'])
    
    print(f'  Objects: {len(all_classes)} (unique classes: {len(set(all_classes))})')
    print(f'  Relations: {len(all_preds)} (unique predicates: {len(set(all_preds))})')
    
    cls_counts = Counter(all_classes).most_common(5)
    pred_counts = Counter(all_preds).most_common(5)
    print(f'  Top classes: {cls_counts}')
    print(f'  Top predicates: {pred_counts}')

# Global distributions
print('\n' + '=' * 60)
print('GLOBAL DISTRIBUTIONS')
print('=' * 60)

global_classes = Counter()
global_preds = Counter()
for ann in ag_annotations.values():
    for f in ann['frames']:
        for obj in f['objects']:
            global_classes[obj['class']] += 1
        for rel in f['relations']:
            global_preds[rel['predicate']] += 1

print(f'\nObject classes ({len(global_classes)} unique):')
for cls, cnt in global_classes.most_common(10):
    print(f'  {cnt:6d}  {cls}')

print(f'\nRelation predicates ({len(global_preds)} unique):')
for pred, cnt in global_preds.most_common(10):
    print(f'  {cnt:6d}  {pred}')

In [None]:
# Cell 12: Phase 1 validation
print('=' * 60)
print('PHASE 1 VALIDATION')
print('=' * 60)

checks = {
    'AG annotations loaded': len(ag_annotations) > 0,
    'Train split non-empty': len(train_samples) > 0,
    'Val split non-empty': len(val_samples) > 0,
    'Test split non-empty': len(test_samples) > 0,
    'No video leakage (train/val)': len(train_vids & val_vids) == 0,
    'No video leakage (train/test)': len(train_vids & test_vids) == 0,
    'Samples have required fields': all(
        'question' in s and 'answer' in s and 'video_id' in s
        for s in train_samples[:10]
    ),
    'Annotations have frames': all(
        len(ag_annotations[vid]['frames']) > 0
        for vid in list(ag_annotations.keys())[:10]
    ),
}

all_pass = True
for check, passed in checks.items():
    status = '✓' if passed else '✗'
    print(f'  {status} {check}')
    if not passed:
        all_pass = False

print('\n' + '=' * 60)
if all_pass:
    print('✓ PHASE 1 COMPLETE')
else:
    print('✗ PHASE 1 INCOMPLETE — review failed checks')
print('=' * 60)

## Phase 2: Scene Graph Construction

In [None]:
# Cell 14: Build HeteroData per video
from tqdm.auto import tqdm

graphs_path = config.get_checkpoint_path('scene_graphs.pkl')

if checkpoint_exists(graphs_path):
    print('Loading scene graphs from checkpoint...')
    scene_graphs = load_checkpoint(graphs_path, format='pickle')
else:
    print('Building scene graphs with text embeddings...')
    embedder = TextEmbedder(config)
    builder = SceneGraphBuilder(config, embedder=embedder)

    scene_graphs = {}
    failed = 0
    for vid in tqdm(sorted(ag_annotations.keys()), desc='Building graphs'):
        try:
            data = builder.build(ag_annotations[vid])
            scene_graphs[vid] = data
        except Exception as e:
            failed += 1
            if failed <= 5:
                print(f'  Warning: {vid} failed: {e}')

    if failed > 0:
        print(f'  {failed} videos failed to build')

    save_checkpoint(scene_graphs, graphs_path, format='pickle')

    # Free embedder GPU memory
    try:
        embedder.model.to('cpu')
    except Exception:
        pass

print(f'\n✓ Scene graphs: {len(scene_graphs)} videos')

# Quick stats
import numpy as np
node_counts = [g["object"].x.shape[0] for g in scene_graphs.values()]
spatial_counts = [g["object", "spatial_rel", "object"].edge_index.shape[1] for g in scene_graphs.values()]
temporal_counts = [g["object", "temporal", "object"].edge_index.shape[1] for g in scene_graphs.values()]

print(f'  Nodes — mean: {np.mean(node_counts):.0f}, min: {min(node_counts)}, max: {max(node_counts)}')
print(f'  Spatial edges — mean: {np.mean(spatial_counts):.0f}, min: {min(spatial_counts)}, max: {max(spatial_counts)}')
print(f'  Temporal edges — mean: {np.mean(temporal_counts):.0f}, min: {min(temporal_counts)}, max: {max(temporal_counts)}')

In [None]:
# Cell 15: Scene graph inspection
import matplotlib.pyplot as plt

print('=' * 60)
print('SCENE GRAPH INSPECTION')
print('=' * 60)

# Sample graph details
for vid in sample_vids[:2]:
    if vid not in scene_graphs:
        continue
    g = scene_graphs[vid]
    n_nodes = g['object'].x.shape[0]
    n_spatial = g['object', 'spatial_rel', 'object'].edge_index.shape[1]
    n_temporal = g['object', 'temporal', 'object'].edge_index.shape[1]
    print(f'\n  Video {vid}:')
    print(f'    Nodes: {n_nodes}')
    print(f'    Spatial edges: {n_spatial}')
    print(f'    Temporal edges: {n_temporal}')
    print(f'    Feature dim: {g["object"].x.shape[1]}')
    print(f'    Frames: {g.num_frames}')
    if hasattr(g, 'object_names'):
        unique_names = set(g.object_names)
        print(f'    Unique object names: {len(unique_names)}')
        print(f'    Sample names: {list(unique_names)[:8]}')

# Size histogram
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
axes[0].hist(node_counts, bins=30, color='#42A5F5', edgecolor='white')
axes[0].set_xlabel('Nodes'); axes[0].set_ylabel('Videos'); axes[0].set_title('Node Count Distribution')

axes[1].hist(spatial_counts, bins=30, color='#66BB6A', edgecolor='white')
axes[1].set_xlabel('Edges'); axes[1].set_ylabel('Videos'); axes[1].set_title('Spatial Edge Distribution')

axes[2].hist(temporal_counts, bins=30, color='#FFA726', edgecolor='white')
axes[2].set_xlabel('Edges'); axes[2].set_ylabel('Videos'); axes[2].set_title('Temporal Edge Distribution')

plt.tight_layout()
plt.show()

In [None]:
# Cell 16: Phase 2 validation
print('=' * 60)
print('PHASE 2 VALIDATION')
print('=' * 60)

# Check all AGQA videos have scene graphs
missing_vids = all_vids - set(scene_graphs.keys())

checks = {
    'All AGQA videos have graphs': len(missing_vids) == 0,
    'Both edge types present': all(
        ('object', 'spatial_rel', 'object') in g.edge_types
        and ('object', 'temporal', 'object') in g.edge_types
        for g in list(scene_graphs.values())[:20]
    ),
    'Correct feature dim (384)': all(
        g['object'].x.shape[1] == config.embedding_dim
        for g in list(scene_graphs.values())[:20]
    ),
    'object_names stored': all(
        hasattr(g, 'object_names') and g.object_names is not None
        for g in list(scene_graphs.values())[:20]
    ),
    'spatial_predicates stored': all(
        hasattr(g, 'spatial_predicates') and g.spatial_predicates is not None
        for g in list(scene_graphs.values())[:20]
    ),
}

all_pass = True
for check, passed in checks.items():
    status = '✓' if passed else '✗'
    print(f'  {status} {check}')
    if not passed:
        all_pass = False

if missing_vids:
    print(f'  Missing videos: {list(missing_vids)[:5]}...')

print('\n' + '=' * 60)
if all_pass:
    print('✓ PHASE 2 COMPLETE')
else:
    print('✗ PHASE 2 INCOMPLETE')
print('=' * 60)

## Phase 3: Per-Video Retrieval

In [None]:
# Cell 18: PCST configuration override
# Edit values below then run this cell.

config.pcst_budget = 70
config.pcst_cost = 0.015
config.pcst_pruning = 'none'
config.top_k_seeds = 10
config.pcst_temporal_cost_weight = 0.5

config.__post_init__()  # re-validate

print('PCST config applied:')
print(f'  budget={config.pcst_budget}, cost={config.pcst_cost}')
print(f'  pruning={config.pcst_pruning!r}, top_k_seeds={config.top_k_seeds}')
print(f'  temporal_cost_weight={config.pcst_temporal_cost_weight}')

In [None]:
# Cell 19: Initialize VideoRetriever
print('Initializing VideoRetriever...')
embedder = TextEmbedder(config)
retriever = VideoRetriever(config, embedder=embedder)
print(f'✓ VideoRetriever ready (top_k_seeds={config.top_k_seeds}, pcst_budget={config.pcst_budget})')

In [None]:
# Cell 20: Validate retrieval on 50 val examples
import time
import numpy as np

n_val_test = min(50, len(val_samples))
val_subset = val_samples[:n_val_test]

print(f'Validating retrieval on {n_val_test} val examples...\n')

retrieval_times = []
subgraph_sizes = []
hit_count = 0

for i, sample in enumerate(val_subset):
    vid = sample['video_id']
    if vid not in scene_graphs:
        continue

    result = retriever.retrieve(sample['question'], scene_graphs[vid])

    retrieval_times.append(result.retrieval_time_ms)
    subgraph_sizes.append(result.num_nodes)

    # Check hit: answer text matches any object_name in subgraph
    answer_text = sample['answer'].strip().lower()
    subgraph_names = getattr(result.subgraph, 'object_names', []) or []
    hit = any(name.strip().lower() == answer_text for name in subgraph_names)
    if hit:
        hit_count += 1

    if i < 5:
        print(f'  [{i+1}] Q: {sample["question"][:60]}...')
        print(f'       Nodes: {result.num_nodes}, Edges: {result.num_edges}, '
              f'Time: {result.retrieval_time_ms:.0f}ms, Hit: {"✓" if hit else "✗"}')

n_tested = len(retrieval_times)
hit_rate = hit_count / n_tested * 100 if n_tested else 0
avg_time = np.mean(retrieval_times) if retrieval_times else 0
avg_size = np.mean(subgraph_sizes) if subgraph_sizes else 0

print(f'\n--- Retrieval Summary ({n_tested} examples) ---')
print(f'  Hit rate: {hit_rate:.1f}% ({hit_count}/{n_tested})')
print(f'  Avg time: {avg_time:.1f}ms')
print(f'  Avg subgraph size: {avg_size:.1f} nodes')
print(f'  Max subgraph size: {max(subgraph_sizes) if subgraph_sizes else 0} nodes')

In [None]:
# Cell 21: Diagnostics deep dive — trace single example
import matplotlib.pyplot as plt

# Pick first val example
probe = val_samples[0]
probe_vid = probe['video_id']

if probe_vid in scene_graphs:
    print(f'Probe: Q="{probe["question"]}"')
    print(f'       A="{probe["answer"]}"  Video={probe_vid}')

    result = retriever.retrieve(probe['question'], scene_graphs[probe_vid])

    print(f'\nRetrieval result:')
    print(f'  Seed indices: {result.seed_indices}')
    print(f'  Subgraph nodes: {result.num_nodes}')
    print(f'  Subgraph edges: {result.num_edges}')
    print(f'  PCST used: {result.pcst_used}')
    print(f'  Time: {result.retrieval_time_ms:.1f}ms')

    if hasattr(result.subgraph, 'object_names'):
        print(f'\n  Object names in subgraph:')
        for name in set(result.subgraph.object_names):
            print(f'    - {name}')

    # Funnel plot
    sg = scene_graphs[probe_vid]
    full_nodes = sg['object'].x.shape[0]
    seeds = len(result.seed_indices)
    output_nodes = result.num_nodes

    fig, ax = plt.subplots(figsize=(8, 4))
    labels = ['Full graph', 'k-NN seeds', 'PCST output']
    values = [full_nodes, seeds, output_nodes]
    colors = ['#1565C0', '#FFA726', '#66BB6A']
    bars = ax.barh(labels[::-1], values[::-1], color=colors[::-1])
    for bar, val in zip(bars, values[::-1]):
        ax.text(bar.get_width() + 0.5, bar.get_y() + bar.get_height()/2,
                str(val), va='center', fontweight='bold')
    ax.set_xlabel('Nodes')
    ax.set_title('Retrieval Funnel')
    plt.tight_layout()
    plt.show()
else:
    print(f'Video {probe_vid} not in scene graphs')

In [None]:
# Cell 22: Phase 3 validation
print('=' * 60)
print('PHASE 3 VALIDATION')
print('=' * 60)

checks = {
    'Retrieval speed < 1s': avg_time < 1000,
    'Subgraph size <= budget': (max(subgraph_sizes) <= config.pcst_budget) if subgraph_sizes else False,
    'Non-empty subgraphs': all(s > 0 for s in subgraph_sizes),
}

all_pass = True
for check, passed in checks.items():
    status = '✓' if passed else '✗'
    print(f'  {status} {check}')
    if not passed:
        all_pass = False

print('\n' + '=' * 60)
if all_pass:
    print('✓ PHASE 3 COMPLETE')
else:
    print('✗ PHASE 3 INCOMPLETE')
print('=' * 60)

## Phase 4: Heterogeneous GNN Encoder

In [None]:
# Cell 24: Prepare PyG training data
# This retrieves subgraphs for each QA pair and labels answer nodes.
# Handled automatically by HeteroGNNModel.build_from_checkpoint_or_train.
# (shown here for visibility; actual prep happens in Cell 25)

print('PyG training data will be prepared in the next cell.')
print(f'  Train examples: {len(train_samples)}')
print(f'  Val examples: {len(val_samples)}')
print(f'  Scene graphs available: {len(scene_graphs)}')

In [None]:
# Cell 25: Build/train HeteroGATv2 via HeteroGNNModel
print('Building HeteroGNN model...\n')

hetero_model = HeteroGNNModel.build_from_checkpoint_or_train(
    config=config,
    retriever=retriever,
    train_samples=train_samples,
    val_samples=val_samples,
    scene_graphs=scene_graphs,
)

# Set embedder for inference
hetero_model.set_embedder(embedder)

In [None]:
# Cell 26: Test inference on single example
import torch

print('Testing HeteroGNN inference...\n')

test_sample = val_samples[0]
test_vid = test_sample['video_id']

if test_vid in scene_graphs:
    # Retrieve subgraph
    test_result = retriever.retrieve(test_sample['question'], scene_graphs[test_vid])
    subgraph = test_result.subgraph

    print(f'Q: {test_sample["question"]}')
    print(f'A: {test_sample["answer"]}')
    print(f'Subgraph: {test_result.num_nodes} nodes, {test_result.num_edges} edges')

    # Encode
    node_emb, attn_scores, graph_emb = hetero_model.encode(subgraph, test_sample['question'])

    print(f'\nGNN output shapes:')
    print(f'  Node embeddings: {node_emb.shape}')
    print(f'  Attention scores: {attn_scores.shape}')
    print(f'  Graph embedding: {graph_emb.shape}')

    # Top attention nodes
    top_k = min(10, len(attn_scores))
    top_vals, top_idx = torch.topk(attn_scores, top_k)
    names = getattr(subgraph, 'object_names', None)
    print(f'\nTop {top_k} nodes by attention:')
    for i, (idx, val) in enumerate(zip(top_idx.tolist(), top_vals.tolist()), 1):
        name = names[idx] if names and idx < len(names) else f'node_{idx}'
        print(f'  {i}. {name}: {val:.4f}')

    # Verify attention in [0, 1]
    assert attn_scores.min() >= 0, f'Negative attention: {attn_scores.min()}'
    assert attn_scores.max() <= 1.01, f'Attention > 1: {attn_scores.max()}'
    print('\n✓ Attention scores in valid range [0, 1]')
else:
    print(f'Video {test_vid} not in scene graphs')

In [None]:
# Cell 27: Training curves
import json
import matplotlib.pyplot as plt

history_path = config.get_checkpoint_path('hetero_gnn_history.json')

try:
    with open(history_path, 'r') as f:
        history = json.load(f)

    fig, axes = plt.subplots(1, 2, figsize=(12, 4))

    axes[0].plot(history['train_loss'], label='Train Loss')
    axes[0].plot(history['val_loss'], label='Val Loss')
    axes[0].set_xlabel('Epoch'); axes[0].set_ylabel('Loss')
    axes[0].set_title('Loss Curves'); axes[0].legend(); axes[0].grid(True)

    axes[1].plot(history['train_f1'], label='Train F1')
    axes[1].plot(history['val_f1'], label='Val F1')
    axes[1].set_xlabel('Epoch'); axes[1].set_ylabel('F1')
    axes[1].set_title('F1 Curves'); axes[1].legend(); axes[1].grid(True)

    plt.tight_layout()
    plt.show()

    best_f1 = max(history['val_f1'])
    print(f'Best validation F1: {best_f1:.3f}')
    print(f'Epochs trained: {len(history["train_loss"])}')
except FileNotFoundError:
    print('No training history found (model loaded from checkpoint)')

In [None]:
# Cell 28: Attention visualization
import matplotlib.pyplot as plt
import numpy as np

if test_vid in scene_graphs:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Attention distribution
    ax = axes[0]
    ax.hist(attn_scores.numpy(), bins=30, color='#42A5F5', edgecolor='white')
    ax.set_xlabel('Attention Score')
    ax.set_ylabel('Nodes')
    ax.set_title(f'Attention Distribution ({len(attn_scores)} nodes)')

    # Attention by object class
    ax = axes[1]
    if names:
        from collections import defaultdict
        class_scores = defaultdict(list)
        for idx, score in enumerate(attn_scores.tolist()):
            if idx < len(names):
                class_scores[names[idx]].append(score)
        class_means = {k: np.mean(v) for k, v in class_scores.items()}
        sorted_classes = sorted(class_means.items(), key=lambda x: x[1], reverse=True)[:15]
        labels, values = zip(*sorted_classes) if sorted_classes else ([], [])
        ax.barh(list(labels)[::-1], list(values)[::-1], color='#66BB6A')
        ax.set_xlabel('Mean Attention')
        ax.set_title('Attention by Object Class (top 15)')

    plt.tight_layout()
    plt.show()

In [None]:
# Cell 29: Memory check
import torch

if torch.cuda.is_available():
    print('GPU Memory:')
    print(f'  Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB')
    print(f'  Reserved:  {torch.cuda.memory_reserved() / 1e9:.2f} GB')
    print(f'  Max alloc: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB')
    assert torch.cuda.memory_allocated() / 1e9 < 14.0, 'Memory leak!'
    print('  ✓ Memory OK')
else:
    print('GPU not available — running on CPU')

In [None]:
# Cell 30: Phase 4 validation
print('=' * 60)
print('PHASE 4 VALIDATION')
print('=' * 60)

# Check training history for F1
try:
    best_val_f1 = max(history['val_f1'])
    f1_pass = best_val_f1 > 0.1  # Relaxed threshold for scene graphs
except Exception:
    best_val_f1 = None
    f1_pass = True  # Skip if loaded from checkpoint

checks = {
    'Model loads/trains successfully': hetero_model is not None,
    'Inference produces valid shapes': node_emb.shape[1] == config.gnn_hidden_dim,
    'Attention in [0, 1]': attn_scores.min() >= 0 and attn_scores.max() <= 1.01,
    'No OOM': True,  # If we got here, no OOM
}
if best_val_f1 is not None:
    checks[f'Val F1 > threshold ({best_val_f1:.3f})'] = f1_pass

all_pass = True
for check, passed in checks.items():
    status = '✓' if passed else '✗'
    print(f'  {status} {check}')
    if not passed:
        all_pass = False

print('\n' + '=' * 60)
if all_pass:
    print('✓ PHASE 4 COMPLETE')
else:
    print('✗ PHASE 4 INCOMPLETE')
print('=' * 60)

## Phase 5: Verbalization & Evaluation

In [None]:
# Cell 32: Verbalize examples — weighted vs unweighted side-by-side
verbalizer = SceneVerbalizer(config)

print('=' * 60)
print('VERBALIZATION EXAMPLES')
print('=' * 60)

for i, sample in enumerate(val_samples[:3]):
    vid = sample['video_id']
    if vid not in scene_graphs:
        continue

    result = retriever.retrieve(sample['question'], scene_graphs[vid])
    node_emb, attn_scores_i, graph_emb = hetero_model.encode(
        result.subgraph, sample['question']
    )

    weighted = verbalizer.verbalize(result.subgraph, attn_scores_i)
    unweighted = verbalizer.verbalize_unweighted(result.subgraph)

    print(f'\n--- Example {i+1} ---')
    print(f'Q: {sample["question"]}')
    print(f'A: {sample["answer"]}')
    print(f'\nAttention-weighted verbalization:')
    print(weighted[:500] if weighted else '  (empty)')
    print(f'\nUnweighted verbalization:')
    print(unweighted[:500] if unweighted else '  (empty)')
    print()

In [None]:
# Cell 33: Run evaluation on test subset
from tqdm.auto import tqdm

evaluator = BenchmarkEvaluator()

n_test = min(100, len(test_samples))
test_subset = test_samples[:n_test]

print(f'Evaluating on {n_test} test examples...\n')

eval_results = []
for sample in tqdm(test_subset, desc='Evaluating'):
    vid = sample['video_id']
    if vid not in scene_graphs:
        continue

    # Retrieve
    result = retriever.retrieve(sample['question'], scene_graphs[vid])
    subgraph = result.subgraph

    # Encode
    node_emb, attn_scores_i, graph_emb = hetero_model.encode(subgraph, sample['question'])

    # Find answer nodes by name matching
    answer_text = sample['answer'].strip().lower()
    names = getattr(subgraph, 'object_names', []) or []
    answer_nodes = [i for i, n in enumerate(names) if n.strip().lower() == answer_text]

    # Compute retrieval metrics
    selected = list(range(subgraph['object'].x.shape[0]))
    hit_rate = evaluator.retrieval_hit_rate(selected, answer_nodes)
    attn_prec = evaluator.attention_precision(attn_scores_i, answer_nodes, top_k=5)

    eval_results.append({
        'retrieval_hit_rate': hit_rate,
        'attention_precision': attn_prec,
    })

# Aggregate
if eval_results:
    agg = evaluator.aggregate(eval_results)

    print(f'\n--- Evaluation Results ({len(eval_results)} examples) ---')
    for metric, value in agg.items():
        print(f'  {metric}: {value:.4f}')
else:
    print('No examples evaluated')

In [None]:
# Cell 34: Results visualization
import matplotlib.pyplot as plt

if eval_results:
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    # Bar chart of aggregate metrics
    ax = axes[0]
    metrics = list(agg.keys())
    values = list(agg.values())
    colors = ['#42A5F5', '#66BB6A', '#FFA726', '#EF5350'][:len(metrics)]
    bars = ax.bar(metrics, values, color=colors)
    for bar, val in zip(bars, values):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                f'{val:.3f}', ha='center', fontweight='bold')
    ax.set_ylabel('Score')
    ax.set_title('Aggregate Metrics')
    ax.set_ylim(0, 1.1)
    ax.grid(axis='y', alpha=0.3)

    # Distribution of per-example scores
    ax = axes[1]
    hit_rates = [r['retrieval_hit_rate'] for r in eval_results]
    attn_precs = [r['attention_precision'] for r in eval_results]
    ax.hist(hit_rates, bins=20, alpha=0.6, label='Hit Rate', color='#42A5F5')
    ax.hist(attn_precs, bins=20, alpha=0.6, label='Attention Precision', color='#66BB6A')
    ax.set_xlabel('Score')
    ax.set_ylabel('Examples')
    ax.set_title('Per-Example Score Distribution')
    ax.legend()

    plt.tight_layout()
    plt.show()

In [None]:
# Cell 35: Save results to Drive
import json
from datetime import datetime

results_data = {
    'timestamp': datetime.now().isoformat(),
    'config': {
        'agqa_subset_size': config.agqa_subset_size,
        'ag_frame_sample_rate': config.ag_frame_sample_rate,
        'pcst_budget': config.pcst_budget,
        'top_k_seeds': config.top_k_seeds,
        'gnn_hidden_dim': config.gnn_hidden_dim,
        'gnn_num_layers': config.gnn_num_layers,
        'gnn_num_heads': config.gnn_num_heads,
        'gnn_encoder_type': config.gnn_encoder_type,
    },
    'data_stats': {
        'train_samples': len(train_samples),
        'val_samples': len(val_samples),
        'test_samples': len(test_samples),
        'scene_graphs': len(scene_graphs),
    },
    'metrics': agg if eval_results else {},
    'n_evaluated': len(eval_results),
}

results_path = config.get_results_path('video_benchmark_results.json')
save_checkpoint(results_data, results_path, format='json')
print(f'Results saved to {results_path}')

In [None]:
# Cell 36: Final summary
print('=' * 60)
print('arcOS BENCHMARK — FINAL SUMMARY')
print('=' * 60)

phases = {
    'Phase 1: Data Loading': len(ag_annotations) > 0 and len(train_samples) > 0,
    'Phase 2: Scene Graphs': len(scene_graphs) > 0,
    'Phase 3: Retrieval': len(retrieval_times) > 0,
    'Phase 4: HeteroGNN': hetero_model is not None,
    'Phase 5: Evaluation': len(eval_results) > 0,
}

for phase, passed in phases.items():
    status = '✓' if passed else '✗'
    print(f'  {status} {phase}')

if eval_results:
    print(f'\nKey Metrics:')
    for metric, value in agg.items():
        print(f'  {metric}: {value:.4f}')

print(f'\nData:')
print(f'  Videos: {len(scene_graphs)}')
print(f'  QA pairs: {len(train_samples) + len(val_samples) + len(test_samples)}')
print(f'  Test evaluated: {len(eval_results)}')

all_complete = all(phases.values())
print('\n' + '=' * 60)
if all_complete:
    print('✓ BENCHMARK COMPLETE')
else:
    print('✗ BENCHMARK INCOMPLETE — review phase statuses above')
print('=' * 60)