# GATWithTransformerFusion Training

This notebook demonstrates how to train the GATWithTransformerFusion model for end-to-end RNA to ADT mapping using a unified architecture that combines Graph Attention Networks (GAT) with Transformer fusion layers.

## Overview
1. Setup and imports
2. Load and preprocess data
3. Initialize GATWithTransformerFusion model
4. Train the model with end-to-end optimization
5. Evaluate performance
6. Save trained model
7. Make predictions on new data


## 1. Setup and Imports


In [None]:
import sys
import os
import torch
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# Add project root to Python path
current_dir = os.getcwd()
project_root = os.path.dirname(current_dir)
if project_root not in sys.path:
    sys.path.insert(0, project_root)

print("=== Setup Information ===")
print(f"Current directory: {current_dir}")
print(f"Project root: {project_root}")
print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name()}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Check current working directory and fix paths
print("=== Path Debugging ===")
print(f"Current working directory: {os.getcwd()}")
print(f"Project root: {project_root}")

# Check if scripts directory exists
scripts_path = os.path.join(project_root, 'scripts')
print(f"Scripts directory exists: {os.path.exists(scripts_path)}")

if os.path.exists(scripts_path):
    print("Scripts directory contents:")
    for item in os.listdir(scripts_path):
        print(f"  - {item}")

# Check if model directory exists
model_path = os.path.join(scripts_path, 'model')
print(f"Model directory exists: {os.path.exists(model_path)}")

if os.path.exists(model_path):
    print("Model directory contents:")
    for item in os.listdir(model_path):
        print(f"  - {item}")

# Import project modules
try:
    from scripts.data_provider.graph_data_builder import build_pyg_data
    from scripts.trainer.gat_trainer import train_gat_transformer_fusion
    from scripts.model.doNET import GATWithTransformerFusion
    from scripts.data_provider.data_preprocessing import prepare_train_test_anndata
    print("✅ All imports successful!")
except ImportError as e:
    print(f"❌ Import error: {e}")
    print("Trying alternative import method...")
    
    # Alternative import method
    import importlib.util
    import sys
    
    # Add the project root to sys.path
    project_root = os.path.dirname(os.getcwd())
    if project_root not in sys.path:
        sys.path.insert(0, project_root)
    
    # Try importing again
    try:
        from scripts.data_provider.graph_data_builder import build_pyg_data
        from scripts.trainer.gat_trainer import train_gat_transformer_fusion
        from scripts.model.doNET import GATWithTransformerFusion
        from scripts.data_provider.data_preprocessing import prepare_train_test_anndata
        print("✅ Alternative imports successful!")
    except ImportError as e2:
        print(f"❌ Alternative import also failed: {e2}")
        print("Please check that all required files exist and the project structure is correct.")


In [None]:
# Test imports and model creation
print("=== Testing Imports and Model Creation ===")

try:
    # Test transformer models import
    from scripts.model.transformer_models import TransformerMapping
    print("✅ Transformer models imported successfully")
    
    # Test doNET import
    from scripts.model.doNET import GATWithTransformerFusion, TransformerFusion
    print("✅ GATWithTransformerFusion imported successfully")
    
    # Test model creation
    print("\n=== Testing Model Creation ===")
    
    # Test TransformerMapping
    transformer_mapping = TransformerMapping(
        input_dim=50, 
        output_dim=10, 
        d_model=64
    )
    print(f"✅ TransformerMapping created: {transformer_mapping}")
    
    # Test GATWithTransformerFusion
    gat_transformer = GATWithTransformerFusion(
        in_channels=50, 
        hidden_channels=32, 
        out_channels=10
    )
    print(f"✅ GATWithTransformerFusion created: {gat_transformer}")
    
    print("\n🎉 All imports and model creation tests passed!")
    
except Exception as e:
    print(f"❌ Error during testing: {e}")
    import traceback
    traceback.print_exc()


## 2. Load and Preprocess Data


In [None]:
# Load training data
print("=== Loading Training Data ===")

# Load RNA and ADT data
rna_adata = sc.read_h5ad("/projects/vanaja_lab/satya/Datasets/GSMControlRNA.h5ad")
adt_adata = sc.read_h5ad("/projects/vanaja_lab/satya/Datasets/ControlADT.h5ad")

print(f"RNA data shape: {rna_adata.shape}")
print(f"ADT data shape: {adt_adata.shape}")

