# GNN-MAPS: Hybrid Model (Protein + Spatial Features)

**Option 3: Hybrid MLP+GNN Architecture**

This notebook implements a **hybrid model** that combines:
- ‚úÖ **MLP branch**: Learns from protein markers (like MAPS)
- ‚úÖ **GNN branch**: Learns from spatial neighborhood context
- ‚úÖ **Fusion**: Concatenates both representations for classification

## Hypothesis:
Combining protein features AND spatial context should give the best of both worlds!

## Expected Results:
- Better than pure MLP (adds spatial context)
- Better than pure GNN (protein features are strong)
- Should handle spatial split better than pure GNN

## Configuration:
- **Hidden Dimension:** 512 (matches MAPS)
- **MLP Branch:** 4 layers (MAPS architecture)
- **GNN Branch:** 2 GraphSAGE layers
- **Fusion:** Concatenate [MLP features + GNN features]
- **Split:** Spatial 80/20 (same as gnn-maps-3)

---

In [None]:
# Install PyTorch Geometric and its dependencies (Kaggle-compatible)
import sys
import torch

print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")

!pip install -q torch-geometric

import torch
pytorch_version = torch.__version__.split('+')[0]
cuda_version = torch.version.cuda.replace('.', '') if torch.cuda.is_available() else 'cpu'

print(f"\nInstalling PyG extensions for PyTorch {pytorch_version} and CUDA {cuda_version}...")

if torch.cuda.is_available():
    !pip install -q torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-{pytorch_version}+cu{cuda_version}.html
else:
    !pip install -q torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-{pytorch_version}+cpu.html

print("\n‚úÖ PyTorch Geometric installation complete!")

In [None]:
# Quick verification test
try:
    import torch
    from torch_geometric.nn import SAGEConv
    
    test_conv = SAGEConv(16, 32)
    print("‚úÖ PyTorch Geometric is working correctly!")
    print(f"   Test layer created: {test_conv}")
    
except Exception as e:
    print(f"‚ùå Error: {e}")
    print("Please re-run the installation cell above.")

In [None]:
# Import all necessary libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os

for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import f1_score

print("‚úÖ All libraries loaded successfully!")
print(f"   - PyTorch version: {torch.__version__}")
print(f"   - CUDA available: {torch.cuda.is_available()}")
print(f"   - Device: {'GPU' if torch.cuda.is_available() else 'CPU'}")

# 1. Load Data

In [None]:
df = pd.read_csv("/kaggle/input/chl-codex-annotated/cHL_CODEX_annotation.csv")
pd.set_option('display.max_columns', None)
display(df.head())
print(f"\nDataset shape: {df.shape}")

# 2. Data Preparation & Graph Construction

In [None]:
print("=" * 80)
print("OPTION 3: HYBRID MODEL (PROTEIN + SPATIAL FEATURES)")
print("=" * 80)

# Column definitions
x_col = 'X_cent'
y_col = 'Y_cent'  
label_col = 'cellType'

marker_cols = [
    'BCL.2', 'CCR6', 'CD11b', 'CD11c', 'CD15', 'CD16', 'CD162', 'CD163', 
    'CD2', 'CD20', 'CD206', 'CD25', 'CD30', 'CD31', 'CD4', 'CD44', 
    'CD45RA', 'CD45RO', 'CD45', 'CD5', 'CD56', 'CD57', 'CD68', 'CD69', 
    'CD7', 'CD8', 'Collagen.4', 'Cytokeratin', 'DAPI.01', 'EGFR', 
    'FoxP3', 'Granzyme.B', 'HLA.DR', 'IDO.1', 'LAG.3', 'MCT', 'MMP.9', 
    'MUC.1', 'PD.1', 'PD.L1', 'Podoplanin', 'T.bet', 'TCR.g.d', 'TCRb', 
    'Tim.3', 'VISA', 'Vimentin', 'a.SMA', 'b.Catenin'
]

print(f"\n‚úÖ Using {len(marker_cols)} protein markers")
print(f"‚úÖ Total cells: {len(df):,}")
print(f"‚úÖ Cell types: {df[label_col].nunique()}")

