# Graph Transformer Cancer Classification Demo

This notebook demonstrates the Graph Transformer model for cancer type classification:
- Model architecture and training
- Model evaluation
- Full interpretability analysis (SHAP, attention, PCA baseline)
- Comparative analysis and result generation


In [7]:
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import json
import os

# Determine project root (notebook is in notebooks/ subdirectory)
# In Jupyter, __file__ doesn't exist, so use getcwd()
cwd = Path(os.getcwd())
if (cwd / "src").exists() and (cwd / "interpretability").exists():
    PROJECT_ROOT = cwd
elif (cwd.parent / "src").exists() and (cwd.parent / "interpretability").exists():
    PROJECT_ROOT = cwd.parent
else:
    # Fallback: assume we're in notebooks/ directory
    PROJECT_ROOT = Path().resolve().parent

sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / "src"))
sys.path.insert(0, str(PROJECT_ROOT / "interpretability"))

import config
from model import GraphTransformerClassifier
from graph_prior import load_graph_prior, get_graph_features_as_tensors
from dataset import load_and_preprocess_data, create_dataloaders
from utils import load_trained_model, load_data, get_output_dirs, DEFAULT_PATHS

plt.rcParams['figure.dpi'] = 100
sns.set_style('whitegrid')

print(f"Project root: {PROJECT_ROOT}")
print(f"PyTorch version: {torch.__version__}")
print(f"Working directory: {cwd}")


Project root: /Users/leophelan/Code/projects/graph_transformer_proteomics/CleanedProject/Executable_Project_Code
PyTorch version: 2.8.0
Working directory: /Users/leophelan/Code/projects/graph_transformer_proteomics/CleanedProject/Executable_Project_Code/notebooks


## 1. Data Availability Check


In [8]:
def check_data_availability():
    """Check if required data files are available."""
    data_dir = PROJECT_ROOT / "data"
    csv_path = data_dir / "processed_datasets" / "tcga_pancan_rppa_compiled.csv"
    prior_path = data_dir / "priors" / "tcga_string_prior.npz"
    
    # Also check alternative locations
    if not csv_path.exists():
        csv_path = data_dir / "tcga_pancan_rppa_compiled.csv"
    if not prior_path.exists():
        prior_path = data_dir / "tcga_string_prior.npz"
    
    csv_available = csv_path.exists()
    prior_available = prior_path.exists()
    
    return csv_available and prior_available, csv_path if csv_available else None, prior_path if prior_available else None

DATA_AVAILABLE, CSV_PATH, PRIOR_PATH = check_data_availability()

if DATA_AVAILABLE:
    print("✓ Data files found - model can be trained")
    print(f"  CSV: {CSV_PATH}")
    print(f"  Prior: {PRIOR_PATH}")
else:
    print("⚠ Data files not found - will display model architecture only")
    print("  Place data files in data/processed_datasets/ and data/priors/")
    print("  See data/README.md for instructions")


✓ Data files found - model can be trained
  CSV: /Users/leophelan/Code/projects/graph_transformer_proteomics/CleanedProject/Executable_Project_Code/data/processed_datasets/tcga_pancan_rppa_compiled.csv
  Prior: /Users/leophelan/Code/projects/graph_transformer_proteomics/CleanedProject/Executable_Project_Code/data/priors/tcga_string_prior.npz


## 2. Model Architecture Display / Training


In [9]:
def display_model_architecture():
    """Display model architecture as formatted table."""
    print("=" * 70)
    print("Graph Transformer Classifier Architecture")
    print("=" * 70)
    print(f"\n{'Parameter':<30} {'Value':<40}")
    print("-" * 70)
    print(f"{'Embedding Dimension':<30} {config.MODEL['embedding_dim']:<40}")
    print(f"{'Number of Layers':<30} {config.MODEL['n_layers']:<40}")
    print(f"{'Attention Heads':<30} {config.MODEL['n_heads']:<40}")
    print(f"{'Feed-forward Dimension':<30} {config.MODEL['ffn_dim']:<40}")
    print(f"{'Dropout Rate':<30} {config.MODEL['dropout']:<40}")
    print(f"{'Graph Bias Scale (learnable)':<30} {config.MODEL['graph_bias_scale']:<40}")
    print(f"{'Positional Encoding Dim':<30} {config.MODEL['pe_dim']:<40}")
    print(f"{'Diffusion Kernel Beta':<30} {config.GRAPH_PRIOR['diffusion_beta']:<40}")
    print("\nTraining Parameters:")
    print("-" * 70)
    print(f"{'Learning Rate':<30} {config.TRAINING['learning_rate']:<40}")
    print(f"{'Weight Decay':<30} {config.TRAINING['weight_decay']:<40}")
    print(f"{'Batch Size':<30} {config.TRAINING['batch_size']:<40}")
    print(f"{'Max Epochs':<30} {config.TRAINING['max_epochs']:<40}")
    print(f"{'Early Stopping Patience':<30} {config.TRAINING['patience']:<40}")
    print(f"{'Gradient Clipping':<30} {config.TRAINING['grad_clip']:<40}")
    print("=" * 70)