# Ensure same number of cells
if rna_adata.n_obs != adt_adata.n_obs:
    print("Warning: RNA and ADT data have different number of cells")
    common_cells = rna_adata.obs_names.intersection(adt_adata.obs_names)
    rna_adata = rna_adata[common_cells]
    adt_adata = adt_adata[common_cells]
    print(f"Using {len(common_cells)} common cells")

print(f"Final RNA data shape: {rna_adata.shape}")
print(f"Final ADT data shape: {adt_adata.shape}")

# Basic preprocessing
print("\n=== Basic Preprocessing ===")
print("RNA data preprocessing...")
sc.pp.filter_genes(rna_adata, min_cells=10)
sc.pp.filter_cells(rna_adata, min_genes=200)
sc.pp.normalize_total(rna_adata, target_sum=1e4)
sc.pp.log1p(rna_adata)
sc.pp.highly_variable_genes(rna_adata, min_mean=0.0125, max_mean=3, min_disp=0.5)
rna_adata.raw = rna_adata
rna_adata = rna_adata[:, rna_adata.var.highly_variable]

print("ADT data preprocessing...")
sc.pp.filter_genes(adt_adata, min_cells=10)
sc.pp.filter_cells(adt_adata, min_genes=5)
sc.pp.normalize_total(adt_adata, target_sum=1e4)
sc.pp.log1p(adt_adata)

print(f"After preprocessing - RNA: {rna_adata.shape}, ADT: {adt_adata.shape}")

# Ensure same cells after preprocessing
common_cells = rna_adata.obs_names.intersection(adt_adata.obs_names)
rna_adata = rna_adata[common_cells]
adt_adata = adt_adata[common_cells]

print(f"Final common cells: {len(common_cells)}")
print("✅ Data loading and preprocessing complete!")


## 3. Convert to PyTorch Geometric Format


In [None]:
# Convert AnnData to PyTorch Geometric format
print("=== Converting to PyTorch Geometric Format ===")

# Convert RNA data
print("Converting RNA data...")
rna_pyg_data = build_pyg_data(rna_adata)
print(f"RNA PyG data: {rna_pyg_data}")

# Convert ADT data
print("Converting ADT data...")
adt_pyg_data = build_pyg_data(adt_adata)
print(f"ADT PyG data: {adt_pyg_data}")

# Verify data compatibility
print(f"\n=== Data Compatibility Check ===")
print(f"RNA nodes: {rna_pyg_data.num_nodes}")
print(f"ADT nodes: {adt_pyg_data.num_nodes}")
print(f"RNA features: {rna_pyg_data.x.shape[1]}")
print(f"ADT features: {adt_pyg_data.x.shape[1]}")
print(f"RNA edges: {rna_pyg_data.num_edges}")
print(f"ADT edges: {adt_pyg_data.num_edges}")

if rna_pyg_data.num_nodes != adt_pyg_data.num_nodes:
    print("⚠️  Warning: RNA and ADT data have different number of nodes!")
else:
    print("✅ RNA and ADT data have same number of nodes")

print("✅ PyTorch Geometric conversion complete!")


## 4. Initialize GATWithTransformerFusion Model


In [None]:
# Initialize the GATWithTransformerFusion model
print("=== Initializing GATWithTransformerFusion Model ===")

# Get data dimensions
rna_input_dim = rna_pyg_data.x.shape[1]
adt_output_dim = adt_pyg_data.x.shape[1]

print(f"RNA input dimension: {rna_input_dim}")
print(f"ADT output dimension: {adt_output_dim}")

# Model configuration
model_config = {
    'in_channels': rna_input_dim,
    'hidden_channels': 64,
    'out_channels': adt_output_dim,
    'heads': 8,
    'dropout': 0.6,
    'nhead': 4,
    'num_layers': 2
}

print(f"Model configuration: {model_config}")

# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = GATWithTransformerFusion(**model_config).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Model architecture summary
print(f"\n=== Model Architecture ===")
print(f"GAT RNA Encoder: {rna_input_dim} -> {model_config['hidden_channels']} (2 layers)")
print(f"Transformer Fusion: {model_config['hidden_channels']} dim, {model_config['nhead']} heads, {model_config['num_layers']} layers")
print(f"GAT ADT Predictor: {model_config['hidden_channels']} -> {adt_output_dim} (1 layer)")

print("✅ Model initialization complete!")


## 5. Train the Model


In [None]:
# Train the GATWithTransformerFusion model
print("=== Training GATWithTransformerFusion Model ===")