# Normalize features
scaler = StandardScaler()
X_normalized = scaler.fit_transform(df[marker_cols].values)
x = torch.tensor(X_normalized, dtype=torch.float)

# Encode labels
unique_labels = sorted(df[label_col].unique())
label_map = {name: i for i, name in enumerate(unique_labels)}
y = torch.tensor(df[label_col].map(label_map).values, dtype=torch.long)
num_classes = len(label_map)

print(f"\nüìä {num_classes} cell types encoded")

# Build KNN graph
print(f"\nüîó Building KNN Graph (K=5)...")
k_neighbors = 5
coords = df[[x_col, y_col]].values

nbrs = NearestNeighbors(n_neighbors=k_neighbors + 1, algorithm='ball_tree').fit(coords)
distances, indices = nbrs.kneighbors(coords)

source_nodes = np.repeat(np.arange(len(df)), k_neighbors)
target_nodes = indices[:, 1:].flatten()
edge_index = torch.tensor([source_nodes, target_nodes], dtype=torch.long)

print(f"‚úÖ Graph: {edge_index.shape[1]:,} edges")

# Spatial split (80/20 by X-axis)
print(f"\n" + "=" * 80)
print("SPATIAL TRAIN/TEST SPLIT (80/20 by X-axis)")
print("=" * 80)

x_min = df[x_col].min()
x_max = df[x_col].max()
x_threshold = x_min + (0.8 * (x_max - x_min))

train_mask = torch.tensor(df[x_col].values <= x_threshold, dtype=torch.bool)
test_mask = torch.tensor(df[x_col].values > x_threshold, dtype=torch.bool)

print(f"\n‚úÖ SPATIAL SPLIT")
print(f"   Train: {train_mask.sum():,} cells ({100*train_mask.float().mean():.1f}%)")
print(f"   Test:  {test_mask.sum():,} cells ({100*test_mask.float().mean():.1f}%)")
print(f"   This is the SAME split as gnn-maps-3 (for fair comparison)")

# Create PyG Data object
data = Data(
    x=x,
    edge_index=edge_index,
    y=y,
    train_mask=train_mask,
    test_mask=test_mask
)

print(f"\n{data}")

# 3. Hybrid Model Architecture

In [None]:
print("=" * 80)
print("HYBRID MODEL: MLP + GNN FUSION")
print("=" * 80)

class HybridGNN(torch.nn.Module):
    """
    Hybrid Model: Combines MLP (protein features) + GNN (spatial context)
    
    Architecture:
    1. MLP Branch: 4-layer network (MAPS-style) processes protein markers
    2. GNN Branch: 2-layer GraphSAGE aggregates spatial neighborhood info
    3. Fusion: Concatenate [MLP features + GNN features] ‚Üí classifier
    
    This gives the best of both worlds:
    - Strong protein marker features (like MAPS)
    - Rich spatial context (like pure GNN)
    """
    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.1):
        super().__init__()
        
        # MLP Branch (processes protein markers independently)
        self.mlp_fc1 = torch.nn.Linear(in_channels, hidden_channels)
        self.mlp_fc2 = torch.nn.Linear(hidden_channels, hidden_channels)
        self.mlp_fc3 = torch.nn.Linear(hidden_channels, hidden_channels)
        self.mlp_fc4 = torch.nn.Linear(hidden_channels, hidden_channels)
        
        # GNN Branch (aggregates spatial neighborhood)
        self.gnn_conv1 = SAGEConv(in_channels, hidden_channels)
        self.gnn_conv2 = SAGEConv(hidden_channels, hidden_channels)
        
        # Fusion layer (combines both representations)
        self.classifier = torch.nn.Linear(hidden_channels * 2, out_channels)
        
        self.dropout = dropout
    
    def forward(self, x, edge_index):
        # MLP Branch: Process protein markers
        mlp_out = F.relu(self.mlp_fc1(x))
        mlp_out = F.dropout(mlp_out, p=self.dropout, training=self.training)
        mlp_out = F.relu(self.mlp_fc2(mlp_out))
        mlp_out = F.dropout(mlp_out, p=self.dropout, training=self.training)
        mlp_out = F.relu(self.mlp_fc3(mlp_out))
        mlp_out = F.dropout(mlp_out, p=self.dropout, training=self.training)
        mlp_out = F.relu(self.mlp_fc4(mlp_out))
        mlp_out = F.dropout(mlp_out, p=self.dropout, training=self.training)
        
        # GNN Branch: Aggregate spatial context
        gnn_out = self.gnn_conv1(x, edge_index)
        gnn_out = F.relu(gnn_out)
        gnn_out = F.dropout(gnn_out, p=self.dropout, training=self.training)
        gnn_out = self.gnn_conv2(gnn_out, edge_index)
        gnn_out = F.relu(gnn_out)
        gnn_out = F.dropout(gnn_out, p=self.dropout, training=self.training)
        
        # Fusion: Concatenate both representations
        combined = torch.cat([mlp_out, gnn_out], dim=1)
        
        # Final classification
        out = self.classifier(combined)
        return F.log_softmax(out, dim=1)