if DATA_AVAILABLE:
    print("Training model from scratch...")
    # Import training functions
    import torch.nn as nn
    import torch.optim as optim
    from tqdm import tqdm
    from sklearn.metrics import accuracy_score
    
    # Define training functions inline (or import from training module)
    def train_epoch_simple(model, loader, criterion, optimizer, device):
        model.train()
        total_loss = 0
        all_preds = []
        all_labels = []
        for x, y in tqdm(loader, desc='Training'):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAINING['grad_clip'])
            optimizer.step()
            total_loss += loss.item()
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.cpu().numpy())
        return total_loss / len(loader), accuracy_score(all_labels, all_preds)
    
    def evaluate_simple(model, loader, criterion, device):
        model.eval()
        total_loss = 0
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for x, y in loader:
                x, y = x.to(device), y.to(device)
                logits = model(x)
                loss = criterion(logits, y)
                total_loss += loss.item()
                preds = torch.argmax(logits, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(y.cpu().numpy())
        accuracy = accuracy_score(all_labels, all_preds)
        return {'loss': total_loss / len(loader), 'accuracy': accuracy}
    
    # Load data and prior
    graph_prior = load_graph_prior(str(PRIOR_PATH))
    graph_tensors = get_graph_features_as_tensors(graph_prior, device='cpu')
    
    data_splits, label_info, scaler = load_and_preprocess_data(str(CSV_PATH), graph_prior['protein_cols'])
    train_loader, val_loader, test_loader = create_dataloaders(data_splits)
    
    # Initialize model
    model = GraphTransformerClassifier(
        n_proteins=graph_prior['A'].shape[0],
        n_classes=label_info['n_classes'],
        diffusion_kernel=graph_tensors['K'],
        positional_encodings=graph_tensors['PE'],
    )
    
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model parameters: {n_params:,}")
    
    # Training setup
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(
        model.parameters(),
        lr=config.TRAINING['learning_rate'],
        weight_decay=config.TRAINING['weight_decay']
    )
    
    # Train for a few epochs (full training would take longer)
    print("\nTraining (limited epochs for demo)...")
    model.train()
    for epoch in range(min(5, config.TRAINING['max_epochs'])):
        train_loss, train_acc = train_epoch_simple(model, train_loader, criterion, optimizer, 'cpu')
        val_metrics = evaluate_simple(model, val_loader, criterion, 'cpu')
        print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Acc={val_metrics['accuracy']:.4f}")
    
    # Save model
    checkpoint_dir = PROJECT_ROOT / "pretrained"
    checkpoint_dir.mkdir(exist_ok=True)
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'config': {'MODEL': config.MODEL, 'TRAINING': config.TRAINING},
        'label_info': label_info,
    }
    torch.save(checkpoint, checkpoint_dir / "best_model.pt")
    print(f"\nModel saved to {checkpoint_dir / 'best_model.pt'}")
else:
    display_model_architecture()


Training model from scratch...
Loaded prior: 198 proteins, 1184 edges
Loading data from /Users/leophelan/Code/projects/graph_transformer_proteomics/CleanedProject/Executable_Project_Code/data/processed_datasets/tcga_pancan_rppa_compiled.csv...
Loaded 7523 samples
Found 198 protein columns in CSV
Filtering samples: 7523/7523 have ≤50.0% missing values
After filtering: 7523 samples across 32 cancer types
Cancer type distribution:
CANCER_TYPE_ACRONYM
ACC      45
BLCA    343
BRCA    876
CESC    166
CHOL     30
COAD    346
DLBC     33
ESCA    125
GBM     231
HNSC    211
KICH     62
KIRC    455
KIRP    209
LGG     428
LIHC    179
LUAD    360
LUSC    317
MESO     63
OV      414
PAAD    122
PCPG     79
PRAD    350
READ    118
SARC    218
SKCM    329
STAD    354
TGCT    118
THCA    369
THYM     90
UCEC    423
UCS      48
UVM      12
Name: count, dtype: int64

Data splits:
  Train: 6394 samples (85.0%)
  Val:   751 samples (10.0%)
  Test:  378 samples (5.0%)

Handling missing values...
  train: 

Training: 100%|██████████| 100/100 [01:12<00:00,  1.37it/s]


Epoch 1: Train Loss=3.2258, Val Acc=0.1651


Training:  72%|███████▏  | 72/100 [00:53<00:20,  1.35it/s]


KeyboardInterrupt: 

## 3. Load Pretrained Model


In [None]:
# Load pretrained model (or use newly trained one)
try:
    model, graph_prior, label_info = load_trained_model(device='cpu')
    print("✓ Loaded pretrained model")
    print(f"  Classes: {label_info.get('n_classes', 'unknown')}")