# Training parameters
training_config = {
    'epochs': 100,
    'use_cpu_fallback': True,  # Set to False if you have sufficient GPU memory
    'seed': 42
}

print(f"Training configuration: {training_config}")

# Start training
start_time = datetime.now()
print(f"Training started at: {start_time}")

# Train the model
trained_model, rna_data_with_masks, adt_data_with_masks = train_gat_transformer_fusion(
    rna_data=rna_pyg_data,
    adt_data=adt_pyg_data,
    **training_config
)

end_time = datetime.now()
training_duration = end_time - start_time

print(f"\n=== Training Complete ===")
print(f"Training finished at: {end_time}")
print(f"Total training time: {training_duration}")
print(f"Training time per epoch: {training_duration / training_config['epochs']}")

print("✅ Model training complete!")


## 6. Evaluate Model Performance


In [None]:
# Evaluate model performance
print("=== Model Performance Evaluation ===")

from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr, spearmanr

# Make predictions on test set
trained_model.eval()
with torch.no_grad():
    test_predictions, test_fused_embeddings = trained_model(
        x=rna_data_with_masks.x,
        edge_index_rna=rna_data_with_masks.edge_index,
        edge_index_adt=adt_data_with_masks.edge_index
    )

# Get test set predictions and ground truth
test_mask = rna_data_with_masks.test_mask
y_true = adt_data_with_masks.x[test_mask].cpu().numpy()
y_pred = test_predictions[test_mask].cpu().numpy()

print(f"Test set size: {y_true.shape[0]} cells")
print(f"Prediction shape: {y_pred.shape}")

# Calculate metrics
mse = mean_squared_error(y_true, y_pred)
r2 = r2_score(y_true, y_pred)

print(f"\n=== Performance Metrics ===")
print(f"MSE: {mse:.6f}")
print(f"R² Score: {r2:.4f}")

# Calculate per-marker correlations
pearson_corrs = []
spearman_corrs = []

for i in range(y_true.shape[1]):
    if y_true[:, i].std() > 0 and y_pred[:, i].std() > 0:
        pearson_r, _ = pearsonr(y_true[:, i], y_pred[:, i])
        spearman_r, _ = spearmanr(y_true[:, i], y_pred[:, i])
        pearson_corrs.append(pearson_r)
        spearman_corrs.append(spearman_r)

mean_pearson = np.mean(pearson_corrs)
mean_spearman = np.mean(spearman_corrs)

print(f"Mean Pearson Correlation: {mean_pearson:.4f}")
print(f"Mean Spearman Correlation: {mean_spearman:.4f}")

# Show top and bottom performing markers
if len(pearson_corrs) > 0:
    marker_names = adt_adata.var_names[:len(pearson_corrs)]
    corr_df = pd.DataFrame({
        'marker': marker_names,
        'pearson': pearson_corrs,
        'spearman': spearman_corrs
    })
    
    print(f"\n=== Top 5 Performing Markers (Pearson) ===")
    top_markers = corr_df.nlargest(5, 'pearson')
    for _, row in top_markers.iterrows():
        print(f"  {row['marker']}: {row['pearson']:.4f}")
    
    print(f"\n=== Bottom 5 Performing Markers (Pearson) ===")
    bottom_markers = corr_df.nsmallest(5, 'pearson')
    for _, row in bottom_markers.iterrows():
        print(f"  {row['marker']}: {row['pearson']:.4f}")

print("✅ Performance evaluation complete!")


## 7. Visualize Results


In [None]:
# Create visualizations
print("=== Creating Visualizations ===")

# Set up plotting
plt.style.use('default')
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
fig.suptitle('GATWithTransformerFusion Training Results', fontsize=16)

# 1. Prediction vs Ground Truth scatter plot (sample of markers)
ax1 = axes[0, 0]
sample_markers = min(5, y_true.shape[1])
for i in range(sample_markers):
    ax1.scatter(y_true[:, i], y_pred[:, i], alpha=0.6, s=10, label=f'Marker {i+1}')
ax1.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'r--', lw=2)
ax1.set_xlabel('Ground Truth')
ax1.set_ylabel('Predicted')
ax1.set_title('Prediction vs Ground Truth')
ax1.legend()
ax1.grid(True, alpha=0.3)

