# BrainGraphNet - Model Analysis
## Analyze trained model predictions and performance

In [None]:
import sys
sys.path.append('..')

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch

from src.models.evolve_gcn import EvolveGCN
from src.models.metrics import compute_metrics, connectivity_metrics
from src.utils.config_parser import load_config
from src.utils.helpers import load_checkpoint
from src.visualization.connectivity_plots import plot_prediction_comparison
from src.visualization.training_curves import plot_prediction_scatter

sns.set_style('whitegrid')
%matplotlib inline

## 1. Load Trained Model

In [None]:
# Load config and model
config = load_config('../config.yaml')
device = torch.device('cpu')

model = EvolveGCN(config).to(device)
checkpoint_path = '../outputs/checkpoints/best_model.pth'

try:
    load_checkpoint(model, checkpoint_path, device)
    print("✅ Model loaded successfully!")
except FileNotFoundError:
    print("⚠️ No checkpoint found. Please train the model first.")
    print("Run: python train.py --config config.yaml")

## 2. Model Architecture

In [None]:
# Print model architecture
print("Model Architecture:")
print("="*60)
print(model)
print("="*60)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## 3. Load Test Predictions

In [None]:
# Load predictions (if test.py was run)
try:
    data = np.load('../outputs/predictions/test_predictions.npz')
    predictions = data['predictions']
    targets = data['targets']
    
    print(f"Predictions shape: {predictions.shape}")
    print(f"Targets shape: {targets.shape}")
except FileNotFoundError:
    print("⚠️ No predictions found. Run: python test.py --checkpoint outputs/checkpoints/best_model.pth --save-predictions")

## 4. Compute Metrics

In [None]:
# Compute evaluation metrics
if 'predictions' in locals():
    metrics = compute_metrics(predictions, targets)
    
    print("Evaluation Metrics:")
    print("="*40)
    for metric, value in metrics.items():
        print(f"{metric.upper()}: {value:.4f}")
    print("="*40)

## 5. Visualize Predictions

In [None]:
# Plot prediction vs ground truth for first sample
if 'predictions' in locals():
    sample_idx = 0
    plot_prediction_comparison(
        targets[sample_idx],
        predictions[sample_idx]
    )

## 6. Prediction Scatter Plot

In [None]:
# Scatter plot of all predictions
if 'predictions' in locals():
    plot_prediction_scatter(targets, predictions)

## 7. Error Analysis

In [None]:
# Analyze prediction errors
if 'predictions' in locals():
    errors = np.abs(predictions - targets)
    
    plt.figure(figsize=(12, 5))
    
    # Error distribution
    plt.subplot(1, 2, 1)
    plt.hist(errors.flatten(), bins=50, edgecolor='black', alpha=0.7)
    plt.xlabel('Absolute Error')
    plt.ylabel('Frequency')
    plt.title('Distribution of Prediction Errors')
    plt.grid(alpha=0.3)
    
    # Error heatmap for first sample
    plt.subplot(1, 2, 2)
    plt.imshow(errors[0], cmap='hot')
    plt.colorbar(label='Absolute Error')
    plt.title('Error Heatmap (Sample 1)')
    
    plt.tight_layout()
    plt.show()
    
    print(f"Mean error: {errors.mean():.4f}")
    print(f"Max error: {errors.max():.4f}")

## 8. Connectivity-Specific Metrics

In [None]:
# Compute connectivity-specific metrics
if 'predictions' in locals():
    conn_metrics = connectivity_metrics(predictions[0], targets[0])
    
    print("Connectivity Metrics:")
    print("="*40)
    for metric, value in conn_metrics.items():
        print(f"{metric}: {value:.4f}")
    print("="*40)

## Summary

This notebook demonstrated:
- Loading trained model
- Analyzing model architecture
- Evaluating predictions
- Visualizing results
- Computing connectivity-specific metrics

For more analysis, see the main project README.