# GNN Benchmarking Analysis - Google Colab Version

This notebook runs the same comprehensive benchmarking analysis as `benchmarking.py` but in Google Colab.

**Features**:
- Multi-seed training (GCN, GAT, MLP, MeanMedian)
- Statistical analysis with confidence intervals
- Pairwise comparisons (Wilcoxon test + rank-biserial effect sizes)
- Motif-specific performance metrics
- Sensitivity analysis (optional)
- Complete visualizations

**Setup time**: ~5-10 minutes
**Runtime**: 30 min (5 seeds) to 3 hours (10 seeds + sensitivity)

## 1. Setup: Mount Google Drive and Install Dependencies

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Set working directory to your project folder
import os
os.chdir('/content/drive/My Drive/182-GNN_SAE')  # MODIFY THIS PATH
!pwd

In [None]:
# Install required packages
!pip install -q torch torch-geometric pytorch-lightning
!pip install -q pandas numpy scipy matplotlib seaborn networkx scikit-learn
!pip install -q tqdm

print("✓ All packages installed")

In [None]:
# Check GPU availability
import torch

print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("WARNING: No GPU detected. Training will be slow on CPU.")

## 2. Import Required Modules

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
from collections import defaultdict
from typing import Dict, List
import warnings
warnings.filterwarnings('ignore')

# PyTorch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data, Batch
from scipy import stats
from tqdm import tqdm

# Import from local benchmarking.py
from gnn_train import (
    GraphDataset, GCNModel, GATModel, GNNTrainer,
    load_all_graphs, split_data, collate_fn, MOTIF_LABELS, MOTIF_TO_ID
)
from benchmarking import (
    MeanMedianBaseline, MLPBaseline, DataVariationDataset, BaselineTrainer,
    BenchmarkExperiment
)

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

print("✓ All imports successful")

## 3. Configuration

In [None]:
# Configuration parameters
CONFIG = {
    'n_seeds': 20,             # Number of random seeds (default: 20 for statistical rigor)
    'num_epochs': 50,         # Maximum training epochs
    'batch_size': 128,          # Batch size
    'learning_rate': 1e-3,     # Learning rate
    'run_sensitivity': True,  # Run sensitivity analysis? (adds ~1-2 hours)
    'output_dir': 'outputs/benchmark_colab',  # Output directory
}

