# arcOS Benchmark - Causal QA with GNN + LLM

**Phase 1: Environment & Data Foundation**

This notebook implements the complete arcOS benchmark pipeline:
- Graph Neural Network structural reasoning over knowledge graphs
- LLM text generation with graph-guided prompts
- Evaluation on RoG-WebQSP question answering dataset

**Requirements:**
- Google Colab with GPU runtime (T4 or better)
- Google Drive mounted for checkpointing
- ~10GB free space on Drive

**Architecture:**
- Dataset: RoG-WebQSP (4,706 QA pairs with Freebase subgraphs)
- Graph DB: NetworkX in-memory
- GNN: Graph Attention Network (GATv2)
- LLM: OpenRouter API (Claude 3.5 Sonnet)
- Verbalization: Hard prompts (text-based, not soft embeddings)

## Cell 0: Cloning Repository
Update notebook with latest updates from GitHub repo.

In [None]:
# Remove existing directory if it exists, then clone fresh
!rm -rf /content/arcOS-benchmark-colab
!git clone https://github.com/ashtonalex/arcOS-benchmark-colab /content/arcOS-benchmark-colab

## Cell 1: Environment Setup

Install dependencies and verify GPU availability.

In [None]:
# ============================================================================
# ENVIRONMENT SETUP WITH UV PACKAGE MANAGER
# Ensures absolute environment parity between kernel and installed packages
# ============================================================================

import sys
import os
import subprocess
from pathlib import Path

# Colab UV workaround: Clear broken constraint files
os.environ["UV_CONSTRAINT"] = ""
os.environ["UV_BUILD_CONSTRAINT"] = ""

print("="*70)
print("STEP 1: ENVIRONMENT PATH VERIFICATION")
print("="*70)

# Capture current Python executable
current_python = sys.executable
print(f"Current kernel executable: {current_python}")
print(f"Python version: {sys.version}")
print(f"Site packages: {sys.path[0] if sys.path else 'N/A'}")

# Check if uv is available
def check_uv_available():
    """Check if uv is installed and accessible."""
    try:
        result = subprocess.run(
            ['uv', '--version'],
            capture_output=True,
            text=True,
            timeout=5
        )
        return result.returncode == 0
    except (FileNotFoundError, subprocess.TimeoutExpired):
        return False

uv_available = check_uv_available()

if not uv_available:
    print("\n⚠ UV not found. Installing uv package manager...")
    %pip install -q uv
    uv_available = check_uv_available()

if uv_available:
    # Get uv version
    uv_version = subprocess.run(
        ['uv', '--version'],
        capture_output=True,
        text=True
    ).stdout.strip()
    print(f"✓ UV available: {uv_version}")
else:
    print("✗ UV installation failed. Will fall back to pip.")

print("\n" + "="*70)
print("STEP 2: PACKAGE INSTALLATION")
print("="*70)

# Define packages to install
packages = [
    "datasets",
    "networkx",
    "tqdm",
    "faiss-gpu-cu12" # Added faiss-gpu
]

# PyTorch with CUDA support
torch_packages = "torch torchvision torchaudio"
torch_index = "https://download.pytorch.org/whl/cu118"

if uv_available:
    print(f"Installing packages using UV with --python {current_python}\n")

    # Install PyTorch with CUDA
    print("Installing PyTorch with CUDA 11.8 support...")
    !uv pip install --python {current_python} {torch_packages} --index-url {torch_index}

    # Install other packages
    print("\nInstalling additional dependencies...")
    for package in packages:
        !uv pip install --python {current_python} {package}
else:
    print("Falling back to standard pip installation\n")

    # Install PyTorch with CUDA
    print("Installing PyTorch with CUDA 11.8 support...")
    %pip install -q {torch_packages} --index-url {torch_index}

    # Install other packages
    print("\nInstalling additional dependencies...")
    %pip install -q {' '.join(packages)}

print("\n" + "="*70)
print("STEP 3: INSTALLATION VERIFICATION")
print("="*70)

