# GNN Model Development and Testing

This notebook implements and tests a Graph Neural Network (GNN) model for drug response prediction with pathway-aware architecture.

**Objectives:**
- Test GNN model in fallback mode (without torch_geometric)
- Load and prepare pathway data
- Test on small subset
- Verify backward compatibility with existing models
- Compare performance

**Author:** Aaron Yu  
**Date:** November 8, 2025

## 1. Setup Environment and Dependencies

In [None]:
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import json
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
from scipy.stats import spearmanr
import warnings
warnings.filterwarnings('ignore')

# Add project paths
PROJECT_ROOT = Path.cwd().parent
sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / "src"))

print(f"Project root: {PROJECT_ROOT}")
print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
# Check for torch_geometric
try:
    import torch_geometric
    print(f"PyTorch Geometric version: {torch_geometric.__version__}")
    HAS_PYGEOMETRIC = True
except ImportError:
    print("PyTorch Geometric not installed - will use fallback mode")
    print("To install: pip install torch-geometric")
    HAS_PYGEOMETRIC = False

# Import our models
from src.models.gnn_model import PathwayAwareGNN, GenomicEncoder, DrugEncoder, MultiTaskHead
from src.models.deep_learning import ImprovedDrugResponseModel
from src.data.pathway_utils import PathwayGraphBuilder

print("\nModules imported successfully!")

## 2. Load Existing Data and Models

In [None]:
# Load existing trained model for comparison
CHECKPOINT_DIR = PROJECT_ROOT / "checkpoints_stratified" / "previous_treatment"

# Load config
with open(CHECKPOINT_DIR / "config.json", 'r') as f:
    dl_config = json.load(f)

print("Existing DL model config:")
for key, value in dl_config.items():
    print(f"  {key}: {value}")