# Also keep baseline models for comparison
class MLP(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.1):
        super().__init__()
        self.fc1 = torch.nn.Linear(in_channels, hidden_channels)
        self.fc2 = torch.nn.Linear(hidden_channels, hidden_channels)
        self.fc3 = torch.nn.Linear(hidden_channels, hidden_channels)
        self.fc4 = torch.nn.Linear(hidden_channels, hidden_channels)
        self.classifier = torch.nn.Linear(hidden_channels, out_channels)
        self.dropout = dropout

    def forward(self, x, edge_index=None):
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.fc2(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.fc3(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.fc4(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.classifier(x)
        return F.log_softmax(x, dim=1)

class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.1):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)
        self.dropout = dropout

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

print("\n‚úÖ Models defined:")
print("   1. MLP (baseline - protein only)")
print("   2. GraphSAGE (baseline - spatial only)")
print("   3. HybridGNN (protein + spatial)")
print("\nüî¨ Hybrid Architecture:")
print("   Input (49 markers) ‚Üí")
print("   ‚îú‚îÄ MLP: 49 ‚Üí 512 ‚Üí 512 ‚Üí 512 ‚Üí 512")
print("   ‚îî‚îÄ GNN: 49 ‚Üí 512 ‚Üí 512")
print("   Concat: [512 MLP + 512 GNN] = 1024")
print("   Classifier: 1024 ‚Üí num_classes")

# 4. Training (All Three Models)

In [None]:
# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üíª Device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")

# Hyperparameters
hidden_dim = 512
dropout = 0.1
lr = 0.001
max_epochs = 500
min_epochs = 250
patience = 100

print(f"\n‚öôÔ∏è  Hyperparameters: hidden={hidden_dim}, epochs={max_epochs}")

# Training functions
def train_epoch_with_loss(model, data, optimizer):
    model.train()
    data = data.to(device)
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def evaluate(model, data, mask):
    model.eval()
    data_device = data.to(device)
    out = model(data_device.x, data_device.edge_index)
    pred = out.argmax(dim=1)
    
    y_true = data_device.y[mask].cpu().numpy()
    y_pred = pred[mask].cpu().numpy()
    
    acc = (pred[mask] == data_device.y[mask]).sum().item() / mask.sum().item()
    f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)
    
    return acc, f1