# Verify installed packages are in the correct location
def verify_package_location(package_name):
    """Verify package is installed in current kernel's site-packages."""
    try:
        module = __import__(package_name)
        module_path = Path(module.__file__).parent

        # Check if module is in one of sys.path locations
        in_sys_path = any(str(module_path).startswith(p) for p in sys.path if p)

        # Get version if available
        version = getattr(module, '__version__', 'unknown')

        return {
            'installed': True,
            'version': version,
            'location': str(module_path),
            'in_sys_path': in_sys_path
        }
    except ImportError:
        return {'installed': False}

# Verify key packages
verification_packages = ['torch', 'datasets', 'networkx', 'tqdm', 'faiss'] # Added faiss for verification
print("\nVerifying installed packages:\n")

all_verified = True
for pkg in verification_packages:
    info = verify_package_location(pkg)
    if info['installed']:
        status = "✓" if info['in_sys_path'] else "⚠"
        print(f"{status} {pkg:12s} v{info['version']:12s}")
        print(f"  Location: {info['location']}")
        if not info['in_sys_path']:
            print(f"  WARNING: Not in sys.path!")
            all_verified = False
    else:
        print(f"✗ {pkg:12s} NOT INSTALLED")
        all_verified = False
    print()

print("="*70)
print("STEP 4: GPU VERIFICATION")
print("="*70)

import torch
gpu_available = torch.cuda.is_available()
print(f"\nGPU available: {gpu_available} {'✓' if gpu_available else '✗'}")