# 2. Correlation distribution
ax2 = axes[0, 1]
ax2.hist(pearson_corrs, bins=20, alpha=0.7, edgecolor='black')
ax2.axvline(mean_pearson, color='red', linestyle='--', label=f'Mean: {mean_pearson:.3f}')
ax2.set_xlabel('Pearson Correlation')
ax2.set_ylabel('Frequency')
ax2.set_title('Correlation Distribution')
ax2.legend()
ax2.grid(True, alpha=0.3)

# 3. Top performing markers
ax3 = axes[0, 2]
if len(pearson_corrs) > 0:
    top_10_markers = corr_df.nlargest(10, 'pearson')
    ax3.barh(range(len(top_10_markers)), top_10_markers['pearson'])
    ax3.set_yticks(range(len(top_10_markers)))
    ax3.set_yticklabels(top_10_markers['marker'], fontsize=8)
    ax3.set_xlabel('Pearson Correlation')
    ax3.set_title('Top 10 Performing Markers')
    ax3.grid(True, alpha=0.3)

# 4. Residuals plot
ax4 = axes[1, 0]
residuals = y_true - y_pred
ax4.scatter(y_pred, residuals, alpha=0.6, s=10)
ax4.axhline(y=0, color='red', linestyle='--')
ax4.set_xlabel('Predicted Values')
ax4.set_ylabel('Residuals')
ax4.set_title('Residuals Plot')
ax4.grid(True, alpha=0.3)

# 5. Fused embeddings visualization (PCA)
ax5 = axes[1, 1]
from sklearn.decomposition import PCA
pca = PCA(n_components=2)
fused_embeddings_2d = pca.fit_transform(test_fused_embeddings[test_mask].cpu().numpy())
ax5.scatter(fused_embeddings_2d[:, 0], fused_embeddings_2d[:, 1], alpha=0.6, s=10)
ax5.set_xlabel('PC1')
ax5.set_ylabel('PC2')
ax5.set_title('Fused Embeddings (PCA)')
ax5.grid(True, alpha=0.3)

# 6. Performance summary
ax6 = axes[1, 2]
ax6.axis('off')
summary_text = f"""
Performance Summary:

MSE: {mse:.6f}
R² Score: {r2:.4f}
Mean Pearson: {mean_pearson:.4f}
Mean Spearman: {mean_spearman:.4f}

Model Architecture:
• GAT RNA Encoder: 2 layers
• Transformer Fusion: 2 layers
• GAT ADT Predictor: 1 layer

Training:
• Epochs: {training_config['epochs']}
• Parameters: {total_params:,}
• Device: {device}
"""
ax6.text(0.1, 0.9, summary_text, transform=ax6.transAxes, fontsize=10,
         verticalalignment='top', fontfamily='monospace',
         bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))

plt.tight_layout()
plt.show()

print("✅ Visualizations complete!")


## 7.1 UMAP Visualizations of Embeddings and Markers

In [None]:
# Import required libraries
import umap
import seaborn as sns
from sklearn.preprocessing import StandardScaler

print("=== Creating UMAP Visualizations ===")

# Get embeddings and predictions
with torch.no_grad():
    predictions, fused_embeddings = trained_model(
        rna_data_with_masks.x,
        rna_data_with_masks.edge_index
    )

# Convert to numpy arrays
embeddings_np = fused_embeddings.cpu().numpy()
predictions_np = predictions.cpu().numpy()
true_values_np = adt_data_with_masks.x.cpu().numpy()

# Standardize the embeddings
scaler = StandardScaler()
embeddings_scaled = scaler.fit_transform(embeddings_np)

# Create UMAP reducer
reducer = umap.UMAP(
    n_neighbors=15,
    min_dist=0.1,
    n_components=2,
    metric='euclidean',
    random_state=42
)

# Get UMAP embeddings
umap_embeddings = reducer.fit_transform(embeddings_scaled)

# Create figure with multiple subplots
n_markers = min(6, true_values_np.shape[1])  # Show up to 6 markers
fig = plt.figure(figsize=(20, 15))
fig.suptitle('UMAP Visualizations of Embeddings and Markers', fontsize=16)