# Check if trained model exists
checkpoint_path = CHECKPOINT_DIR / "best_model.pt"
if checkpoint_path.exists():
    print(f"\nExisting model found: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
    print(f"Model trained for {checkpoint.get('epoch', 'unknown')} epochs")
    if 'metrics' in checkpoint:
        print(f"Best metrics: {checkpoint['metrics']}")
else:
    print("\nWarning: No trained model found")

## 3. Load Pathway Graph Data

In [None]:
# Load pathway graph
pathway_graph_dir = PROJECT_ROOT / "data" / "pathway_graphs"
pathway_graph_file = pathway_graph_dir / "kegg_human_pathway_graph.pt"

if pathway_graph_file.exists():
    builder = PathwayGraphBuilder(cache_dir=str(pathway_graph_dir))
    edge_index, gene_to_idx, idx_to_gene = builder.load_graph("kegg_human_pathway_graph.pt")
    
    print(f"\nPathway graph loaded:")
    print(f"  Number of genes: {len(gene_to_idx)}")
    print(f"  Number of edges: {edge_index.shape[1]}")
    print(f"  Edge index shape: {edge_index.shape}")
    print(f"\nSample genes: {list(gene_to_idx.keys())[:5]}")
else:
    print("Warning: Pathway graph not found")
    print(f"Expected location: {pathway_graph_file}")
    print("Run scripts/download_pathways.py to create it")
    edge_index = None
    gene_to_idx = {}
    idx_to_gene = {}

## 4. Test GNN Model in Fallback Mode

In [None]:
# Create GNN model in fallback mode (without graph structure)
print("Testing GNN model in fallback mode...")

gnn_model_fallback = PathwayAwareGNN(
    genomic_dim=1318,
    drug_fp_dim=8192,
    embed_dim=256,
    use_pathway_graph=False,  # Fallback mode
    dropout_genomic=0.4,
    dropout_drug=0.4,
    dropout_head=0.3
)

print("\nModel created successfully!")
print(f"Total parameters: {sum(p.numel() for p in gnn_model_fallback.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in gnn_model_fallback.parameters() if p.requires_grad):,}")

# Test forward pass with dummy data
batch_size = 4
dummy_genomic = torch.randn(batch_size, 1318)
dummy_drug = torch.randn(batch_size, 8192)

print("\nTesting forward pass...")
with torch.no_grad():
    pfs_pred = gnn_model_fallback(dummy_genomic, dummy_drug)
    print(f"Input shapes: genomic={dummy_genomic.shape}, drug={dummy_drug.shape}")
    print(f"Output shape: {pfs_pred.shape}")
    print(f"Output values: {pfs_pred.squeeze().numpy()}")

print("\nFallback mode test: PASSED")

## 5. Test Multi-Task Outputs

In [None]:
# Test multi-task predictions and attention
print("Testing multi-task predictions with attention...")

with torch.no_grad():
    pfs_pred, additional_outputs = gnn_model_fallback(
        dummy_genomic, 
        dummy_drug,
        return_attention=True,
        return_all_tasks=True
    )
    
    print(f"\nMain prediction (PFS): {pfs_pred.shape}")
    print(f"PFS values: {pfs_pred.squeeze().numpy()}")
    
    if additional_outputs:
        if 'resistance_mechanism' in additional_outputs:
            resistance = additional_outputs['resistance_mechanism']
            print(f"\nResistance mechanism prediction: {resistance.shape}")
            print(f"Resistance classes (logits): {resistance[0].numpy()}")
            
            # Apply softmax to get probabilities
            resistance_probs = torch.softmax(resistance, dim=1)
            print(f"Resistance probabilities: {resistance_probs[0].numpy()}")
        
        if 'pathway_activity' in additional_outputs:
            pathway = additional_outputs['pathway_activity']
            print(f"\nPathway activity prediction: {pathway.shape}")
            print(f"Top 5 pathway activities: {pathway[0, :5].numpy()}")
        
        if 'attention_weights' in additional_outputs:
            attention = additional_outputs['attention_weights']
            print(f"\nAttention weights: {attention.shape}")

print("\nMulti-task test: PASSED")

## 6. Compare with Existing Model Architecture

In [None]:
# Create existing DL model for comparison
print("Creating existing DL model...")

existing_model = ImprovedDrugResponseModel(
    genomic_dim=dl_config.get("genomic_dim", 1318),
    drug_fp_dim=dl_config.get("drug_fp_dim", 8192),
    embed_dim=dl_config.get("embed_dim", 256),
    dropout_genomic=dl_config.get("dropout_genomic", 0.5),
    dropout_drug=dl_config.get("dropout_drug", 0.4),
    dropout_head=dl_config.get("dropout_head", 0.5)
)

existing_params = sum(p.numel() for p in existing_model.parameters())
gnn_params = sum(p.numel() for p in gnn_model_fallback.parameters())

print(f"\nModel Comparison:")
print(f"  Existing DL model parameters: {existing_params:,}")
print(f"  New GNN model parameters: {gnn_params:,}")
print(f"  Parameter difference: {gnn_params - existing_params:,} ({(gnn_params/existing_params - 1)*100:.1f}% {'increase' if gnn_params > existing_params else 'decrease'})")

# Test both models with same input
print("\nTesting both models with identical input...")
with torch.no_grad():
    existing_pred = existing_model(dummy_genomic, dummy_drug)
    gnn_pred = gnn_model_fallback(dummy_genomic, dummy_drug)
    
    print(f"Existing model predictions: {existing_pred.squeeze().numpy()}")
    print(f"GNN model predictions: {gnn_pred.squeeze().numpy()}")
    print(f"\nNote: Different predictions expected (models not trained yet)")

## 7. Verify Backward Compatibility

In [None]:
# Verify GNN model can handle same input format as existing model
print("Verifying backward compatibility...")

# Test 1: Same input dimensions
assert dummy_genomic.shape[1] == 1318, "Genomic dimension mismatch"
assert dummy_drug.shape[1] == 8192, "Drug dimension mismatch"
print("Test 1 - Input dimensions: PASSED")

# Test 2: Same output dimensions
assert existing_pred.shape == gnn_pred.shape, "Output shape mismatch"
print(f"Test 2 - Output dimensions: PASSED (both output shape {gnn_pred.shape})")

# Test 3: Save and load GNN model
print("\nTest 3 - Model serialization...")
test_checkpoint_path = PROJECT_ROOT / "checkpoints_stratified" / "test_gnn_model.pt"
test_checkpoint_path.parent.mkdir(parents=True, exist_ok=True)

gnn_model_fallback.save_checkpoint(str(test_checkpoint_path))
print(f"  Saved to: {test_checkpoint_path}")

loaded_model = PathwayAwareGNN.from_checkpoint(str(test_checkpoint_path))
print("  Loaded successfully")

# Test loaded model
with torch.no_grad():
    loaded_pred = loaded_model(dummy_genomic, dummy_drug)
    
assert torch.allclose(gnn_pred, loaded_pred, atol=1e-5), "Loaded model predictions differ"
print("  Predictions match: PASSED")

# Clean up test file
test_checkpoint_path.unlink()
print("  Cleaned up test file")

print("\nAll backward compatibility tests: PASSED")

## 8. Summary and Next Steps

**Completed:**
- GNN model architecture created successfully
- Fallback mode (MLP) tested and working
- Multi-task predictions implemented
- Attention mechanism functional
- Backward compatibility verified
- Model serialization working

**Next Steps:**
1. Create training script (train_gnn.py)
2. Train model on full dataset
3. Compare performance with existing models
4. Integrate into app.py with model selector

**Model Features:**
- Pathway-aware architecture (ready for graph input)
- Multi-task learning (PFS + resistance + pathway activity)
- Cross-attention between genomic and drug features
- Compatible with existing infrastructure
- Can run with or without torch_geometric