# Print configuration
print("Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

# Create output directory
Path(CONFIG['output_dir']).mkdir(parents=True, exist_ok=True)
print(f"\n✓ Output directory created: {CONFIG['output_dir']}")

## 4. Phase 1: Multi-Seed Training

In [None]:
# Initialize benchmark experiment
benchmark = BenchmarkExperiment(output_dir=CONFIG['output_dir'])

# Models to test
models_to_test = ['GCN', 'GAT', 'MLP', 'MeanMedian']

# Storage for results
all_results = {}
all_detailed_stats = {}

print("="*60)
print("PHASE 1: MULTI-SEED TRAINING")
print("="*60)

# Train each model
for model in models_to_test:
    print(f"\n{'='*60}")
    print(f"Training {model}")
    print(f"{'='*60}")
    
    try:
        # Run multi-seed training
        results = benchmark.run_multi_seed_training(
            model_type=model,
            n_seeds=CONFIG['n_seeds'],
            num_epochs=CONFIG['num_epochs'],
            batch_size=CONFIG['batch_size'],
            learning_rate=CONFIG['learning_rate']
        )
        
        all_results[model] = results
        
        # Save individual results
        benchmark.save_results(results, f"{model.lower()}_results.json")
        
        # Generate statistics
        detailed_stats = benchmark.generate_detailed_statistics(results, model)
        all_detailed_stats[model] = detailed_stats
        
        # Print results table
        results_df = benchmark.generate_results_table(results, model)
        print(f"\n{model} Results:")
        print(results_df.to_string(index=False))
        
    except Exception as e:
        print(f"Error training {model}: {e}")
        import traceback
        traceback.print_exc()

print("\n" + "="*60)
print("✓ Multi-seed training complete")
print("="*60)

## 5. Phase 1b: Aggregate Statistics and Save

In [None]:
# Save detailed statistics
if all_detailed_stats:
    print("Saving comprehensive statistical summary...")
    
    detailed_stats_dict = {
        model: {k: v for k, v in stats.items() if k != 'test_loss' or isinstance(v, dict)}
        for model, stats in all_detailed_stats.items()
    }
    benchmark.save_results(detailed_stats_dict, "detailed_statistics.json")
    print("✓ Saved detailed_statistics.json")

# Save multi-seed summary
benchmark.save_results(all_results, "multi_seed_summary.json")
print("✓ Saved multi_seed_summary.json")

## 6. Phase 1c: Motif-Specific Analysis

In [None]:
# Aggregate motif metrics across seeds
print("Aggregating motif metrics across seeds...")

motif_metrics_all_models = {}
for model, results in all_results.items():
    if 'motif_metrics' in results and results['motif_metrics']:
        # Average motif metrics across seeds
        aggregated_motif_metrics = {}
        for motif_label in results['motif_metrics'][0].keys():
            mses = [seed_metrics[motif_label]['mean_mse'] 
                    for seed_metrics in results['motif_metrics'] 
                    if motif_label in seed_metrics]
            maes = [seed_metrics[motif_label]['mean_mae'] 
                    for seed_metrics in results['motif_metrics'] 
                    if motif_label in seed_metrics]
            num_graphs = results['motif_metrics'][0][motif_label]['num_graphs']
            
            aggregated_motif_metrics[motif_label] = {
                'num_graphs': num_graphs,
                'mean_mse': float(np.mean(mses)),
                'std_mse': float(np.std(mses)) if len(mses) > 1 else 0.0,
                'mean_mae': float(np.mean(maes)),
                'std_mae': float(np.std(maes)) if len(maes) > 1 else 0.0
            }
        motif_metrics_all_models[model] = aggregated_motif_metrics

# Save aggregated motif metrics
if motif_metrics_all_models:
    benchmark.save_results(motif_metrics_all_models, "motif_metrics_summary.json")
    print("✓ Saved motif_metrics_summary.json")
else:
    print("⚠ No motif metrics found")

## 7. Phase 2: Pairwise Statistical Comparisons

In [None]:
print("\n" + "="*60)
print("PHASE 2: PAIRWISE STATISTICAL COMPARISONS")
print("="*60)

# Compute pairwise comparisons
pairwise_comparisons = benchmark.compute_pairwise_comparisons(all_results)

if pairwise_comparisons:
    # Display results
    print("\nPairwise Comparison Results:")
    print()
    
    for comparison_label, results in pairwise_comparisons.items():
        print(f"\n{comparison_label}:")
        print(f"  Wilcoxon p-value: {results['p_value']:.6f}")
        print(f"  Rank-biserial (r): {results['rank_biserial']:.4f}")
        print(f"  95% CI: [{results['rank_biserial_ci_lower']:.4f}, {results['rank_biserial_ci_upper']:.4f}]")
        print(f"  Mean Loss {comparison_label.split(' vs ')[0]}: {results['mean_loss_a']:.4f}")
        print(f"  Mean Loss {comparison_label.split(' vs ')[1]}: {results['mean_loss_b']:.4f}")
        print(f"  Better Model: {results['better_model']}")
        print(f"  Significant (p<0.05): {results['is_significant']}")
    
    # Save results
    benchmark.save_results(pairwise_comparisons, "pairwise_comparisons.json")
    print("\n✓ Saved pairwise_comparisons.json")
else:
    print("⚠ No pairwise comparisons computed")

## 8. Generate Visualizations

In [None]:
print("\n" + "="*60)
print("GENERATING VISUALIZATIONS")
print("="*60)

if all_detailed_stats:
    # Baseline comparison
    try:
        print("\nGenerating baseline comparison plot...")
        benchmark.plot_baseline_comparison(all_detailed_stats)
    except Exception as e:
        print(f"Error: {e}")
    
    # Seed variance
    try:
        print("Generating seed variance plot...")
        benchmark.plot_seed_variance(all_detailed_stats)
    except Exception as e:
        print(f"Error: {e}")
    
    # Statistical summary table
    try:
        print("Generating statistical summary table...")
        benchmark.plot_statistical_summary_table(all_detailed_stats)
    except Exception as e:
        print(f"Error: {e}")
    
    # Train/val/test progression
    try:
        print("Generating train/val/test progression plot...")
        benchmark.plot_train_val_test_progression(all_detailed_stats)
    except Exception as e:
        print(f"Error: {e}")

# Pairwise comparisons
if pairwise_comparisons:
    try:
        print("Generating pairwise comparisons plot...")
        benchmark.plot_pairwise_comparisons(pairwise_comparisons)
    except Exception as e:
        print(f"Error: {e}")

print("\n✓ Visualizations complete")

## 9. Motif-Specific Visualizations

In [None]:
if motif_metrics_all_models:
    print("\nGenerating motif-specific visualizations...")
    
    try:
        print("Generating motif comparison plot...")
        benchmark.plot_motif_comparison(motif_metrics_all_models)
    except Exception as e:
        print(f"Error: {e}")
    
    try:
        print("Generating motif heatmap...")
        benchmark.plot_motif_heatmap(motif_metrics_all_models)
    except Exception as e:
        print(f"Error: {e}")
    
    print("✓ Motif visualizations complete")
else:
    print("No motif metrics to visualize")

## 10. Sensitivity Analysis (Optional)

In [None]:
# CHANGE THIS TO TRUE TO RUN SENSITIVITY ANALYSIS
RUN_SENSITIVITY = CONFIG['run_sensitivity']

if RUN_SENSITIVITY:
    print("\n" + "="*60)
    print("PHASE 4: SENSITIVITY ANALYSIS")
    print("="*60)
    print("\nThis will test GCN and GAT with different:")
    print("  - Timesteps: [25, 50, 75]")
    print("  - Noise levels: [0.005, 0.01, 0.05]")
    print("\nEstimated time: 1-2 hours")
    
    sensitivity_results = {}
    
    for model in ['GCN', 'GAT']:
        try:
            print(f"\nRunning sensitivity analysis for {model}...")
            results = benchmark.run_sensitivity_analysis(model_type=model)
            sensitivity_results[model] = results
            benchmark.save_results(results, f"{model.lower()}_sensitivity.json")
            print(f"✓ Saved {model.lower()}_sensitivity.json")
        except Exception as e:
            print(f"Error: {e}")
    
    # Generate sensitivity visualizations
    if sensitivity_results:
        try:
            print("\nGenerating combined sensitivity plot...")
            benchmark.plot_sensitivity_analysis(sensitivity_results)
        except Exception as e:
            print(f"Error: {e}")
        
        try:
            print("Generating individual sensitivity plots...")
            benchmark.plot_individual_sensitivity_analysis(sensitivity_results)
        except Exception as e:
            print(f"Error: {e}")
        
        print("\n✓ Sensitivity analysis complete")
else:
    print("\n⊘ Sensitivity analysis skipped (set RUN_SENSITIVITY=True to enable)")

## 11. Summary and Results

In [None]:
print("\n" + "="*60)
print("BENCHMARK COMPLETE")
print("="*60)

print(f"\nAll results saved to: {CONFIG['output_dir']}")

print("\nGenerated files:")
print("\n  Visualizations:")
print("    - baseline_comparison.png")
print("    - seed_variance.png")
print("    - statistical_summary_table.png")
print("    - train_val_test_progression.png")
print("    - pairwise_comparisons.png")

if motif_metrics_all_models:
    print("    - motif_comparison.png")
    print("    - motif_heatmap.png")

if RUN_SENSITIVITY:
    print("    - sensitivity_analysis.png")
    print("    - gcn_sensitivity_detailed.png")
    print("    - gat_sensitivity_detailed.png")

print("\n  Data Files:")
print("    - detailed_statistics.json")
print("    - gcn_results.json")
print("    - gat_results.json")
print("    - mlp_results.json")
print("    - meanmedian_results.json")
print("    - pairwise_comparisons.json")
print("    - motif_metrics_summary.json")
print("    - multi_seed_summary.json")

if RUN_SENSITIVITY:
    print("    - gcn_sensitivity.json")
    print("    - gat_sensitivity.json")

print("\n" + "="*60)

## 12. Download Results

In [None]:
# Create zip file of all results
import shutil

output_path = Path(CONFIG['output_dir'])
zip_name = "benchmark_results"

print(f"Creating zip file: {zip_name}.zip")
shutil.make_archive(zip_name, 'zip', output_path)

print(f"✓ Created {zip_name}.zip")
print(f"\nTo download:")
print(f"  1. Click the folder icon on the left (Files)")
print(f"  2. Find {zip_name}.zip")
print(f"  3. Right-click and select Download")

## 13. Inspect Key Results

In [None]:
# Display detailed statistics as a nice table
if all_detailed_stats:
    print("\n" + "="*80)
    print("DETAILED STATISTICS SUMMARY")
    print("="*80)
    
    # Create summary dataframe
    summary_data = []
    for model_name, stats in all_detailed_stats.items():
        test_loss = stats['test_loss']
        summary_data.append({
            'Model': model_name,
            'Mean': f"{test_loss['mean']:.4f}",
            'Median': f"{test_loss['median']:.4f}",
            'Std': f"{test_loss['std']:.4f}",
            '95% CI': f"±{test_loss['ci_95']:.4f}",
            'Min': f"{test_loss['min']:.4f}",
            'Max': f"{test_loss['max']:.4f}"
        })
    
    summary_df = pd.DataFrame(summary_data)
    print(summary_df.to_string(index=False))
    print()

In [None]:
# Display pairwise comparison results
if pairwise_comparisons:
    print("\n" + "="*80)
    print("PAIRWISE COMPARISON RESULTS (Wilcoxon Test)")
    print("="*80)
    
    comparison_data = []
    for comparison_label, results in pairwise_comparisons.items():
        comparison_data.append({
            'Comparison': comparison_label,
            'p-value': f"{results['p_value']:.6f}",
            'Rank-biserial (r)': f"{results['rank_biserial']:.4f}",
            '95% CI': f"[{results['rank_biserial_ci_lower']:.4f}, {results['rank_biserial_ci_upper']:.4f}]",
            'Significant': '✓ Yes' if results['is_significant'] else '✗ No',
            'Better Model': results['better_model']
        })
    
    comparison_df = pd.DataFrame(comparison_data)
    print(comparison_df.to_string(index=False))
    print()
    print("Legend:")
    print("  p-value < 0.05: Statistically significant difference")
    print("  Rank-biserial |r| > 0.3: Medium to large effect size")
    print("  95% CI: Confidence interval - narrow = precise, wide = uncertain")

## 14. Display Generated Plots

In [None]:
# Display baseline comparison plot
from IPython.display import Image, display

output_dir = Path(CONFIG['output_dir'])

plot_files = [
    'baseline_comparison.png',
    'seed_variance.png',
    'statistical_summary_table.png',
    'pairwise_comparisons.png',
    'train_val_test_progression.png',
]

for plot_file in plot_files:
    plot_path = output_dir / 'visualizations' / plot_file
    if plot_path.exists():
        print(f"\n{'='*60}")
        print(f"{plot_file}")
        print(f"{'='*60}")
        display(Image(str(plot_path)))
    else:
        print(f"Plot not found: {plot_file}")

In [None]:
# Display motif plots if available
if motif_metrics_all_models:
    motif_plots = ['motif_comparison.png', 'motif_heatmap.png']
    
    for plot_file in motif_plots:
        plot_path = output_dir / 'visualizations' / plot_file
        if plot_path.exists():
            print(f"\n{'='*60}")
            print(f"{plot_file}")
            print(f"{'='*60}")
            display(Image(str(plot_path)))
        else:
            print(f"Plot not found: {plot_file}")

In [None]:
# Display sensitivity plots if available
if RUN_SENSITIVITY:
    sensitivity_plots = [
        'sensitivity_analysis.png',
        'gcn_sensitivity_detailed.png',
        'gat_sensitivity_detailed.png'
    ]
    
    for plot_file in sensitivity_plots:
        plot_path = output_dir / 'visualizations' / plot_file
        if plot_path.exists():
            print(f"\n{'='*60}")
            print(f"{plot_file}")
            print(f"{'='*60}")
            display(Image(str(plot_path)))
        else:
            print(f"Plot not found: {plot_file}")
else:
    print("Sensitivity analysis not run. Set RUN_SENSITIVITY=True to enable.")

## Notes & Tips

### Configuration Options

To modify the analysis, edit the CONFIG dictionary in cell 3:

```python
CONFIG = {
    'n_seeds': 5,              # More seeds = more rigorous (5-10 recommended)
    'num_epochs': 100,         # Higher = longer training but potentially better
    'batch_size': 32,          # Reduce if out of memory
    'run_sensitivity': False,  # Set to True to run sensitivity analysis (adds 1-2 hours)
    'output_dir': 'outputs/benchmark_colab',
}
```

### Recommended Configurations

**Quick test (15-30 min)**:
```python
n_seeds: 3
num_epochs: 50
run_sensitivity: False
```

**Standard evaluation (45 min - 1 hour)**:
```python
n_seeds: 5
num_epochs: 100
run_sensitivity: False
```

**Full analysis (2-3 hours)**:
```python
n_seeds: 5
num_epochs: 100
run_sensitivity: True
```

**Publication quality (3-4 hours)**:
```python
n_seeds: 10
num_epochs: 150
run_sensitivity: True
```

### Troubleshooting

**Out of Memory Error**:
- Reduce `batch_size` to 16
- Reduce `n_seeds` to 3
- Reduce `num_epochs` to 50

**Timeout (Colab disconnects)**:
- Colab has 12-hour time limit
- Reduce number of seeds or epochs
- Don't run sensitivity analysis with many seeds

**Graphs not showing**:
- Wait for training to complete
- Check that all_detailed_stats is not empty
- Rerun the visualization cells

### Output Files

All results are saved to the `outputs/benchmark_colab/` folder:
- **Visualizations/**: PNG files of all plots
- **Statistics/**: JSON files with raw data

Download the `benchmark_results.zip` file to get everything at once.

### Interpreting Results

See the `BENCHMARKING_WORKFLOW.md` and `SENSITIVITY_ANALYSIS_GUIDE.md` files for detailed interpretation guides.

Key questions to ask:
1. **Do GNNs outperform baselines?** (Check baseline_comparison.png and pairwise_comparisons.png)
2. **Are results stable across seeds?** (Check seed_variance.png - tight box = stable)
3. **Which model is best?** (Check statistical_summary_table.png - lowest mean test loss)
4. **Are results significant?** (Check pairwise_comparisons.png - p < 0.05?)
5. **Which motif types are hardest?** (Check motif_heatmap.png - darker = harder)
6. **Is model robust to noise?** (Check sensitivity_analysis.png - flat = robust)