# Function to train any model
def train_model(model, model_name, data, lr, max_epochs, min_epochs, patience):
    print("\n" + "=" * 80)
    print(f"TRAINING {model_name}")
    print("=" * 80)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    print(f"\nEpoch | Loss    | Train Acc | Train F1 | Test Acc | Test F1 | Status")
    print("-" * 80)
    
    best_f1 = 0
    best_epoch = 0
    patience_counter = 0
    
    for epoch in range(1, max_epochs + 1):
        loss = train_epoch_with_loss(model, data, optimizer)
        
        if epoch % 10 == 0 or epoch == 1:
            train_acc, train_f1 = evaluate(model, data, data.train_mask)
            test_acc, test_f1 = evaluate(model, data, data.test_mask)
            
            status = ""
            if test_f1 > best_f1:
                best_f1 = test_f1
                best_epoch = epoch
                patience_counter = 0
                status = "‚úÖ BEST"
            else:
                patience_counter += 10
                if epoch >= min_epochs and patience_counter >= patience:
                    status = "üõë EARLY STOP"
                
            print(f'{epoch:5d} | {loss:7.4f} | {train_acc:9.4f} | {train_f1:8.4f} | '
                  f'{test_acc:8.4f} | {test_f1:7.4f} | {status}')
            
            if epoch >= min_epochs and patience_counter >= patience:
                print(f"\n‚è∏Ô∏è  Early stopping!")
                break
    
    print(f"\n‚úÖ {model_name} Complete! Best F1: {best_f1:.4f} (epoch {best_epoch})")
    return best_f1, best_epoch

# Train all three models
print("\n" + "=" * 80)
print("TRAINING ALL MODELS FOR COMPARISON")
print("=" * 80)

# 1. MLP Baseline
mlp_model = MLP(len(marker_cols), hidden_dim, num_classes, dropout).to(device)
mlp_f1, mlp_epoch = train_model(mlp_model, "MLP (Baseline)", data, lr, max_epochs, min_epochs, patience)

# 2. GraphSAGE Baseline
gnn_model = GraphSAGE(len(marker_cols), hidden_dim, num_classes, dropout).to(device)
gnn_f1, gnn_epoch = train_model(gnn_model, "GraphSAGE (Baseline)", data, lr, max_epochs, min_epochs, patience)

# 3. Hybrid Model (STAR OF THE SHOW!)
hybrid_model = HybridGNN(len(marker_cols), hidden_dim, num_classes, dropout).to(device)
hybrid_f1, hybrid_epoch = train_model(hybrid_model, "HybridGNN (Protein+Spatial)", data, lr, max_epochs, min_epochs, patience)

# ==========================================
# FINAL COMPARISON
# ==========================================
print("\n" + "=" * 80)
print("FINAL RESULTS: THREE-WAY COMPARISON")
print("=" * 80)

# Get final test scores
mlp_test_acc, mlp_test_f1 = evaluate(mlp_model, data, data.test_mask)
gnn_test_acc, gnn_test_f1 = evaluate(gnn_model, data, data.test_mask)
hybrid_test_acc, hybrid_test_f1 = evaluate(hybrid_model, data, data.test_mask)

print(f"\n{'Model':<30s} {'Accuracy':>12s} {'F1-Score':>12s} {'vs MAPS':>12s}")
print("-" * 75)
print(f"{'MAPS (Paper Baseline)':<30s} {'N/A':>12s} {'0.9000':>12s} {'1.00x':>12s}")
print(f"{'MLP (Protein Only)':<30s} {mlp_test_acc:>12.4f} {mlp_test_f1:>12.4f} {mlp_test_f1/0.90:>11.2f}x")
print(f"{'GraphSAGE (Spatial Only)':<30s} {gnn_test_acc:>12.4f} {gnn_test_f1:>12.4f} {gnn_test_f1/0.90:>11.2f}x")
print(f"{'HybridGNN (Protein+Spatial)':<30s} {hybrid_test_acc:>12.4f} {hybrid_test_f1:>12.4f} {hybrid_test_f1/0.90:>11.2f}x")
print("-" * 75)

# Analysis
print(f"\nüìä Performance Analysis:")
print(f"   Hybrid vs MLP:  {(hybrid_test_f1 - mlp_test_f1)*100:+.2f} pp ({((hybrid_test_f1 - mlp_test_f1)/mlp_test_f1)*100:+.1f}%)")
print(f"   Hybrid vs GNN:  {(hybrid_test_f1 - gnn_test_f1)*100:+.2f} pp ({((hybrid_test_f1 - gnn_test_f1)/gnn_test_f1)*100:+.1f}%)")
print(f"   Hybrid vs MAPS: {(hybrid_test_f1 - 0.90)*100:+.2f} pp ({((hybrid_test_f1 - 0.90)/0.90)*100:+.1f}%)")