if gpu_available:
    print(f"GPU name: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"CUDA version: {torch.version.cuda}")
else:
    print("⚠ Warning: No GPU detected.")
    print("  Go to: Runtime -> Change runtime type -> Select T4 GPU")

print("\n" + "="*70)
print("ENVIRONMENT SETUP SUMMARY")
print("="*70)
print(f"Package manager: {'UV' if uv_available else 'pip'}")
print(f"Python executable: {current_python}")
print(f"All packages verified: {'✓ YES' if all_verified else '✗ NO'}")
print(f"GPU available: {'✓ YES' if gpu_available else '✗ NO'}")
print("="*70)

if not all_verified:
    print("\n⚠ WARNING: Some packages failed verification. Check output above.")
else:
    print("\n✓ Environment setup complete with full parity!")


## Cell 2: Clean Room Import

Purge bytecode caches, scrub `sys.modules`, pin `sys.path`, and verify source file integrity before importing any `src/` modules. Run this cell after every `git pull` or code change.

In [None]:
# ============================================================================
# CLEAN ROOM IMPORT — guarantees fresh module loading
# Run after every git pull / code edit to eliminate stale bytecode & caches
# ============================================================================

import sys, os, shutil, hashlib, importlib
from pathlib import Path
from datetime import datetime, timezone

# ── CONFIG ──────────────────────────────────────────────────────────────────
REPO_ROOT = Path("/content/arcOS-benchmark-colab")
SRC_ROOT  = REPO_ROOT / "src"
# Files to fingerprint (add any core logic files you want to verify)
VERIFY_FILES = [
    SRC_ROOT / "config.py",
    SRC_ROOT / "retrieval" / "pcst_solver.py",
    SRC_ROOT / "gnn" / "encoder.py",
]
# Module prefixes to scrub from sys.modules
SCRUB_PREFIXES = ("src", "src.")

print("=" * 70)
print("CLEAN ROOM IMPORT")
print("=" * 70)

# ── STEP 1: Bytecode Purge ─────────────────────────────────────────────────
print("\n[1/4] Purging __pycache__ and .pyc files...")
cache_dirs_removed = 0
pyc_files_removed  = 0

for cache_dir in SRC_ROOT.rglob("__pycache__"):
    shutil.rmtree(cache_dir)
    cache_dirs_removed += 1

for pyc_file in SRC_ROOT.rglob("*.pyc"):
    pyc_file.unlink()
    pyc_files_removed += 1

print(f"  Removed {cache_dirs_removed} __pycache__ dirs, {pyc_files_removed} .pyc files")

# ── STEP 2: sys.modules Scrub ──────────────────────────────────────────────
print("\n[2/4] Scrubbing src.* from sys.modules...")
stale_keys = [k for k in sys.modules if k.startswith(SCRUB_PREFIXES)]
for key in stale_keys:
    del sys.modules[key]
print(f"  Evicted {len(stale_keys)} cached modules: {stale_keys[:8]}{'...' if len(stale_keys) > 8 else ''}")

# ── STEP 3: Path Priority ─────────────────────────────────────────────────
print("\n[3/4] Pinning sys.path priority...")
repo_str = str(REPO_ROOT)
# Remove any existing entries to avoid duplicates
sys.path = [p for p in sys.path if p != repo_str]
# Insert at position 0 so our src/ wins over any pip-installed version
sys.path.insert(0, repo_str)
print(f"  sys.path[0] = {sys.path[0]}")

# ── STEP 4: Source File Validation ─────────────────────────────────────────
print("\n[4/4] Verifying source file integrity...")

def file_fingerprint(path: Path) -> dict:
    """Return last-modified timestamp and MD5 hash for a file."""
    data = path.read_bytes()
    md5  = hashlib.md5(data).hexdigest()
    mtime = datetime.fromtimestamp(path.stat().st_mtime, tz=timezone.utc)
    return {"md5": md5, "modified": mtime.strftime("%Y-%m-%d %H:%M:%S UTC"), "size": len(data)}

for fpath in VERIFY_FILES:
    if fpath.exists():
        info = file_fingerprint(fpath)
        rel  = fpath.relative_to(REPO_ROOT)
        print(f"  ✓ {rel}")
        print(f"    Modified : {info['modified']}")
        print(f"    MD5      : {info['md5']}")
        print(f"    Size     : {info['size']:,} bytes")
    else:
        print(f"  ⚠ {fpath} — NOT FOUND (skipped)")

# ── STEP 5: Fresh Imports ──────────────────────────────────────────────────
print("\n" + "-" * 70)
print("Importing modules (fresh)...")

from src.config import BenchmarkConfig
from src.utils.seeds import set_seeds
from src.utils.checkpoints import (
    ensure_drive_mounted,
    checkpoint_exists,
    save_checkpoint,
    load_checkpoint,
    create_checkpoint_dirs,
)
from src.data.dataset_loader import RoGWebQSPLoader
from src.data.graph_builder import GraphBuilder

print("✓ All imports successful (clean load)")

print("\n" + "=" * 70)
print("CLEAN ROOM COMPLETE")
print("=" * 70)

## Cell 3: Configuration

Initialize benchmark configuration with hyperparameters.

In [None]:
# Initialize configuration
config = BenchmarkConfig(
    seed=42,
    deterministic=True,
    drive_root="/content/drive/MyDrive/arcOS_benchmark",
)

# Print configuration summary
config.print_summary()

## Cell 4: Seed Initialization

Set random seeds for reproducibility.

In [None]:
# Set seeds for reproducibility
set_seeds(config.seed, config.deterministic)

## Cell 5: Google Drive Setup

Mount Google Drive and create checkpoint/results directories.

In [None]:
# Mount Google Drive
drive_mounted = ensure_drive_mounted()

if drive_mounted:
    # Create checkpoint and results directories
    create_checkpoint_dirs(config.checkpoint_dir, config.results_dir)
else:
    print("⚠ Warning: Drive not mounted. Checkpointing will not work.")
    print("  Continuing with local /content/ storage (temporary)")

## Cell 6: Dataset Loading

Load RoG-WebQSP dataset from HuggingFace with Drive caching.

In [None]:
# Initialize dataset loader
cache_dir = config.checkpoint_dir / "huggingface_cache"
loader = RoGWebQSPLoader(cache_dir=cache_dir)

# Check for cached dataset
dataset_checkpoint_path = config.get_checkpoint_path("dataset.pkl")

dataset = None # Initialize dataset to None

if checkpoint_exists(dataset_checkpoint_path):
    print("Loading dataset from checkpoint...")
    try:
        dataset = load_checkpoint(dataset_checkpoint_path, format="pickle")
        print("✓ Dataset loaded from checkpoint.")
    except FileNotFoundError as e:
        print(f"⚠ Warning: Failed to load dataset from checkpoint due to missing files: {e}")
        print("  Falling back to downloading dataset from HuggingFace...")
        # If loading fails, proceed to download
        pass # dataset remains None, so the next block will execute

if dataset is None: # If dataset was not loaded successfully or checkpoint didn't exist
    print("Downloading dataset from HuggingFace...")
    dataset = loader.load(dataset_name=config.dataset_name)
    save_checkpoint(dataset, dataset_checkpoint_path, format="pickle")
    print("✓ Dataset downloaded and saved to checkpoint.")

# Slice dataset to desired sizes
dataset = loader.slice_dataset(
    dataset,
    train_size=900,
    val_size=90,
    test_size=None  # Keep all test examples
)

# Inspect dataset schema
loader.inspect_schema(dataset, num_examples=1)

# Compute statistics
loader.compute_statistics(dataset)

# Validate split counts (updated for sliced dataset)
split_valid = loader.validate_split_counts(
    dataset,
    expected_train=900,
    expected_val=90,
    expected_test=1628,  # Keep original test size
)

## Cell 7: Graph Construction

Build NetworkX graphs from dataset triples.

In [None]:
# Initialize graph builder
graph_builder = GraphBuilder(directed=config.graph_directed)

# Check for cached unified graph
unified_graph_path = config.get_checkpoint_path("unified_graph.pkl")

if checkpoint_exists(unified_graph_path):
    print("Loading unified graph from checkpoint...")
    unified_graph = load_checkpoint(unified_graph_path, format="pickle")
else:
    print("Building unified graph from training split...")
    unified_graph = graph_builder.build_unified_graph(dataset["train"])
    save_checkpoint(unified_graph, unified_graph_path, format="pickle")

# Print graph statistics
graph_builder.print_graph_info(unified_graph, name="Unified Training Graph")

# Validate graph size
graph_valid = graph_builder.validate_graph_size(
    unified_graph,
    min_nodes=config.unified_graph_min_nodes,
    min_edges=config.unified_graph_min_edges,
)

# Build sample per-example graph for demonstration
print("\nBuilding sample per-example graph...")
sample_example = dataset["train"][0]
sample_graph = graph_builder.build_from_triples(
    sample_example["graph"],
    graph_id=sample_example["id"]
)
graph_builder.print_graph_info(sample_graph, name="Sample Per-Example Graph")

## Cell 8: Phase 1 Validation

Automated validation of all Phase 1 success criteria.

In [None]:
print("\n" + "="*60)
print("Phase 1 Success Criteria Validation")
print("="*60)

# Collect validation results
validation_results = {
    "GPU Available": torch.cuda.is_available(),
    "All Imports Successful": True,  # If we got here, imports worked
    "Dataset Splits Valid": split_valid,
    "Unified Graph Size Valid": graph_valid,
}

# Test checkpoint round-trip
test_checkpoint_path = config.get_checkpoint_path("test_roundtrip.pkl")
test_data = {"test": "round-trip", "value": 42}
try:
    save_checkpoint(test_data, test_checkpoint_path, format="pickle")
    loaded_data = load_checkpoint(test_checkpoint_path, format="pickle")
    checkpoint_roundtrip_ok = (loaded_data == test_data)
    validation_results["Checkpoint Round-Trip"] = checkpoint_roundtrip_ok
except Exception as e:
    print(f"Checkpoint round-trip failed: {e}")
    validation_results["Checkpoint Round-Trip"] = False

# Print results
print("\nValidation Results:")
all_passed = True
for criterion, passed in validation_results.items():
    status = "✓" if passed else "✗"
    print(f"  {status} {criterion}")
    if not passed:
        all_passed = False

print("\n" + "="*60)
if all_passed:
    print("✓ PHASE 1 COMPLETE - All criteria passed!")
    print("\nReady to proceed to Phase 2: Retrieval Pipeline")
else:
    print("✗ PHASE 1 INCOMPLETE - Some criteria failed")
    print("\nPlease review failed criteria above")
print("="*60)

# Print summary statistics
print("\nPhase 1 Summary:")
print(f"  Dataset: {config.dataset_name}")
print(f"  Training examples: {len(dataset['train'])}")
print(f"  Validation examples: {len(dataset['validation'])}")
print(f"  Test examples: {len(dataset['test'])}")
print(f"  Unified graph nodes: {unified_graph.number_of_nodes()}")
print(f"  Unified graph edges: {unified_graph.number_of_edges()}")
print(f"  Checkpoints saved to: {config.checkpoint_dir}")

## Cell 9: Build Retrieval Pipeline

Initialize retrieval components (embeddings, FAISS index, PCST solver).

In [None]:
!uv pip install pcst-fast

In [None]:
print("=" * 60)
print("PHASE 2: RETRIEVAL PIPELINE")
print("=" * 60)

from src.retrieval import Retriever

# Build retriever (uses checkpoints if available)
retriever = Retriever.build_from_checkpoint_or_new(
    config=config,
    unified_graph=unified_graph  # From Phase 1 Cell 7
)

print("\n✓ Retrieval pipeline initialized")
print(f"  - Entity embeddings: {len(retriever.entity_index)} entities")
print(f"  - Top-K: {config.top_k_entities}")
print(f"  - PCST budget: {config.pcst_budget} nodes")

## Cell 10: Retrieval Validation

Test retrieval pipeline on 10 validation examples.

In [None]:
print("\n" + "=" * 60)
print("RETRIEVAL VALIDATION (10 examples)")
print("=" * 60)

# Use first 10 validation examples
val_examples = list(dataset["validation"].select(range(10)))

hit_count = 0
total_time_ms = 0
subgraph_sizes = []

for i, example in enumerate(val_examples):
    question = example["question"]
    answer_entities = example.get("a_entity", [])
    if isinstance(answer_entities, str):
        answer_entities = [answer_entities]

    # Extract topic entities from dataset
    q_entities = example.get("q_entity", [])
    if isinstance(q_entities, str):
        q_entities = [q_entities]

    # Retrieve subgraph (q_entity used as primary seed when available)
    result = retriever.retrieve(question, q_entity=q_entities)

    # Check if answer entity in subgraph
    subgraph_nodes = set(result.subgraph.nodes())
    hit = any(ans in subgraph_nodes for ans in answer_entities)

    if hit:
        hit_count += 1

    total_time_ms += result.retrieval_time_ms
    subgraph_sizes.append(result.num_nodes)

    # Print example
    print(f"\n[{i+1}/10] Q: {question[:60]}...")
    print(f"  Topic entities (q_entity): {q_entities}")
    print(f"  Answer entities: {answer_entities}")
    print(f"  Seeds used: {result.seed_entities[:5]}{'...' if len(result.seed_entities) > 5 else ''}")
    print(f"  Subgraph: {result.num_nodes} nodes, {result.num_edges} edges")
    print(f"  Hit: {'✓' if hit else '✗'}")
    print(f"  Time: {result.retrieval_time_ms:.1f}ms")

# Summary metrics
hit_rate = hit_count / len(val_examples) * 100
avg_time = total_time_ms / len(val_examples)
avg_size = sum(subgraph_sizes) / len(subgraph_sizes)

print("\n" + "=" * 60)
print("VALIDATION SUMMARY")
print("=" * 60)
print(f"Hit rate: {hit_rate:.1f}% ({hit_count}/{len(val_examples)})")
print(f"Avg retrieval time: {avg_time:.1f}ms")
print(f"Avg subgraph size: {avg_size:.1f} nodes")
print(f"Max subgraph size: {max(subgraph_sizes)} nodes")

## Cell 11: Phase 2 Success Criteria

Validate Phase 2 completion criteria.

In [None]:
import networkx as nx

print("\n" + "=" * 60)
print("PHASE 2 SUCCESS CRITERIA")
print("=" * 60)

# Criterion 1: Retrieval speed < 1 second
speed_pass = avg_time < 1000  # ms
print(f"[{'✓' if speed_pass else '✗'}] Retrieval completes in <1 second: {avg_time:.1f}ms")

# Criterion 2: Hit rate > 60%
hit_pass = hit_rate >= 60.0
print(f"[{'✓' if hit_pass else '✗'}] Subgraph contains answer entity >60%: {hit_rate:.1f}%")

# Criterion 3: All subgraphs connected
all_connected = all(
    nx.is_weakly_connected(
        retriever.retrieve(
            example["question"],
            q_entity=example.get("q_entity")
        ).subgraph
    )
    for example in val_examples[:5]  # Check first 5
)
print(f"[{'✓' if all_connected else '✗'}] All subgraphs are connected")

# Criterion 4: Subgraph size respects budget
size_pass = max(subgraph_sizes) <= config.pcst_budget
print(f"[{'✓' if size_pass else '✗'}] Subgraph size ≤ budget ({max(subgraph_sizes)} ≤ {config.pcst_budget})")

# Overall pass
all_pass = speed_pass and hit_pass and all_connected and size_pass
print("\n" + "=" * 60)
if all_pass:
    print("✓ PHASE 2 COMPLETE - All criteria met!")
    print("\nReady to proceed to Phase 3: GNN Encoder")
else:
    print("⚠ PHASE 2 INCOMPLETE - Review failed criteria above")
print("=" * 60)

# Phase 3: GNN Encoder
## Cell 12: Build/Load GNN Model

In [None]:
!uv pip install torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0

In [None]:
# Install torch_geometric and its dependencies
print("Installing torch_geometric...")
!uv pip install torch_geometric torch_scatter torch_sparse -f https://data.pyg.org/whl/torch-2.8.0+cu128.html
print("✓ torch_geometric installed.")

In [None]:
# ============================================================
# PHASE 3: GNN Encoder
# ============================================================

print("Building GNN Model...")
print("This will either:")
print("  1. Load pre-trained model from checkpoint, OR")
print("  2. Prepare data and train from scratch (~30 min)")
print()

from src.gnn import GNNModel

# Build GNN model (handles checkpoint loading or training automatically)
gnn_model = GNNModel.build_from_checkpoint_or_train(
    config=config,
    retriever=retriever,
    train_data=dataset["train"],
    val_data=dataset["validation"],
    encoder_type="gatv2",  # or "graphsage"
    pooling_type="attention",  # or "mean", "max"
)

print("\n" + "="*60)
print("GNN Model Ready")
print("="*60)

## Cell 13: Test GNN Inference

In [None]:
# Test GNN encoding on a single example
print("Testing GNN inference on example query...\n")

test_question = "Who is Justin Bieber's brother?"
print(f"Question: {test_question}")

# Retrieve subgraph
retrieved = retriever.retrieve(test_question)
print(f"Retrieved subgraph: {retrieved.num_nodes} nodes, {retrieved.num_edges} edges")

# Encode with GNN
gnn_output = gnn_model.encode(retrieved, test_question)

print(f"\nGNN Output:")
print(f"  Node embeddings shape: {gnn_output.node_embeddings.shape}")
print(f"  Graph embedding shape: {gnn_output.graph_embedding.shape}")
print(f"  Attention scores: {len(gnn_output.attention_scores)} nodes")

# Get top attention nodes
top_nodes = gnn_model.get_top_attention_nodes(gnn_output, top_k=10)
print(f"\nTop 10 nodes by attention score:")
for i, (node, score) in enumerate(top_nodes, 1):
    print(f"  {i}. {node}: {score:.4f}")

## Cell 14: Validate GNN Metrics

In [None]:
# Load training history and display metrics
import json
import matplotlib.pyplot as plt

history_path = config.get_checkpoint_path("gnn_training_history.json")
with open(history_path, "r") as f:
    history = json.load(f)

print("="*60)
print("PHASE 3 VALIDATION: GNN Metrics")
print("="*60)

# Best metrics
best_val_f1 = max(history["val_f1"])
best_val_loss = min(history["val_loss"])
final_val_f1 = history["val_f1"][-1]

print(f"\nTraining Summary:")
print(f"  Epochs trained: {len(history['train_loss'])}")
print(f"  Best validation F1: {best_val_f1:.3f}")
print(f"  Best validation loss: {best_val_loss:.4f}")
print(f"  Final validation F1: {final_val_f1:.3f}")

# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Loss curve
axes[0].plot(history["train_loss"], label="Train Loss")
axes[0].plot(history["val_loss"], label="Val Loss")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].set_title("Training and Validation Loss")
axes[0].legend()
axes[0].grid(True)