# Plot UMAP colored by different markers
for idx in range(n_markers):
    # True Values
    plt.subplot(3, n_markers, idx + 1)
    scatter = plt.scatter(
        umap_embeddings[:, 0],
        umap_embeddings[:, 1],
        c=true_values_np[:, idx],
        cmap='viridis',
        s=5,
        alpha=0.7
    )
    plt.colorbar(scatter)
    plt.title(f'True ADT Marker {idx+1}')
    plt.xlabel('UMAP1')
    plt.ylabel('UMAP2')
    
    # Predicted Values
    plt.subplot(3, n_markers, n_markers + idx + 1)
    scatter = plt.scatter(
        umap_embeddings[:, 0],
        umap_embeddings[:, 1],
        c=predictions_np[:, idx],
        cmap='viridis',
        s=5,
        alpha=0.7
    )
    plt.colorbar(scatter)
    plt.title(f'Predicted ADT Marker {idx+1}')
    plt.xlabel('UMAP1')
    plt.ylabel('UMAP2')
    
    # Difference (Error)
    plt.subplot(3, n_markers, 2*n_markers + idx + 1)
    scatter = plt.scatter(
        umap_embeddings[:, 0],
        umap_embeddings[:, 1],
        c=np.abs(true_values_np[:, idx] - predictions_np[:, idx]),
        cmap='viridis',
        s=5,
        alpha=0.7
    )
    plt.colorbar(scatter)
    plt.title(f'Prediction Error Marker {idx+1}')
    plt.xlabel('UMAP1')
    plt.ylabel('UMAP2')

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

print("✅ UMAP visualizations complete!")

# Print correlation statistics for each marker
print("\n=== Marker-wise Statistics ===")
for idx in range(n_markers):
    true_vals = true_values_np[:, idx]
    pred_vals = predictions_np[:, idx]
    pearson_r = pearsonr(true_vals, pred_vals)[0]
    spearman_r = spearmanr(true_vals, pred_vals)[0]
    mse = mean_squared_error(true_vals, pred_vals)
    
    print(f"\nMarker {idx+1}:")
    print(f"Pearson correlation: {pearson_r:.4f}")
    print(f"Spearman correlation: {spearman_r:.4f}")
    print(f"MSE: {mse:.6f}")

## 8. Save Trained Model


In [None]:
# Save the trained model
print("=== Saving Trained Model ===")

# Create save directory
save_dir = "trained_models"
os.makedirs(save_dir, exist_ok=True)

# Generate timestamp for unique filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_filename = f"gat_transformer_fusion_{timestamp}.pth"
model_path = os.path.join(save_dir, model_filename)

# Save model checkpoint
checkpoint = {
    'model_state_dict': trained_model.state_dict(),
    'model_config': model_config,
    'training_config': training_config,
    'performance_metrics': {
        'mse': mse,
        'r2_score': r2,
        'mean_pearson': mean_pearson,
        'mean_spearman': mean_spearman
    },
    'data_info': {
        'rna_input_dim': rna_input_dim,
        'adt_output_dim': adt_output_dim,
        'num_nodes': rna_pyg_data.num_nodes,
        'num_edges': rna_pyg_data.num_edges
    },
    'training_time': str(training_duration),
    'timestamp': timestamp
}

torch.save(checkpoint, model_path)

print(f"Model saved to: {model_path}")
print(f"Checkpoint contains:")
for key in checkpoint.keys():
    print(f"  • {key}")

# Also save a simple state dict for easy loading
simple_model_path = os.path.join(save_dir, f"gat_transformer_fusion_simple_{timestamp}.pth")
torch.save(trained_model.state_dict(), simple_model_path)
print(f"Simple model state dict saved to: {simple_model_path}")

print("✅ Model saving complete!")


## 9. Make Predictions on New Data


In [None]:
# Function to make predictions on new data
def predict_with_gat_transformer_fusion(model, rna_adata, device=None):
    """
    Make predictions using trained GATWithTransformerFusion model
    
    Args:
        model: Trained GATWithTransformerFusion model
        rna_adata: AnnData object with RNA data
        device: Device to use for inference
    
    Returns:
        predicted_adt_embeddings, fused_embeddings
    """
    
    print("=== Making Predictions with GATWithTransformerFusion ===")
    
    if device is None:
        device = next(model.parameters()).device
    
    # Convert to PyTorch Geometric format
    rna_pyg_data = build_pyg_data(rna_adata)
    rna_pyg_data = rna_pyg_data.to(device)
    
    print(f"Input RNA data shape: {rna_adata.shape}")
    print(f"PyG data: {rna_pyg_data}")
    
    # Make predictions
    model.eval()
    with torch.no_grad():
        predicted_adt, fused_embeddings = model(
            x=rna_pyg_data.x,
            edge_index_rna=rna_pyg_data.edge_index
        )
    
    # Convert to numpy
    predicted_adt_np = predicted_adt.cpu().numpy()
    fused_embeddings_np = fused_embeddings.cpu().numpy()
    
    print(f"Predicted ADT shape: {predicted_adt_np.shape}")
    print(f"Fused embeddings shape: {fused_embeddings_np.shape}")
    
    return predicted_adt_np, fused_embeddings_np