# Determine winner
best_model = max([(mlp_test_f1, 'MLP'), (gnn_test_f1, 'GraphSAGE'), (hybrid_test_f1, 'Hybrid')], key=lambda x: x[0])
print(f"\nüèÜ Best Model: {best_model[1]} with {best_model[0]:.1%} F1")

if hybrid_test_f1 > max(mlp_test_f1, gnn_test_f1):
    print(f"\n‚úÖ HYBRID MODEL WINS!")
    print(f"   Combining protein + spatial features beats both baselines!")
    if hybrid_test_f1 > 0.90:
        print(f"   üéâ AND IT BEATS MAPS (90%)!")
elif mlp_test_f1 > gnn_test_f1:
    print(f"\nüìä MLP wins (protein features stronger than spatial for this split)")
else:
    print(f"\nüìä GNN wins (spatial context helps despite distribution shift)")

# Visualization
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Bar chart comparison
ax = axes[0]
models = ['MAPS\n(Paper)', 'MLP\n(Protein)', 'GNN\n(Spatial)', 'Hybrid\n(Both)']
f1_scores = [0.90, mlp_test_f1, gnn_test_f1, hybrid_test_f1]
colors = ['#2ecc71', '#3498db', '#e74c3c', '#9b59b6']

bars = ax.bar(models, f1_scores, color=colors, edgecolor='black', linewidth=2, width=0.6)
ax.set_ylabel('F1-Score', fontsize=14, fontweight='bold')
ax.set_title('Hybrid Model: Protein + Spatial Features', fontsize=16, fontweight='bold')
ax.set_ylim([0, 1.0])
ax.axhline(y=0.90, color='green', linestyle='--', linewidth=2, alpha=0.5, label='MAPS Target')
ax.grid(axis='y', alpha=0.3)
ax.legend(fontsize=11)

for bar in bars:
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height,
            f'{height:.4f}',
            ha='center', va='bottom', fontsize=13, fontweight='bold')

# Improvement breakdown
ax2 = axes[1]
comparisons = ['MLP vs\nMAPS', 'GNN vs\nMAPS', 'Hybrid vs\nMAPS', 'Hybrid vs\nMLP', 'Hybrid vs\nGNN']
improvements = [
    (mlp_test_f1 - 0.90) * 100,
    (gnn_test_f1 - 0.90) * 100,
    (hybrid_test_f1 - 0.90) * 100,
    (hybrid_test_f1 - mlp_test_f1) * 100,
    (hybrid_test_f1 - gnn_test_f1) * 100
]
colors2 = ['#3498db' if v >= 0 else '#e67e22' for v in improvements]

bars2 = ax2.bar(comparisons, improvements, color=colors2, edgecolor='black', linewidth=1.5)
ax2.set_ylabel('Improvement (%)', fontsize=14, fontweight='bold')
ax2.set_title('Performance Improvements', fontsize=16, fontweight='bold')
ax2.axhline(y=0, color='black', linestyle='-', linewidth=1)
ax2.grid(axis='y', alpha=0.3)

for bar in bars2:
    height = bar.get_height()
    ax2.text(bar.get_x() + bar.get_width()/2., height,
            f'{height:+.1f}%',
            ha='center', va='bottom' if height > 0 else 'top', fontsize=11, fontweight='bold')

plt.tight_layout()
plt.show()

print("\n" + "=" * 80)
print("üöÄ ANALYSIS COMPLETE!")
print("=" * 80)
print(f"\nüíª Hardware: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
print(f"‚öôÔ∏è  Config: Spatial Split, Hybrid Architecture, hidden=512")
print(f"üìä Final Results:")
print(f"   MLP:      {mlp_test_f1:.1%}")
print(f"   GNN:      {gnn_test_f1:.1%}")
print(f"   Hybrid:   {hybrid_test_f1:.1%}")
print(f"   MAPS:     90.0%")