# arcOS Benchmark - Causal QA with GNN + LLM

**Phase 1: Environment & Data Foundation**

This notebook implements the complete arcOS benchmark pipeline:
- Graph Neural Network structural reasoning over knowledge graphs
- LLM text generation with graph-guided prompts
- Evaluation on RoG-WebQSP question answering dataset

**Requirements:**
- Google Colab with GPU runtime (T4 or better)
- Google Drive mounted for checkpointing
- ~10GB free space on Drive

**Architecture:**
- Dataset: RoG-WebQSP (4,706 QA pairs with Freebase subgraphs)
- Graph DB: NetworkX in-memory
- GNN: Graph Attention Network (GATv2)
- LLM: OpenRouter API (Claude 3.5 Sonnet)
- Verbalization: Hard prompts (text-based, not soft embeddings)

## Cell 0: Cloning Repository
Update notebook with latest updates from GitHub repo.

In [1]:
import os

# Ensure we are in a safe directory before operations
%cd /content

# Remove existing directory if it exists, then clone fresh
!rm -rf arcOS-benchmark-colab
!git clone https://github.com/ashtonalex/arcOS-benchmark-colab

/content
Cloning into 'arcOS-benchmark-colab'...
remote: Enumerating objects: 236, done.[K
remote: Counting objects: 100% (236/236), done.[K
remote: Compressing objects: 100% (148/148), done.[K
remote: Total 236 (delta 125), reused 183 (delta 75), pack-reused 0 (from 0)[K
Receiving objects: 100% (236/236), 229.72 KiB | 4.42 MiB/s, done.
Resolving deltas: 100% (125/125), done.


## Cell 1: Environment Setup

Install dependencies and verify GPU availability.

In [2]:
# ============================================================================
# ENVIRONMENT SETUP WITH UV PACKAGE MANAGER
# Ensures absolute environment parity between kernel and installed packages
# ============================================================================

import sys
import os
import subprocess
from pathlib import Path

# Colab UV workaround: Clear broken constraint files
os.environ["UV_CONSTRAINT"] = ""
os.environ["UV_BUILD_CONSTRAINT"] = ""

print("="*70)
print("STEP 1: ENVIRONMENT PATH VERIFICATION")
print("="*70)

# Capture current Python executable
current_python = sys.executable
print(f"Current kernel executable: {current_python}")
print(f"Python version: {sys.version}")
print(f"Site packages: {sys.path[0] if sys.path else 'N/A'}")

# Check if uv is available
def check_uv_available():
    """Check if uv is installed and accessible."""
    try:
        result = subprocess.run(
            ['uv', '--version'],
            capture_output=True,
            text=True,
            timeout=5
        )
        return result.returncode == 0
    except (FileNotFoundError, subprocess.TimeoutExpired):
        return False

uv_available = check_uv_available()

if not uv_available:
    print("\n⚠ UV not found. Installing uv package manager...")
    %pip install -q uv
    uv_available = check_uv_available()

if uv_available:
    # Get uv version
    uv_version = subprocess.run(
        ['uv', '--version'],
        capture_output=True,
        text=True
    ).stdout.strip()
    print(f"✓ UV available: {uv_version}")
else:
    print("✗ UV installation failed. Will fall back to pip.")

print("\n" + "="*70)
print("STEP 2: PACKAGE INSTALLATION")
print("="*70)

# Define packages to install
packages = [
    "datasets",
    "networkx",
    "tqdm",
    "faiss-gpu-cu12" # Added faiss-gpu
]

# PyTorch with CUDA support
torch_packages = "torch torchvision torchaudio"
torch_index = "https://download.pytorch.org/whl/cu118"

if uv_available:
    print(f"Installing packages using UV with --python {current_python}\n")

    # Install PyTorch with CUDA
    print("Installing PyTorch with CUDA 11.8 support...")
    !uv pip install --python {current_python} {torch_packages} --index-url {torch_index}

    # Install other packages
    print("\nInstalling additional dependencies...")
    for package in packages:
        !uv pip install --python {current_python} {package}
else:
    print("Falling back to standard pip installation\n")

    # Install PyTorch with CUDA
    print("Installing PyTorch with CUDA 11.8 support...")
    %pip install -q {torch_packages} --index-url {torch_index}

    # Install other packages
    print("\nInstalling additional dependencies...")
    %pip install -q {' '.join(packages)}

print("\n" + "="*70)
print("STEP 3: INSTALLATION VERIFICATION")
print("="*70)

# Verify installed packages are in the correct location
def verify_package_location(package_name):
    """Verify package is installed in current kernel's site-packages."""
    try:
        module = __import__(package_name)
        module_path = Path(module.__file__).parent

        # Check if module is in one of sys.path locations
        in_sys_path = any(str(module_path).startswith(p) for p in sys.path if p)

        # Get version if available
        version = getattr(module, '__version__', 'unknown')

        return {
            'installed': True,
            'version': version,
            'location': str(module_path),
            'in_sys_path': in_sys_path
        }
    except ImportError:
        return {'installed': False}

# Verify key packages
verification_packages = ['torch', 'datasets', 'networkx', 'tqdm', 'faiss'] # Added faiss for verification
print("\nVerifying installed packages:\n")

all_verified = True
for pkg in verification_packages:
    info = verify_package_location(pkg)
    if info['installed']:
        status = "✓" if info['in_sys_path'] else "⚠"
        print(f"{status} {pkg:12s} v{info['version']:12s}")
        print(f"  Location: {info['location']}")
        if not info['in_sys_path']:
            print(f"  WARNING: Not in sys.path!")
            all_verified = False
    else:
        print(f"✗ {pkg:12s} NOT INSTALLED")
        all_verified = False
    print()

print("="*70)
print("STEP 4: GPU VERIFICATION")
print("="*70)

import torch
gpu_available = torch.cuda.is_available()
print(f"\nGPU available: {gpu_available} {'✓' if gpu_available else '✗'}")

if gpu_available:
    print(f"GPU name: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"CUDA version: {torch.version.cuda}")
else:
    print("⚠ Warning: No GPU detected.")
    print("  Go to: Runtime -> Change runtime type -> Select T4 GPU")

print("\n" + "="*70)
print("ENVIRONMENT SETUP SUMMARY")
print("="*70)
print(f"Package manager: {'UV' if uv_available else 'pip'}")
print(f"Python executable: {current_python}")
print(f"All packages verified: {'✓ YES' if all_verified else '✗ NO'}")
print(f"GPU available: {'✓ YES' if gpu_available else '✗ NO'}")
print("="*70)

if not all_verified:
    print("\n⚠ WARNING: Some packages failed verification. Check output above.")
else:
    print("\n✓ Environment setup complete with full parity!")


STEP 1: ENVIRONMENT PATH VERIFICATION
Current kernel executable: /usr/bin/python3
Python version: 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]
Site packages: /content
✓ UV available: uv 0.9.26

STEP 2: PACKAGE INSTALLATION
Installing packages using UV with --python /usr/bin/python3

Installing PyTorch with CUDA 11.8 support...
[2mUsing Python 3.12.12 environment at: /usr[0m
[2mAudited [1m3 packages[0m [2min 27ms[0m[0m

Installing additional dependencies...
[2mUsing Python 3.12.12 environment at: /usr[0m
[2mAudited [1m1 package[0m [2min 31ms[0m[0m
[2mUsing Python 3.12.12 environment at: /usr[0m
[2mAudited [1m1 package[0m [2min 14ms[0m[0m
[2mUsing Python 3.12.12 environment at: /usr[0m
[2mAudited [1m1 package[0m [2min 11ms[0m[0m
[2mUsing Python 3.12.12 environment at: /usr[0m
[2mAudited [1m1 package[0m [2min 13ms[0m[0m

STEP 3: INSTALLATION VERIFICATION

Verifying installed packages:

✓ torch        v2.8.0+cu128 
  Location: /usr/local/lib/py

## Cell 2: Clean Room Import

Purge bytecode caches, scrub `sys.modules`, pin `sys.path`, and verify source file integrity before importing any `src/` modules. Run this cell after every `git pull` or code change.

In [3]:
# ============================================================================
# CLEAN ROOM IMPORT — guarantees fresh module loading
# Run after every git pull / code edit to eliminate stale bytecode & caches
# ============================================================================

import sys, os, shutil, hashlib, importlib
from pathlib import Path
from datetime import datetime, timezone

# ── CONFIG ──────────────────────────────────────────────────────────────────
REPO_ROOT = Path("/content/arcOS-benchmark-colab")
SRC_ROOT  = REPO_ROOT / "src"
# Files to fingerprint (add any core logic files you want to verify)
VERIFY_FILES = [
    SRC_ROOT / "config.py",
    SRC_ROOT / "retrieval" / "pcst_solver.py",
    SRC_ROOT / "gnn" / "encoder.py",
]
# Module prefixes to scrub from sys.modules
SCRUB_PREFIXES = ("src", "src.")

print("=" * 70)
print("CLEAN ROOM IMPORT")
print("=" * 70)

# ── STEP 1: Bytecode Purge ─────────────────────────────────────────────────
print("\n[1/4] Purging __pycache__ and .pyc files...")
cache_dirs_removed = 0
pyc_files_removed  = 0

for cache_dir in SRC_ROOT.rglob("__pycache__"):
    shutil.rmtree(cache_dir)
    cache_dirs_removed += 1

for pyc_file in SRC_ROOT.rglob("*.pyc"):
    pyc_file.unlink()
    pyc_files_removed += 1

print(f"  Removed {cache_dirs_removed} __pycache__ dirs, {pyc_files_removed} .pyc files")

# ── STEP 2: sys.modules Scrub ──────────────────────────────────────────────
print("\n[2/4] Scrubbing src.* from sys.modules...")
stale_keys = [k for k in sys.modules if k.startswith(SCRUB_PREFIXES)]
for key in stale_keys:
    del sys.modules[key]
print(f"  Evicted {len(stale_keys)} cached modules: {stale_keys[:8]}{'...' if len(stale_keys) > 8 else ''}")

# ── STEP 3: Path Priority ─────────────────────────────────────────────────
print("\n[3/4] Pinning sys.path priority...")
repo_str = str(REPO_ROOT)
# Remove any existing entries to avoid duplicates
sys.path = [p for p in sys.path if p != repo_str]
# Insert at position 0 so our src/ wins over any pip-installed version
sys.path.insert(0, repo_str)
print(f"  sys.path[0] = {sys.path[0]}")

# ── STEP 4: Source File Validation ─────────────────────────────────────────
print("\n[4/4] Verifying source file integrity...")

def file_fingerprint(path: Path) -> dict:
    """Return last-modified timestamp and MD5 hash for a file."""
    data = path.read_bytes()
    md5  = hashlib.md5(data).hexdigest()
    mtime = datetime.fromtimestamp(path.stat().st_mtime, tz=timezone.utc)
    return {"md5": md5, "modified": mtime.strftime("%Y-%m-%d %H:%M:%S UTC"), "size": len(data)}

for fpath in VERIFY_FILES:
    if fpath.exists():
        info = file_fingerprint(fpath)
        rel  = fpath.relative_to(REPO_ROOT)
        print(f"  ✓ {rel}")
        print(f"    Modified : {info['modified']}")
        print(f"    MD5      : {info['md5']}")
        print(f"    Size     : {info['size']:,} bytes")
    else:
        print(f"  ⚠ {fpath} — NOT FOUND (skipped)")

# ── STEP 5: Fresh Imports ──────────────────────────────────────────────────
print("\n" + "-" * 70)
print("Importing modules (fresh)...")

from src.config import BenchmarkConfig
from src.utils.seeds import set_seeds
from src.utils.checkpoints import (
    ensure_drive_mounted,
    checkpoint_exists,
    save_checkpoint,
    load_checkpoint,
    create_checkpoint_dirs,
)
from src.data.dataset_loader import RoGWebQSPLoader
from src.data.graph_builder import GraphBuilder

print("✓ All imports successful (clean load)")

print("\n" + "=" * 70)
print("CLEAN ROOM COMPLETE")
print("=" * 70)


CLEAN ROOM IMPORT

[1/4] Purging __pycache__ and .pyc files...
  Removed 0 __pycache__ dirs, 0 .pyc files

[2/4] Scrubbing src.* from sys.modules...
  Evicted 0 cached modules: []

[3/4] Pinning sys.path priority...
  sys.path[0] = /content/arcOS-benchmark-colab

[4/4] Verifying source file integrity...
  ✓ src/config.py
    Modified : 2026-02-15 15:05:28 UTC
    MD5      : 65dd4dbe33e660c99769a730c98bbbb4
    Size     : 6,168 bytes
  ✓ src/retrieval/pcst_solver.py
    Modified : 2026-02-15 15:05:28 UTC
    MD5      : ab2028fd300e477aefdecbb3e2c8bd9e
    Size     : 13,877 bytes
  ✓ src/gnn/encoder.py
    Modified : 2026-02-15 15:05:28 UTC
    MD5      : 14360a080ee2e3b637976156dce7a66f
    Size     : 9,077 bytes

----------------------------------------------------------------------
Importing modules (fresh)...
✓ All imports successful (clean load)

CLEAN ROOM COMPLETE


## Cell 3: Configuration

Initialize benchmark configuration with hyperparameters.

In [4]:
# Initialize configuration
config = BenchmarkConfig(
    seed=42,
    deterministic=True,
    drive_root="/content/drive/MyDrive/arcOS_benchmark",
)

# Print configuration summary
config.print_summary()

arcOS Benchmark Configuration
Seed: 42 (deterministic=True)
Dataset: rmanluo/RoG-webqsp
Drive root: /content/drive/MyDrive/arcOS_benchmark
Checkpoint dir: /content/drive/MyDrive/arcOS_benchmark/checkpoints
Results dir: /content/drive/MyDrive/arcOS_benchmark/results

--- Retrieval ---
Embedding model: sentence-transformers/all-MiniLM-L6-v2
Top-K entities: 15
PCST budget: 70
PCST local budget: 300
PCST edge cost: 1.0
PCST pruning: gw
PCST base prize ratio: 1.0

--- GNN ---
Hidden dim: 256
Num layers: 3
Num heads: 4
Pooling: attention

--- LLM ---
Model: anthropic/claude-3.5-sonnet
Provider: openrouter
Temperature: 0.0


## Cell 4: Seed Initialization

Set random seeds for reproducibility.

In [5]:
# Set seeds for reproducibility
set_seeds(config.seed, config.deterministic)

✓ Random seeds set to 42 (deterministic=True)


## Cell 5: Google Drive Setup

Mount Google Drive and create checkpoint/results directories.

In [6]:
# Mount Google Drive
drive_mounted = ensure_drive_mounted()

if drive_mounted:
    # Create checkpoint and results directories
    create_checkpoint_dirs(config.checkpoint_dir, config.results_dir)
else:
    print("⚠ Warning: Drive not mounted. Checkpointing will not work.")
    print("  Continuing with local /content/ storage (temporary)")

✓ Google Drive already mounted at /content/drive
✓ Checkpoint directory: /content/drive/MyDrive/arcOS_benchmark/checkpoints
✓ Results directory: /content/drive/MyDrive/arcOS_benchmark/results


## Cell 6: Dataset Loading

Load RoG-WebQSP dataset from HuggingFace with Drive caching.

In [7]:
# Initialize dataset loader
cache_dir = config.checkpoint_dir / "huggingface_cache"
loader = RoGWebQSPLoader(cache_dir=cache_dir)

# Check for cached dataset
dataset_checkpoint_path = config.get_checkpoint_path("dataset.pkl")

dataset = None # Initialize dataset to None

if checkpoint_exists(dataset_checkpoint_path):
    print("Loading dataset from checkpoint...")
    try:
        dataset = load_checkpoint(dataset_checkpoint_path, format="pickle")
        print("✓ Dataset loaded from checkpoint.")
    except FileNotFoundError as e:
        print(f"⚠ Warning: Failed to load dataset from checkpoint due to missing files: {e}")
        print("  Falling back to downloading dataset from HuggingFace...")
        # If loading fails, proceed to download
        pass # dataset remains None, so the next block will execute

if dataset is None: # If dataset was not loaded successfully or checkpoint didn't exist
    print("Downloading dataset from HuggingFace...")
    dataset = loader.load(dataset_name=config.dataset_name)
    save_checkpoint(dataset, dataset_checkpoint_path, format="pickle")
    print("✓ Dataset downloaded and saved to checkpoint.")

# Slice dataset to desired sizes
dataset = loader.slice_dataset(
    dataset,
    train_size=900,
    val_size=90,
    test_size=None  # Keep all test examples
)

# Inspect dataset schema
loader.inspect_schema(dataset, num_examples=1)

# Compute statistics
loader.compute_statistics(dataset)

# Validate split counts (updated for sliced dataset)
split_valid = loader.validate_split_counts(
    dataset,
    expected_train=900,
    expected_val=90,
    expected_test=1628,  # Keep original test size
)

✓ HuggingFace cache directory: /content/drive/MyDrive/arcOS_benchmark/checkpoints/huggingface_cache
Loading dataset from checkpoint...
✓ Checkpoint loaded: /content/drive/MyDrive/arcOS_benchmark/checkpoints/dataset.pkl (pickle)
✓ Dataset loaded from checkpoint.

Dataset Slicing
Original sizes:
  Train: 2826
  Validation: 246
  Test: 1628

Sliced sizes:
  Train: 900
  Validation: 90
  Test: 1628

✓ Dataset sliced successfully

Dataset Schema Inspection
Inspecting first split: train

Fields:
  - id: Value('string')
  - question: Value('string')
  - answer: List(Value('string'))
  - q_entity: List(Value('string'))
  - a_entity: List(Value('string'))
  - graph: List(List(Value('string')))
  - choices: List(Value('null'))

✓ All expected fields present

Sample Examples (first 1):

--- Example 0 ---
ID: WebQTrn-0
Question: what is the name of justin bieber brother
Answer: ['Jaxon Bieber']
Question Entity: ['Justin Bieber']
Answer Entity: ['Jaxon Bieber']
Choices: []
Graph: 9088 triples
  Sam

In [None]:
# ============================================================================
# DATA STRUCTURE ANALYSIS — Inspect first 3 rows in detail
# Understand entity naming, triple structure, and relation patterns
# before building the graph and embeddings
# ============================================================================

import json
from collections import Counter

print("=" * 70)
print("DATA STRUCTURE ANALYSIS — First 3 examples")
print("=" * 70)

for i in range(3):
    ex = dataset["train"][i]
    print(f"\n{'─' * 70}")
    print(f"EXAMPLE {i}: {ex['id']}")
    print(f"{'─' * 70}")
    print(f"  Question:  {ex['question']}")
    print(f"  Answer:    {ex['answer']}")
    print(f"  q_entity:  {ex['q_entity']}  (type: {type(ex['q_entity']).__name__})")
    print(f"  a_entity:  {ex['a_entity']}  (type: {type(ex['a_entity']).__name__})")
    print(f"  Graph:     {len(ex['graph'])} triples")

    triples = ex['graph']

    # Collect all unique entities and relations
    subjects = set()
    objects_ = set()
    relations = []
    for t in triples:
        subjects.add(t[0])
        objects_.add(t[2])
        relations.append(t[1])

    all_entities = subjects | objects_
    rel_counts = Counter(relations)

    print(f"\n  Unique subjects:  {len(subjects)}")
    print(f"  Unique objects:   {len(objects_)}")
    print(f"  Unique entities:  {len(all_entities)}")
    print(f"  Unique relations: {len(set(relations))}")

    # Show first 10 triples
    print(f"\n  First 10 triples:")
    for j, t in enumerate(triples[:10]):
        print(f"    [{j}] {t[0]!r}  --({t[1]})-->  {t[2]!r}")

    # Show last 5 triples (may reveal different patterns)
    print(f"\n  Last 5 triples:")
    for j, t in enumerate(triples[-5:]):
        print(f"    [{len(triples)-5+j}] {t[0]!r}  --({t[1]})-->  {t[2]!r}")

    # Check if q_entity and a_entity appear as graph nodes
    q_ents = ex['q_entity'] if isinstance(ex['q_entity'], list) else [ex['q_entity']]
    a_ents = ex['a_entity'] if isinstance(ex['a_entity'], list) else [ex['a_entity']]

    print(f"\n  q_entity in graph nodes: ", end="")
    for qe in q_ents:
        found = qe in all_entities
        print(f"{qe!r} -> {'YES' if found else 'NO'}", end="  ")
    print()

    print(f"  a_entity in graph nodes: ", end="")
    for ae in a_ents:
        found = ae in all_entities
        print(f"{ae!r} -> {'YES' if found else 'NO'}", end="  ")
    print()

    # Top 5 relations by frequency
    print(f"\n  Top 5 relations:")
    for rel, cnt in rel_counts.most_common(5):
        print(f"    {cnt:4d}x  {rel}")

    # Entity name analysis — are they readable or opaque IDs?
    sample_entities = sorted(all_entities)[:15]
    print(f"\n  Sample entity names (first 15 alphabetically):")
    for ent in sample_entities:
        print(f"    {ent!r}")

# ── Cross-example summary ──────────────────────────────────────────────
print(f"\n{'=' * 70}")
print("CROSS-EXAMPLE SUMMARY")
print(f"{'=' * 70}")

# Check entity overlap between examples
all_ex_entities = []
for i in range(3):
    ex = dataset["train"][i]
    ents = set()
    for t in ex['graph']:
        ents.add(t[0])
        ents.add(t[2])
    all_ex_entities.append(ents)

overlap_01 = all_ex_entities[0] & all_ex_entities[1]
overlap_02 = all_ex_entities[0] & all_ex_entities[2]
overlap_12 = all_ex_entities[1] & all_ex_entities[2]

print(f"\nEntity overlap between examples:")
print(f"  Ex0 ∩ Ex1: {len(overlap_01)} shared entities")
print(f"  Ex0 ∩ Ex2: {len(overlap_02)} shared entities")
print(f"  Ex1 ∩ Ex2: {len(overlap_12)} shared entities")
if overlap_01:
    print(f"  Sample shared (0∩1): {sorted(overlap_01)[:5]}")


## Cell 7: Graph Construction

Build NetworkX graphs from dataset triples.

In [8]:
# Initialize graph builder
graph_builder = GraphBuilder(directed=config.graph_directed)

# Check for cached unified graph
unified_graph_path = config.get_checkpoint_path("unified_graph.pkl")

if checkpoint_exists(unified_graph_path):
    print("Loading unified graph from checkpoint...")
    unified_graph = load_checkpoint(unified_graph_path, format="pickle")
else:
    print("Building unified graph from training split...")
    unified_graph = graph_builder.build_unified_graph(dataset["train"])
    save_checkpoint(unified_graph, unified_graph_path, format="pickle")

# Print graph statistics
graph_builder.print_graph_info(unified_graph, name="Unified Training Graph")

# Validate graph size
graph_valid = graph_builder.validate_graph_size(
    unified_graph,
    min_nodes=config.unified_graph_min_nodes,
    min_edges=config.unified_graph_min_edges,
)

# Build sample per-example graph for demonstration
print("\nBuilding sample per-example graph...")
sample_example = dataset["train"][0]
sample_graph = graph_builder.build_from_triples(
    sample_example["graph"],
    graph_id=sample_example["id"]
)
graph_builder.print_graph_info(sample_graph, name="Sample Per-Example Graph")

✓ GraphBuilder initialized (directed=True)
Loading unified graph from checkpoint...
✓ Checkpoint loaded: /content/drive/MyDrive/arcOS_benchmark/checkpoints/unified_graph.pkl (pickle)

Unified Training Graph Information
Nodes: 543170
Edges: 1553935
Directed: True
Density: 0.000005
Weakly connected: True

Relation Statistics:
Unique relations: 4797
Top 10 relations:
  - common.topic.notable_types: 73649
  - freebase.valuenotation.is_reviewed: 39352
  - location.location.containedby: 34265
  - common.topic.notable_for: 33116
  - people.person.profession: 25576
  - people.person.gender: 22162
  - common.topic.article: 19550
  - people.person.nationality: 19116
  - common.topic.webpage: 18736
  - common.webpage.topic: 18719

Degree Statistics:
Average in-degree: 2.86
Average out-degree: 2.86

Graph Size Validation
Nodes: 543170 ✓
Edges: 1553935 ✓

✓ Graph meets size requirements

Building sample per-example graph...

Sample Per-Example Graph Information
Nodes: 1723
Edges: 8286
Directed: Tru

## Cell 8: Phase 1 Validation

Automated validation of all Phase 1 success criteria.

In [None]:
print("\n" + "="*60)
print("Phase 1 Success Criteria Validation")
print("="*60)

# Collect validation results
validation_results = {
    "GPU Available": torch.cuda.is_available(),
    "All Imports Successful": True,  # If we got here, imports worked
    "Dataset Splits Valid": split_valid,
    "Unified Graph Size Valid": graph_valid,
}

# Test checkpoint round-trip
test_checkpoint_path = config.get_checkpoint_path("test_roundtrip.pkl")
test_data = {"test": "round-trip", "value": 42}
try:
    save_checkpoint(test_data, test_checkpoint_path, format="pickle")
    loaded_data = load_checkpoint(test_checkpoint_path, format="pickle")
    checkpoint_roundtrip_ok = (loaded_data == test_data)
    validation_results["Checkpoint Round-Trip"] = checkpoint_roundtrip_ok
except Exception as e:
    print(f"Checkpoint round-trip failed: {e}")
    validation_results["Checkpoint Round-Trip"] = False

# Print results
print("\nValidation Results:")
all_passed = True
for criterion, passed in validation_results.items():
    status = "✓" if passed else "✗"
    print(f"  {status} {criterion}")
    if not passed:
        all_passed = False

print("\n" + "="*60)
if all_passed:
    print("✓ PHASE 1 COMPLETE - All criteria passed!")
    print("\nReady to proceed to Phase 2: Retrieval Pipeline")
else:
    print("✗ PHASE 1 INCOMPLETE - Some criteria failed")
    print("\nPlease review failed criteria above")
print("="*60)

# Print summary statistics
print("\nPhase 1 Summary:")
print(f"  Dataset: {config.dataset_name}")
print(f"  Training examples: {len(dataset['train'])}")
print(f"  Validation examples: {len(dataset['validation'])}")
print(f"  Test examples: {len(dataset['test'])}")
print(f"  Unified graph nodes: {unified_graph.number_of_nodes()}")
print(f"  Unified graph edges: {unified_graph.number_of_edges()}")
print(f"  Checkpoints saved to: {config.checkpoint_dir}")

## Cell 8.5: PCST Configuration

Override retrieval hyperparameters **without editing  or pushing to GitHub**.
These assignments mutate the  object created in Cell 3 — run this cell before Cell 9.


In [None]:
# ============================================================
# PCST CONFIGURATION — override config.py defaults here
# Edit values below, then run this cell.
# Works whether Cell 9 has already run or not.
# ============================================================

# Max nodes in the extracted subgraph
config.pcst_budget = 70

# BFS neighbourhood size: nodes collected around seeds before PCST runs
config.pcst_local_budget = 300

# Edge traversal cost. Lower = more edges included.
# With cosine-sim prizes in [0, 1], 0.015 lets a 4-hop path reach a 0.4-prize node.
config.pcst_cost = 0.015

# PCST pruning strategy: "none", "gw" (Goemans-Williamson), or "strong"
config.pcst_pruning = "gw"

# Query-aware edge cost scaling [0, 1].
# 0 = uniform costs  |  1 = fully query-aware (similar edges are cheaper)
config.pcst_edge_weight_alpha = 0.5

# Bridge disconnected PCST components via shortest paths
config.pcst_bridge_components = True

# Max relay hops when bridging disconnected components
config.pcst_bridge_max_hops = 6

# ── Validate ─────────────────────────────────────────────────
config.__post_init__()  # re-run validation with updated values

# ── Live-patch retriever if already built ─────────────────────
# If Cell 9 already ran, update the live PCSTSolver directly so
# you do not need to re-run Cell 9.
try:
    solver = retriever.pcst_solver
    solver.cost              = config.pcst_cost
    solver.budget            = config.pcst_budget
    solver.local_budget      = config.pcst_local_budget
    solver.pruning           = config.pcst_pruning
    solver.edge_weight_alpha = config.pcst_edge_weight_alpha
    solver.bridge_components = config.pcst_bridge_components
    solver.bridge_max_hops   = config.pcst_bridge_max_hops
    _src = "config + live retriever"
except NameError:
    _src = "config only (retriever not built yet — run Cell 9)"

print(f"PCST configuration applied ({_src}):")
print(f"  pcst_budget            = {config.pcst_budget}")
print(f"  pcst_local_budget      = {config.pcst_local_budget}")
print(f"  pcst_cost              = {config.pcst_cost}")
print(f"  pcst_pruning           = {config.pcst_pruning\!r}")
print(f"  pcst_edge_weight_alpha = {config.pcst_edge_weight_alpha}")
print(f"  pcst_bridge_components = {config.pcst_bridge_components}")
print(f"  pcst_bridge_max_hops   = {config.pcst_bridge_max_hops}")


## Cell 9: Build Retrieval Pipeline

Initialize retrieval components (embeddings, FAISS index, PCST solver).

In [9]:
!uv pip install pcst-fast

[2mUsing Python 3.12.12 environment at: /usr[0m
[2mAudited [1m1 package[0m [2min 82ms[0m[0m


In [10]:
print("=" * 60)
print("PHASE 2: RETRIEVAL PIPELINE")
print("=" * 60)

from src.retrieval import Retriever

# Build retriever (uses checkpoints if available)
retriever = Retriever.build_from_checkpoint_or_new(
    config=config,
    unified_graph=unified_graph  # From Phase 1 Cell 7
)

print("\n✓ Retrieval pipeline initialized")
print(f"  - Entity embeddings: {len(retriever.entity_index)} entities")
print(f"  - Top-K: {config.top_k_entities}")
print(f"  - PCST budget: {config.pcst_budget} nodes")

PHASE 2: RETRIEVAL PIPELINE
BUILDING RETRIEVAL PIPELINE

[1/4] Initializing text embedder...
⚠ CUDA not available, falling back to CPU for embeddings


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loading weights:   0%|          | 0/103 [00:00<?, ?it/s]

BertModel LOAD REPORT from: sentence-transformers/all-MiniLM-L6-v2
Key                     | Status     |  | 
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


✓ Loaded embedding model: sentence-transformers/all-MiniLM-L6-v2
  - Device: cpu
  - Embedding dimension: 384

[2/4] Loading/computing entity embeddings...
Loading cached embeddings from entity_embeddings.pkl
✓ Checkpoint loaded: /content/drive/MyDrive/arcOS_benchmark/checkpoints/entity_embeddings.pkl (pickle)
✓ Loaded 543170 entity embeddings

[3/4] Loading/computing relation embeddings...
Loading cached relation embeddings from relation_embeddings.pkl
✓ Checkpoint loaded: /content/drive/MyDrive/arcOS_benchmark/checkpoints/relation_embeddings.pkl (pickle)
✓ Loaded 4797 relation embeddings

[4/4] Loading/building FAISS index...
Loading cached FAISS index from faiss_index.bin
✓ Loaded FAISS index from /content/drive/MyDrive/arcOS_benchmark/checkpoints/faiss_index.bin
  - 543170 entities indexed

Initializing PCST solver...
✓ PCST solver ready (cost: 1.0, budget: 70, local: 300, pruning: gw, base_prize_ratio: 1.0)

✓ RETRIEVAL PIPELINE READY

✓ Retrieval pipeline initialized
  - Entity e

## Cell 10: Retrieval Validation

Test retrieval pipeline on 10 validation examples.

In [None]:
print("\n" + "=" * 60)
print("RETRIEVAL VALIDATION (up to 50 examples)")
print("=" * 60)

# Use up to 50 validation examples for statistically meaningful results
n_val = min(50, len(dataset["validation"]))
val_examples = list(dataset["validation"].select(range(n_val)))

hit_count = 0
total_time_ms = 0
subgraph_sizes = []

for i, example in enumerate(val_examples):
    question = example["question"]
    answer_entities = example.get("a_entity", [])
    if isinstance(answer_entities, str):
        answer_entities = [answer_entities]

    # Extract topic entities from dataset
    q_entities = example.get("q_entity", [])
    if isinstance(q_entities, str):
        q_entities = [q_entities]

    # Print header BEFORE retrieve() so PCST verbose output appears
    # under the correct example (not visually under the previous one)
    print(f"\n[{i+1}/{n_val}] Q: {question[:60]}...")
    print(f"  Topic entities (q_entity): {q_entities}")
    print(f"  Answer entities: {answer_entities}")

    # Retrieve subgraph (q_entity used as primary seed when available)
    result = retriever.retrieve(question, q_entity=q_entities, answer_entities=answer_entities)

    # has_answer is now set on the result by retrieve()
    hit = result.has_answer

    if hit:
        hit_count += 1

    total_time_ms += result.retrieval_time_ms
    subgraph_sizes.append(result.num_nodes)

    # Print result summary (after PCST verbose output)
    print(f"  Seeds used: {result.seed_entities[:5]}{'...' if len(result.seed_entities) > 5 else ''}")
    print(f"  Subgraph: {result.num_nodes} nodes, {result.num_edges} edges")
    print(f"  Hit: {'✓' if hit else '✗'}")
    print(f"  Time: {result.retrieval_time_ms:.1f}ms")

# Summary metrics
hit_rate = hit_count / len(val_examples) * 100
avg_time = total_time_ms / len(val_examples)
avg_size = sum(subgraph_sizes) / len(subgraph_sizes)

print("\n" + "=" * 60)
print("VALIDATION SUMMARY")
print("=" * 60)
print(f"Hit rate: {hit_rate:.1f}% ({hit_count}/{len(val_examples)})")
print(f"Avg retrieval time: {avg_time:.1f}ms")
print(f"Avg subgraph size: {avg_size:.1f} nodes")
print(f"Max subgraph size: {max(subgraph_sizes)} nodes")

## Cell 10a: PCST Deep Dive — Diagnostics

Traces a **single failing example** (smallest subgraph output from Cell 10) through
every step of the retrieval pipeline to pinpoint why PCST collapses to 1 node.

**Steps traced:**

| Step | What it checks |
|------|---------------|
| **A** | Query embedding norm |
| **B** | k-NN seed scores and graph membership |
| **C** | Prize construction — how many seeds get prizes and whether any prize > edge cost |
| **D** | BFS localisation — local graph size, number of WCCs, root component size, answer entity reachability |
| **E** | Local prize computation — entity-embedding coverage, how many nodes clear the `local_prize_threshold`, how many exceed `cost` |
| **F** | Raw `pcst_fast` call — prize/cost ratio, output format (labels vs indices), final node count |

**Visualisations:**
- Prize distribution histogram vs edge-cost line (Plot 1)
- Pipeline funnel: node count at each stage (Plot 2)
- Subgraph-size distribution across all 50 val examples (Plot 3)
- Root-component graph coloured by node type / prize magnitude (Plot 4)

In [None]:
# ============================================================================
# CELL 10a: PCST DEEP DIVE — Step-by-step diagnostics for one failing example
#
# Traces: query → k-NN → prizes → BFS localise → root component →
#         local prizes → raw pcst_fast call → output
# Highlights exactly where and why PCST collapses to 1 node.
# ============================================================================

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import networkx as nx
from collections import Counter

try:
    import pcst_fast as _pcst_fast_mod
except ImportError:
    _pcst_fast_mod = None
    print("⚠ pcst_fast not importable — Step F will be skipped")

# ── Pick probe example: prefer one where PCST returned ≤ 2 nodes ──────────
_failing   = [i for i, s in enumerate(subgraph_sizes) if s <= 2]
_probe_idx = _failing[0] if _failing else 0
_ex        = val_examples[_probe_idx]
_q         = _ex["question"]
_q_ents    = _ex.get("q_entity", [])
if isinstance(_q_ents, str): _q_ents = [_q_ents]
_a_ents    = _ex.get("a_entity", [])
if isinstance(_a_ents, str): _a_ents = [_a_ents]
_solver    = retriever.pcst_solver

print("=" * 70)
print(f"PROBE  [{_probe_idx}]: {_ex['id']}")
print(f"  Q        : {_q}")
print(f"  q_entity : {_q_ents}")
print(f"  a_entity : {_a_ents}")
print("=" * 70)

# ── A: Query embedding ─────────────────────────────────────────────────────
_qemb = retriever.embedder.embed_texts([_q], show_progress=False)[0]
print(f"\n[A] Query embedding  shape={_qemb.shape}  norm={np.linalg.norm(_qemb):.4f}")

# ── B: k-NN search ────────────────────────────────────────────────────────
_knn = retriever.entity_index.search(_qemb, k=config.top_k_entities)
print(f"\n[B] k-NN top-{config.top_k_entities}  (score | in_graph | name):")
for _ent, _sc in _knn:
    _tags = ("  ← q_entity" if _ent in _q_ents else "") + \
            ("  ← a_entity" if _ent in _a_ents else "")
    print(f"  {_sc:.4f}  {'✓' if _ent in retriever.unified_graph else '✗'}  "
          f"{_ent[:55]}{_tags}")

# ── C: Seed list & prize construction (mirrors retriever.retrieve) ─────────
_seed_ents, _sim_scores, _q_ent_names = [], {}, set()
for _qe in _q_ents:
    if _qe in retriever.unified_graph:
        _seed_ents.append(_qe); _sim_scores[_qe] = 1.0; _q_ent_names.add(_qe)
for _ent, _sc in _knn:
    if _ent not in _sim_scores:
        _seed_ents.append(_ent); _sim_scores[_ent] = _sc

_SIM_THRESH = 0.4
_prizes = {_qe: 1.0 for _qe in _q_ent_names}
for _ent, _sc in _knn:
    if _ent not in _prizes and _sc >= _SIM_THRESH:
        _prizes[_ent] = float(_sc)

print(f"\n[C] Seeds & global prizes:")
print(f"  Seeds total         : {len(_seed_ents)}")
print(f"  Seeds in graph      : {sum(1 for s in _seed_ents if s in retriever.unified_graph)}")
print(f"  Prized nodes (global): {len(_prizes)}")
if _prizes:
    _pv = list(_prizes.values())
    print(f"  Prize range         : {min(_pv):.4f} – {max(_pv):.4f}")
print(f"  PCST edge cost      : {_solver.cost}")
_c_above = sum(1 for p in _prizes.values() if p > _solver.cost) if _prizes else 0
print(f"  Prizes > cost       : {_c_above}/{len(_prizes)}")
if _prizes and _c_above == 0:
    print(f"  ⚠ ALL GLOBAL PRIZES ≤ COST — every edge is net-negative before local prizes")
elif not _prizes:
    print(f"  ⚠ NO PRIZES — no k-NN seed scored ≥ {_SIM_THRESH} similarity threshold")

# ── D: BFS localisation ────────────────────────────────────────────────────
_root_arg  = list(_q_ent_names) or None
_local_g   = _solver._localize(retriever.unified_graph, _seed_ents, root_entities=_root_arg)
_root_node = _solver._pick_root(_local_g, _seed_ents, _root_arg, _prizes)
_root_comp, _n_comps = _solver._root_component(_local_g, _root_node)
_prized_in_root = {n: _prizes[n] for n in _prizes if n in _root_comp}

print(f"\n[D] BFS localisation (budget={_solver.local_budget}):")
print(f"  Local graph         : {len(_local_g)} nodes, {_local_g.number_of_edges()} edges")
print(f"  WCC in local graph  : {_n_comps}")
print(f"  Selected root       : '{_root_node[:60]}'")
print(f"  Root component      : {len(_root_comp)} nodes, {_root_comp.number_of_edges()} edges")
print(f"  k-NN prizes in root : {len(_prized_in_root)} / {len(_prizes)}")
if len(_prized_in_root) < len(_prizes):
    _missing = [n for n in _prizes if n not in _root_comp]
    print(f"  Prized nodes NOT in root component: {_missing[:5]}"
          f"{'...' if len(_missing) > 5 else ''}")

print(f"\n  Answer entity trace:")
for _ae in _a_ents[:4]:
    _in_g = _ae in retriever.unified_graph
    _in_l = _ae in _local_g
    _in_r = _ae in _root_comp
    _dist = "?"
    if _in_r and _ae != _root_node:
        try:
            _dist = nx.shortest_path_length(_root_comp.to_undirected(), _root_node, _ae)
        except (nx.NetworkXNoPath, nx.NodeNotFound):
            _dist = "disconnected"
    elif _ae == _root_node:
        _dist = 0
    print(f"    '{_ae[:45]}': in_graph={_in_g}, in_local={_in_l}, "
          f"in_root={_in_r}, hops_from_root={_dist}")

# ── E: Local prize computation ─────────────────────────────────────────────
_local_prizes = _solver._compute_local_prizes(
    _root_comp, _prizes, _qemb, retriever.entity_embeddings)
_all_prized = {n: p for n, p in _local_prizes.items() if p > 0}
_above_cost = {n: p for n, p in _local_prizes.items() if p > _solver.cost}
_n_with_emb = sum(1 for n in _root_comp if n in retriever.entity_embeddings)

print(f"\n[E] Local prizes (merged: global k-NN + local cosine sim + existence):")
print(f"  Nodes in root comp      : {len(_root_comp)}")
print(f"  With entity embedding   : {_n_with_emb} ({_n_with_emb/max(len(_root_comp),1)*100:.0f}%)")
print(f"  local_prize_threshold   : {_solver.local_prize_threshold}")
print(f"  existence_prize         : {_solver.existence_prize}")
print(f"  Nodes with prize > 0    : {len(_all_prized)}")
print(f"  Nodes with prize > cost : {len(_above_cost)}  (cost={_solver.cost})")
if _all_prized:
    _lpv = sorted(_all_prized.values(), reverse=True)
    print(f"  Prize top-5             : {[f'{p:.4f}' for p in _lpv[:5]]}")
    print(f"  Prize mean / max        : {sum(_lpv)/len(_lpv):.4f} / {_lpv[0]:.4f}")

if not _all_prized:
    print(f"\n  ⚠ CRITICAL: Zero prizes in root component — PCST will return only root!")
elif not _above_cost:
    print(f"\n  ⚠ CRITICAL: Max prize ({max(_all_prized.values()):.4f}) ≤ "
          f"cost ({_solver.cost}) — PCST returns only root!")
else:
    print(f"\n  ✓ {len(_above_cost)} node(s) have prize > cost — PCST should expand from root")

# ── F: Raw pcst_fast call ──────────────────────────────────────────────────
print(f"\n[F] Raw pcst_fast call:")
_sel = np.array([0])  # fallback
if _pcst_fast_mod is not None and _root_comp.number_of_edges() > 0:
    _G_und   = _root_comp.to_undirected()
    _nodes   = list(_G_und.nodes())
    _n2i     = {n: i for i, n in enumerate(_nodes)}
    _edges_l = [(int(_n2i[u]), int(_n2i[v])) for u, v in _G_und.edges()]
    _ea      = np.array(_edges_l, dtype=np.int32)
    _pa      = np.zeros(len(_nodes), dtype=np.float64)
    for _n, _p in _local_prizes.items():
        if _n in _n2i:
            _pa[_n2i[_n]] = max(_p, 0.0)
    _ca = np.full(len(_edges_l), _solver.cost, dtype=np.float64)
    _ca = np.maximum(_ca, 1e-9)

    _re = next((_qe for _qe in _q_ents if _qe in _n2i), None)
    if _re is None:
        _re = max((_s for _s in _seed_ents if _s in _n2i),
                  key=lambda s: _local_prizes.get(s, 0.0), default=_nodes[0])
    _ri = int(_n2i[_re])

    print(f"  nodes={len(_nodes)}, edges={len(_edges_l)}, "
          f"root='{_re[:40]}' (idx={_ri})")
    print(f"  edge cost={_solver.cost}, pruning='{_solver.pruning}'")
    print(f"  scored nodes       : {int(np.count_nonzero(_pa))}")
    print(f"  prize array        : max={_pa.max():.6f}  "
          f"mean(>0)={_pa[_pa>0].mean():.6f if _pa.any() else 'N/A'}")
    print(f"  prize / cost ratio : {_pa.max()/_solver.cost:.4f}x  "
          f"(need >1.0 for PCST to expand)")

    _rn, _ = _pcst_fast_mod.pcst_fast(_ea, _pa, _ca, _ri, 1, _solver.pruning, 0)
    _rn = np.asarray(_rn, dtype=np.int64)

    print(f"\n  Raw output array length : {len(_rn)}")
    if len(_rn) == len(_nodes):
        print(f"  ➜ LABELS format (len == num_nodes={len(_nodes)})")
        _rl  = int(_rn[_ri])
        _sel = np.where(_rn == _rl)[0]
        print(f"  Root label={_rl}, nodes in root cluster: {len(_sel)}")
    else:
        print(f"  ➜ INDICES format")
        _sel = np.unique(_rn)
        _sel = _sel[(_sel >= 0) & (_sel < len(_nodes))]

    print(f"\n  *** PCST selected {len(_sel)} node(s) ***")
    for _i in _sel[:10]:
        print(f"    [{_i}] '{_nodes[_i][:55]}'  prize={_pa[_i]:.4f}")
    if len(_sel) > 10:
        print(f"    ... and {len(_sel)-10} more")
else:
    print("  Skipped — no edges in root component or pcst_fast not available")
    _nodes = list(_root_comp.nodes())
    _n2i   = {n: i for i, n in enumerate(_nodes)}
    _pa    = np.zeros(len(_nodes))
    for _n, _p in _local_prizes.items():
        if _n in _n2i: _pa[_n2i[_n]] = max(_p, 0.0)

# ── VISUALISATIONS ─────────────────────────────────────────────────────────
fig, axes = plt.subplots(1, 3, figsize=(20, 5))
fig.suptitle(f"PCST Diagnostics | {_ex['id']}: \"{_q[:70]}\"", fontsize=10)

# Plot 1 – Prize distribution vs edge cost
ax = axes[0]
_pviz = [p for p in _local_prizes.values() if p > 0]
if _pviz:
    ax.hist(_pviz, bins=min(40, len(_pviz)), color='#42A5F5', edgecolor='white', alpha=0.85)
    ax.axvline(_solver.cost, color='#EF5350', lw=2.5, linestyle='--',
               label=f'Edge cost = {_solver.cost}')
    ax.set_xlabel("Prize value"); ax.set_ylabel("# nodes")
    ax.set_title("Prize Distribution vs Edge Cost"); ax.legend()
    _nb = sum(1 for p in _pviz if p <= _solver.cost)
    _na = sum(1 for p in _pviz if p > _solver.cost)
    ax.text(0.97, 0.97, f"≤ cost: {_nb}\n> cost: {_na}",
            transform=ax.transAxes, va='top', ha='right', fontsize=9,
            bbox=dict(boxstyle='round', fc='lightyellow', alpha=0.8))
else:
    ax.text(0.5, 0.5, "No prizes in\nroot component", ha='center', va='center',
            fontsize=14, color='#EF5350', transform=ax.transAxes)
    ax.set_title("Prize Distribution vs Edge Cost")

# Plot 2 – Pipeline funnel (node counts at each stage)
ax = axes[1]
_funnel = [
    ("k-NN seeds",      len(_seed_ents)),
    ("BFS local",       len(_local_g)),
    ("Root comp",       len(_root_comp)),
    ("Prized (>0)",     len(_all_prized)),
    ("Above cost",      len(_above_cost)),
    ("PCST output",     len(_sel)),
]
_fl, _fv = [f[0] for f in _funnel][::-1], [f[1] for f in _funnel][::-1]
_fc = ['#0D47A1','#1B5E20','#E65100','#4A148C','#B71C1C','#004D40'][::-1]
_fb = ax.barh(_fl, _fv, color=_fc, alpha=0.85)
for _bar, _val in zip(_fb, _fv):
    ax.text(_bar.get_width() + max(_fv)*0.01 + 0.5,
            _bar.get_y() + _bar.get_height()/2,
            str(_val), va='center', fontsize=10, fontweight='bold')
ax.set_xlabel("# nodes"); ax.set_title("Pipeline Funnel (single example)")
ax.grid(axis='x', alpha=0.3); ax.set_xlim(0, max(max(_fv), 1) * 1.18)

# Plot 3 – Subgraph size distribution across all 50 val examples
ax = axes[2]
ax.hist(subgraph_sizes, bins=range(0, max(subgraph_sizes) + 2),
        color='#42A5F5', edgecolor='white', alpha=0.85)
ax.axvline(1.5, color='#EF5350', lw=2, linestyle='--', label='Size ≤ 1 (degenerate)')
ax.set_xlabel("Subgraph size (nodes)"); ax.set_ylabel("# examples")
ax.set_title(f"Subgraph Sizes across {len(subgraph_sizes)} val examples"); ax.legend()
_ndeg = sum(1 for s in subgraph_sizes if s <= 1)
ax.text(0.97, 0.97, f"Degenerate (≤1): {_ndeg}/{len(subgraph_sizes)}\n"
        f"Hit rate: {hit_rate:.1f}%\nMean size: {avg_size:.1f}",
        transform=ax.transAxes, va='top', ha='right', fontsize=9,
        bbox=dict(boxstyle='round', fc='lightyellow', alpha=0.8))

plt.tight_layout(); plt.show()

# ── Graph: root component coloured by prize ────────────────────────────────
_top_n = min(60, len(_root_comp))
_top_nodes = [n for n, _ in sorted(_local_prizes.items(),
              key=lambda x: x[1], reverse=True)][:_top_n]
_G_viz = _root_comp.subgraph(_top_nodes).copy()

if len(_G_viz) > 1:
    _pos = nx.spring_layout(_G_viz, k=2.0/max(len(_G_viz)**0.5, 1),
                             iterations=60, seed=42)
    _ncolors, _nsizes = [], []
    for _nd in _G_viz.nodes():
        _pr = _local_prizes.get(_nd, 0.0)
        if _nd in _a_ents:         _ncolors.append('#FF1744')  # answer
        elif _nd in _q_ent_names:  _ncolors.append('#00C853')  # root/q_entity
        elif _pr > _solver.cost:   _ncolors.append('#1565C0')  # high prize
        elif _pr > 0:              _ncolors.append('#90CAF9')  # low prize
        else:                      _ncolors.append('#BDBDBD')  # relay only
        _nsizes.append(max(150, min(_pr * 4000, 1500)))

    fig2, ax2 = plt.subplots(figsize=(14, 9))
    nx.draw_networkx_edges(_G_viz, _pos, ax=ax2, alpha=0.12, arrows=True, arrowsize=7)
    nx.draw_networkx_nodes(_G_viz, _pos, ax=ax2, node_color=_ncolors,
                           node_size=_nsizes, alpha=0.9)
    _top15     = [n for n, _ in sorted(_local_prizes.items(),
                  key=lambda x: x[1], reverse=True)[:15] if n in _G_viz]
    nx.draw_networkx_labels(_G_viz, _pos, {n: n[:18] for n in _top15},
                            ax=ax2, font_size=7)
    ax2.legend(handles=[
        mpatches.Patch(color='#00C853', label='q_entity (root)'),
        mpatches.Patch(color='#FF1744', label='answer entity'),
        mpatches.Patch(color='#1565C0', label=f'prize > cost={_solver.cost}'),
        mpatches.Patch(color='#90CAF9', label='prize ∈ (0, cost]'),
        mpatches.Patch(color='#BDBDBD', label='prize = 0 (relay only)'),
    ], loc='upper right', fontsize=8, title="Node type")
    ax2.set_title(
        f"Root component — top-{_top_n} nodes by prize\n"
        f"Root: '{_root_node[:60]}'\n"
        f"Q: \"{_q[:90]}\"", fontsize=10)
    ax2.axis('off'); plt.tight_layout(); plt.show()
else:
    print(f"Root component has only {len(_G_viz)} visible node(s) — nothing to visualise.")

# ── Diagnostic summary ─────────────────────────────────────────────────────
print("\n" + "=" * 70)
print("DIAGNOSTIC SUMMARY")
print("=" * 70)
_diag = [
    ("example",              _ex['id']),
    ("q_entity in graph",    str([q in retriever.unified_graph for q in _q_ents])),
    ("global prizes",        f"{len(_prizes)} nodes"),
    ("BFS local size",       f"{len(_local_g)} nodes ({_n_comps} WCC)"),
    ("root comp size",       f"{len(_root_comp)} nodes"),
    ("prized in root",       f"{len(_all_prized)} nodes"),
    ("above cost",           f"{len(_above_cost)} nodes (cost={_solver.cost})"),
    ("entity emb coverage",  f"{_n_with_emb}/{len(_root_comp)} ({_n_with_emb/max(len(_root_comp),1)*100:.0f}%)"),
    ("PCST output",          f"{len(_sel)} node(s)"),
    ("solver.pruning",       repr(_solver.pruning)),
    ("local_prize_threshold",str(_solver.local_prize_threshold)),
    ("existence_prize",      str(_solver.existence_prize)),
]
for _k, _v in _diag:
    print(f"  {_k:<26}: {_v}")


## Cell 11: Phase 2 Success Criteria

Validate Phase 2 completion criteria.

In [None]:
import networkx as nx

print("\n" + "=" * 60)
print("PHASE 2 SUCCESS CRITERIA")
print("=" * 60)

# Criterion 1: Retrieval speed < 1 second
speed_pass = avg_time < 1000  # ms
print(f"[{'✓' if speed_pass else '✗'}] Retrieval completes in <1 second: {avg_time:.1f}ms")

# Criterion 2: Hit rate > 60%
hit_pass = hit_rate >= 60.0
print(f"[{'✓' if hit_pass else '✗'}] Subgraph contains answer entity >60%: {hit_rate:.1f}%")

# Criterion 3: All subgraphs connected
all_connected = all(
    nx.is_weakly_connected(
        retriever.retrieve(
            example["question"],
            q_entity=example.get("q_entity")
        ).subgraph
    )
    for example in val_examples[:5]  # Check first 5
)
print(f"[{'✓' if all_connected else '✗'}] All subgraphs are connected")

# Criterion 4: Subgraph size respects budget
size_pass = max(subgraph_sizes) <= config.pcst_budget
print(f"[{'✓' if size_pass else '✗'}] Subgraph size ≤ budget ({max(subgraph_sizes)} ≤ {config.pcst_budget})")

# Overall pass
all_pass = speed_pass and hit_pass and all_connected and size_pass
print("\n" + "=" * 60)
if all_pass:
    print("✓ PHASE 2 COMPLETE - All criteria met!")
    print("\nReady to proceed to Phase 3: GNN Encoder")
else:
    print("⚠ PHASE 2 INCOMPLETE - Review failed criteria above")
print("=" * 60)

# Phase 3: GNN Encoder
## Cell 12: Build/Load GNN Model

In [None]:
!uv pip install torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0

In [None]:
# Install torch_geometric and its dependencies
print("Installing torch_geometric...")
!uv pip install torch_geometric torch_scatter torch_sparse -f https://data.pyg.org/whl/torch-2.8.0+cu128.html
print("✓ torch_geometric installed.")

In [None]:
# ============================================================
# PHASE 3: GNN Encoder
# ============================================================

print("Building GNN Model...")
print("This will either:")
print("  1. Load pre-trained model from checkpoint, OR")
print("  2. Prepare data and train from scratch (~30 min)")
print()

from src.gnn import GNNModel

# Build GNN model (handles checkpoint loading or training automatically)
gnn_model = GNNModel.build_from_checkpoint_or_train(
    config=config,
    retriever=retriever,
    train_data=dataset["train"],
    val_data=dataset["validation"],
    encoder_type="gatv2",  # or "graphsage"
    pooling_type="attention",  # or "mean", "max"
)

print("\n" + "="*60)
print("GNN Model Ready")
print("="*60)

## Cell 13: Test GNN Inference

In [None]:
# Test GNN encoding on a single example
print("Testing GNN inference on example query...\n")

test_question = "Who is Justin Bieber's brother?"
print(f"Question: {test_question}")

# Retrieve subgraph
retrieved = retriever.retrieve(test_question)
print(f"Retrieved subgraph: {retrieved.num_nodes} nodes, {retrieved.num_edges} edges")

# Encode with GNN
gnn_output = gnn_model.encode(retrieved, test_question)

print(f"\nGNN Output:")
print(f"  Node embeddings shape: {gnn_output.node_embeddings.shape}")
print(f"  Graph embedding shape: {gnn_output.graph_embedding.shape}")
print(f"  Attention scores: {len(gnn_output.attention_scores)} nodes")

# Get top attention nodes
top_nodes = gnn_model.get_top_attention_nodes(gnn_output, top_k=10)
print(f"\nTop 10 nodes by attention score:")
for i, (node, score) in enumerate(top_nodes, 1):
    print(f"  {i}. {node}: {score:.4f}")

## Cell 14: Validate GNN Metrics

In [None]:
# Load training history and display metrics
import json
import matplotlib.pyplot as plt

history_path = config.get_checkpoint_path("gnn_training_history.json")
with open(history_path, "r") as f:
    history = json.load(f)

print("="*60)
print("PHASE 3 VALIDATION: GNN Metrics")
print("="*60)

# Best metrics
best_val_f1 = max(history["val_f1"])
best_val_loss = min(history["val_loss"])
final_val_f1 = history["val_f1"][-1]

print(f"\nTraining Summary:")
print(f"  Epochs trained: {len(history['train_loss'])}")
print(f"  Best validation F1: {best_val_f1:.3f}")
print(f"  Best validation loss: {best_val_loss:.4f}")
print(f"  Final validation F1: {final_val_f1:.3f}")

# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Loss curve
axes[0].plot(history["train_loss"], label="Train Loss")
axes[0].plot(history["val_loss"], label="Val Loss")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].set_title("Training and Validation Loss")
axes[0].legend()
axes[0].grid(True)

# F1 curve
axes[1].plot(history["train_f1"], label="Train F1")
axes[1].plot(history["val_f1"], label="Val F1")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("F1 Score")
axes[1].set_title("Answer Node Prediction F1")
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.show()

# Success criteria check
print(f"\n{'='*60}")
print("Success Criteria:")
print(f"{'='*60}")

criteria = [
    ("Validation F1 > 0.5", best_val_f1 > 0.5, best_val_f1),
    ("Training completed < 30 min", True, "N/A"),  # User observation
    ("No OOM errors", True, "N/A"),  # User observation
]

for criterion, passed, value in criteria:
    status = "✓ PASS" if passed else "✗ FAIL"
    print(f"{status} - {criterion}: {value}")

if all(c[1] for c in criteria):
    print(f"\n{'='*60}")
    print("SUCCESS: Phase 3 Complete")
    print(f"{'='*60}")
else:
    print(f"\n{'='*60}")
    print("FAILED: Some criteria not met")
    print(f"{'='*60}")

## Cell 15: Visualize Attention on Example

In [None]:
# Visualize GNN attention on a subgraph
import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

def visualize_gnn_attention(
    subgraph: nx.DiGraph,
    attention_scores: dict,
    answer_entities: list,
    question: str,
    top_k: int = 20,
):
    """
    Visualize GNN attention on a subgraph.

    Args:
        subgraph: NetworkX DiGraph
        attention_scores: Dict[node_name, float]
        answer_entities: List of ground truth answer entities
        question: Question text
        top_k: Show only top-K nodes by attention
    """
    # Get top-K nodes by attention
    sorted_nodes = sorted(
        attention_scores.items(), key=lambda x: x[1], reverse=True
    )[:top_k]
    top_nodes = [node for node, _ in sorted_nodes]

    # Create subgraph with only top nodes
    G_viz = subgraph.subgraph(top_nodes).copy()

    # Node colors (red = answer, blue = high attention, gray = low attention)
    node_colors = []
    for node in G_viz.nodes():
        if node in answer_entities:
            node_colors.append("red")
        else:
            # Scale by attention (darker = higher attention)
            attn = attention_scores.get(node, 0.0)
            intensity = min(attn * 10, 1.0)  # Scale for visibility
            node_colors.append((0.2, 0.4, 0.8, 0.3 + 0.7 * intensity))

    # Node sizes proportional to attention
    node_sizes = [
        300 + 2000 * attention_scores.get(node, 0.0) for node in G_viz.nodes()
    ]

    # Layout
    pos = nx.spring_layout(G_viz, k=0.5, iterations=50, seed=42)

    # Plot
    plt.figure(figsize=(14, 10))
    plt.title(f"GNN Attention Visualization\nQ: {question}", fontsize=12)

    # Draw edges
    nx.draw_networkx_edges(
        G_viz, pos, alpha=0.3, arrows=True, arrowsize=10, width=1.0
    )

    # Draw nodes
    nx.draw_networkx_nodes(
        G_viz, pos, node_color=node_colors, node_size=node_sizes, alpha=0.9
    )

    # Draw labels (only for top 10)
    labels = {node: node[:20] for node in list(G_viz.nodes())[:10]}
    nx.draw_networkx_labels(G_viz, pos, labels, font_size=8)

    plt.axis("off")
    plt.tight_layout()
    plt.show()

    # Print attention scores
    print("Top 10 attention scores:")
    for i, (node, score) in enumerate(sorted_nodes[:10], 1):
        is_answer = "✓ ANSWER" if node in answer_entities else ""
        print(f"  {i}. {node[:30]}: {score:.4f} {is_answer}")


# Test visualization
test_question = "Who is Barack Obama's spouse?"
test_answer = ["Michelle Obama", "m.025s5v9"]  # Freebase ID

retrieved = retriever.retrieve(test_question)
gnn_output = gnn_model.encode(retrieved, test_question)

visualize_gnn_attention(
    subgraph=retrieved.subgraph,
    attention_scores=gnn_output.attention_scores,
    answer_entities=test_answer,
    question=test_question,
    top_k=20,
)

## Cell 16: Memory Check

In [None]:
# Check GPU memory usage
if torch.cuda.is_available():
    print("GPU Memory Summary:")
    print(f"  Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
    print(f"  Reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
    print(f"  Max allocated: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")

    # Verify no memory leak
    assert torch.cuda.memory_allocated() / 1e9 < 14.0, "Memory leak detected!"
    print("\n✓ Memory usage within acceptable range")
else:
    print("GPU not available")