except FileNotFoundError as e:
    print(f"⚠ Pretrained model not found: {e}")
    print("  Model will need to be trained first (requires data)")
    model = None


## 4. Model Evaluation


In [None]:
if model is not None and DATA_AVAILABLE:
    from sklearn.metrics import accuracy_score, f1_score, classification_report
    
    # Load test data
    data_splits, label_info, dataloaders = load_data(return_dataloaders=True, batch_size=32)
    test_loader = dataloaders[2]
    
    # Evaluate
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for x, y in test_loader:
            logits = model(x)
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.numpy())
            all_labels.extend(y.numpy())
    
    accuracy = accuracy_score(all_labels, all_preds)
    f1_macro = f1_score(all_labels, all_preds, average='macro')
    
    print(f"Test Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
    print(f"Test F1 (macro): {f1_macro:.4f} ({f1_macro*100:.2f}%)")
else:
    print("Skipping evaluation - model or data not available")


## 5. Interpretability Analysis

Run SHAP analysis, attention analysis, PCA baseline, and comparative analysis.


In [None]:
if model is not None and DATA_AVAILABLE:
    # Set up output directories
    plots_dir, results_dir = get_output_dirs(PROJECT_ROOT / "results")
    plots_dir.mkdir(parents=True, exist_ok=True)
    
    print("Running interpretability analyses...")
    print("This may take several minutes...")
    
    # Run SHAP analysis
    print("\n1. Running SHAP analysis (100 test samples)...")
    try:
        from shap_analysis import main as shap_main
        shap_main()
        print("✓ SHAP analysis complete")
    except Exception as e:
        print(f"✗ SHAP analysis failed: {e}")
        import traceback
        traceback.print_exc()
    
    # Run attention analysis
    print("\n2. Running attention analysis...")
    try:
        from attention_analysis import main as attn_main
        attn_main()
        print("✓ Attention analysis complete")
    except Exception as e:
        print(f"✗ Attention analysis failed: {e}")
        import traceback
        traceback.print_exc()
    
    # Run PCA baseline
    print("\n3. Training PCA95 baseline...")
    try:
        from pca_baseline import main as pca_main
        pca_main()
        print("✓ PCA baseline complete")
    except Exception as e:
        print(f"✗ PCA baseline failed: {e}")
        import traceback
        traceback.print_exc()
    
    print("\n✓ All interpretability analyses complete")
    print(f"Results saved to: {results_dir}")
else:
    print("Skipping interpretability - model or data not available")


## 6. Generate Result Figures

Generate all figures matching the results section.


In [None]:
if model is not None and DATA_AVAILABLE:
    # Load pre-computed results
    plots_dir, _ = get_output_dirs(PROJECT_ROOT / "results")
    
    # Check if results exist
    shap_results = plots_dir / "SHAP_Plots" / "top_proteins.json"
    pca_results = plots_dir / "PCA_Cox_Plots" / "top_proteins.json"
    attn_results = plots_dir / "Attention_Plots" / "attention_stats.txt"
    
    if shap_results.exists() and pca_results.exists():
        print("Loading analysis results...")
        
        # Load SHAP results
        with open(shap_results, 'r') as f:
            shap_data = json.load(f)
        shap_proteins = [p['protein'] for p in shap_data]
        shap_importance = np.array([p['importance'] for p in shap_data])
        
        # Load PCA results
        with open(pca_results, 'r') as f:
            pca_data = json.load(f)
        pca_proteins = [p['protein'] for p in pca_data]
        pca_importance = np.array([p['importance'] for p in pca_data])
        
        # Load PPI network
        prior_data = np.load(DEFAULT_PATHS['prior'], allow_pickle=True)
        A = prior_data['A']
        all_proteins = prior_data['protein_cols'].tolist()
        
        print(f"Loaded {len(shap_proteins)} SHAP proteins")
        print(f"Loaded {len(pca_proteins)} PCA proteins")
        print(f"PPI network: {A.shape[0]} proteins")
        
        # Generate comparison plots
        comparison_dir = plots_dir / "Model_Comparison_Plots"
        comparison_dir.mkdir(parents=True, exist_ok=True)
        
        # Top 20 overlap
        shap_top20 = set(shap_proteins[:20])
        pca_top20 = set(pca_proteins[:20])
        overlap = shap_top20 & pca_top20
        
        print(f"\nTop 20 Overlap: {len(overlap)}/{20} ({len(overlap)/20*100:.0f}%)")
        print(f"Overlap proteins: {sorted(overlap)}")
        
        # SHAP vs Attention correlation (if attention results exist)
        if attn_results.exists():
            print("\n✓ All analyses complete - results available in plots/")
        else:
            print("\n⚠ Some analyses may still be running")
    else:
        print("Analysis results not yet available - run interpretability cells first")
else:
    print("Skipping result generation - model or data not available")