# Example: Make predictions on the same data (for demonstration)
print("Making predictions on training data (for demonstration)...")
demo_predictions, demo_fused = predict_with_gat_transformer_fusion(trained_model, rna_adata)

print(f"\n=== Prediction Summary ===")
print(f"Predicted ADT embeddings shape: {demo_predictions.shape}")
print(f"Fused embeddings shape: {demo_fused.shape}")
print(f"Predicted ADT range: [{demo_predictions.min():.4f}, {demo_predictions.max():.4f}]")
print(f"Fused embeddings range: [{demo_fused.min():.4f}, {demo_fused.max():.4f}]")

print("✅ Predictions complete!")


## 10. Load Saved Model (Example)


In [None]:
# Example: How to load a saved model
def load_gat_transformer_fusion_model(model_path, device=None):
    """
    Load a saved GATWithTransformerFusion model
    
    Args:
        model_path: Path to the saved model checkpoint
        device: Device to load the model on
    
    Returns:
        loaded_model, checkpoint_info
    """
    
    print(f"=== Loading GATWithTransformerFusion Model ===")
    print(f"Loading from: {model_path}")
    
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load checkpoint
    checkpoint = torch.load(model_path, map_location=device)
    
    # Extract model configuration
    model_config = checkpoint['model_config']
    print(f"Model configuration: {model_config}")
    
    # Initialize model
    model = GATWithTransformerFusion(**model_config).to(device)
    
    # Load state dict
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    print(f"Model loaded successfully on {device}")
    
    # Print checkpoint info
    print(f"\n=== Checkpoint Information ===")
    for key, value in checkpoint.items():
        if key != 'model_state_dict':
            print(f"  {key}: {value}")
    
    return model, checkpoint

# Example usage (commented out since we already have the model loaded)
"""
# Load the model we just saved
loaded_model, checkpoint_info = load_gat_transformer_fusion_model(model_path)

# Make predictions with loaded model
new_predictions, new_fused = predict_with_gat_transformer_fusion(loaded_model, rna_adata)

print("✅ Model loading example complete!")
"""

print("✅ Model loading function ready!")
print("Uncomment the example code above to test loading the saved model.")


## 11. Summary and Next Steps

### What We've Accomplished:
1. ✅ **Loaded and preprocessed** RNA and ADT data
2. ✅ **Converted data** to PyTorch Geometric format
3. ✅ **Initialized** GATWithTransformerFusion model
4. ✅ **Trained** the model with end-to-end optimization
5. ✅ **Evaluated** performance with comprehensive metrics
6. ✅ **Visualized** results with multiple plots
7. ✅ **Saved** the trained model with full checkpoint
8. ✅ **Created** prediction functions for new data
9. ✅ **Provided** model loading examples

### Key Advantages of GATWithTransformerFusion:
- **End-to-end training**: Single model for RNA→ADT mapping
- **Graph structure preservation**: Maintains graph relationships throughout
- **Transformer benefits**: Self-attention for better feature fusion
- **Unified architecture**: Combines GAT and Transformer in one model

### Performance Metrics:
- **MSE**: Measure of prediction accuracy
- **R² Score**: Explained variance
- **Pearson/Spearman Correlations**: Per-marker performance
- **Training time**: Efficiency metrics

### Next Steps:
1. **Compare** with existing GAT + Transformer pipeline
2. **Tune hyperparameters** for better performance
3. **Test on different datasets** for generalization
4. **Integrate** with existing workflow
5. **Deploy** for production use

### Usage in Other Notebooks:
```python
# Load the trained model
from scripts.trainer.gat_trainer import train_gat_transformer_fusion
from scripts.model.doNET import GATWithTransformerFusion

# Train new model
trained_model, rna_data, adt_data = train_gat_transformer_fusion(
    rna_data=rna_pyg_data,
    adt_data=adt_pyg_data,
    epochs=100
)

# Make predictions
predicted_adt, fused_embeddings = predict_with_gat_transformer_fusion(
    trained_model, new_rna_adata
)
```

**🎉 GATWithTransformerFusion training pipeline is now ready for use!**