# F1 curve
axes[1].plot(history["train_f1"], label="Train F1")
axes[1].plot(history["val_f1"], label="Val F1")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("F1 Score")
axes[1].set_title("Answer Node Prediction F1")
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.show()

# Success criteria check
print(f"\n{'='*60}")
print("Success Criteria:")
print(f"{'='*60}")

criteria = [
    ("Validation F1 > 0.5", best_val_f1 > 0.5, best_val_f1),
    ("Training completed < 30 min", True, "N/A"),  # User observation
    ("No OOM errors", True, "N/A"),  # User observation
]

for criterion, passed, value in criteria:
    status = "✓ PASS" if passed else "✗ FAIL"
    print(f"{status} - {criterion}: {value}")

if all(c[1] for c in criteria):
    print(f"\n{'='*60}")
    print("SUCCESS: Phase 3 Complete")
    print(f"{'='*60}")
else:
    print(f"\n{'='*60}")
    print("FAILED: Some criteria not met")
    print(f"{'='*60}")

## Cell 15: Visualize Attention on Example

In [None]:
# Visualize GNN attention on a subgraph
import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

def visualize_gnn_attention(
    subgraph: nx.DiGraph,
    attention_scores: dict,
    answer_entities: list,
    question: str,
    top_k: int = 20,
):
    """
    Visualize GNN attention on a subgraph.

    Args:
        subgraph: NetworkX DiGraph
        attention_scores: Dict[node_name, float]
        answer_entities: List of ground truth answer entities
        question: Question text
        top_k: Show only top-K nodes by attention
    """
    # Get top-K nodes by attention
    sorted_nodes = sorted(
        attention_scores.items(), key=lambda x: x[1], reverse=True
    )[:top_k]
    top_nodes = [node for node, _ in sorted_nodes]

    # Create subgraph with only top nodes
    G_viz = subgraph.subgraph(top_nodes).copy()

    # Node colors (red = answer, blue = high attention, gray = low attention)
    node_colors = []
    for node in G_viz.nodes():
        if node in answer_entities:
            node_colors.append("red")
        else:
            # Scale by attention (darker = higher attention)
            attn = attention_scores.get(node, 0.0)
            intensity = min(attn * 10, 1.0)  # Scale for visibility
            node_colors.append((0.2, 0.4, 0.8, 0.3 + 0.7 * intensity))

    # Node sizes proportional to attention
    node_sizes = [
        300 + 2000 * attention_scores.get(node, 0.0) for node in G_viz.nodes()
    ]

    # Layout
    pos = nx.spring_layout(G_viz, k=0.5, iterations=50, seed=42)

    # Plot
    plt.figure(figsize=(14, 10))
    plt.title(f"GNN Attention Visualization\nQ: {question}", fontsize=12)

    # Draw edges
    nx.draw_networkx_edges(
        G_viz, pos, alpha=0.3, arrows=True, arrowsize=10, width=1.0
    )

    # Draw nodes
    nx.draw_networkx_nodes(
        G_viz, pos, node_color=node_colors, node_size=node_sizes, alpha=0.9
    )

    # Draw labels (only for top 10)
    labels = {node: node[:20] for node in list(G_viz.nodes())[:10]}
    nx.draw_networkx_labels(G_viz, pos, labels, font_size=8)

    plt.axis("off")
    plt.tight_layout()
    plt.show()

    # Print attention scores
    print("Top 10 attention scores:")
    for i, (node, score) in enumerate(sorted_nodes[:10], 1):
        is_answer = "✓ ANSWER" if node in answer_entities else ""
        print(f"  {i}. {node[:30]}: {score:.4f} {is_answer}")


# Test visualization
test_question = "Who is Barack Obama's spouse?"
test_answer = ["Michelle Obama", "m.025s5v9"]  # Freebase ID

retrieved = retriever.retrieve(test_question)
gnn_output = gnn_model.encode(retrieved, test_question)

visualize_gnn_attention(
    subgraph=retrieved.subgraph,
    attention_scores=gnn_output.attention_scores,
    answer_entities=test_answer,
    question=test_question,
    top_k=20,
)

## Cell 16: Memory Check

In [None]:
# Check GPU memory usage
if torch.cuda.is_available():
    print("GPU Memory Summary:")
    print(f"  Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
    print(f"  Reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
    print(f"  Max allocated: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")

    # Verify no memory leak
    assert torch.cuda.memory_allocated() / 1e9 < 14.0, "Memory leak detected!"
    print("\n✓ Memory usage within acceptable range")
else:
    print("GPU not available")