# GNN Hyperparameter Optimization with WandB Sweeps

This notebook demonstrates how to use WandB sweeps for hyperparameter optimization of GNN models.

In [None]:
# Import necessary libraries
import torch
import numpy as np
import pandas as pd
import ast
import wandb
from types import SimpleNamespace
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
# Import custom modules
from GraphBuilder_with_features import create_graph_dataset
from sweep_utils import (
    run_sweep, 
    quick_sweep,
    analyze_sweep_results,
    create_example_config_file
)

In [None]:
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 1. Load and Prepare Data

In [None]:
# Load data function
def load_graph_data(loop):
    """Load graph data from CSV files."""
    edges = []
    y = []
    
    filename = f'../Graph_Edge_Data/den_graph_data_{loop}.csv'
    df = pd.read_csv(filename)
    edges += df['EDGES'].tolist()
    y += df['COEFFICIENTS'].tolist()
    
    edges = [ast.literal_eval(e) for e in edges]
    graphs_data = list(zip(edges, y))
    return graphs_data

In [None]:
# Load data
graphs_data = load_graph_data(loop=8)
print(f"Loaded {len(graphs_data)} graphs")

In [None]:
# Create dataset with chosen features
feature_config = {
    'selected_features': ['basic', 'face', 'spectral_node', 'centrality'],
    'laplacian_pe_k': 3
}

dataset, scaler = create_graph_dataset(graphs_data, feature_config)
print(f"Dataset created with {len(dataset)} graphs")
print(f"Feature dimensions: {dataset[0].x.shape[1]}")
print(f"Feature names: {dataset[0].feature_names}")

## 2. Define Hyperparameter Search Space

In [None]:
# Define hyperparameter ranges for grid search
param_ranges = {
    'hidden_channels': [32, 64, 128],
    'num_layers': [2, 3, 4],
    'dropout': [0.1, 0.2, 0.3],
    'lr': [0.001, 0.003, 0.01],
    'weight_decay': [0, 1e-4, 5e-4]
}

# Calculate total number of combinations
total_runs = 1
for param, values in param_ranges.items():
    total_runs *= len(values)
    print(f"{param}: {len(values)} values - {values}")

print(f"\nTotal combinations: {total_runs}")

In [None]:
# Fixed configuration (not swept)
fixed_config = {
    'model_name': 'gin',
    'epochs': 100,
    'batch_size': 32,
    'scheduler_type': 'onecycle',
    'save_models': False
}

## 3. Run Hyperparameter Sweep

### Option A: Quick Test (Fewer Combinations)

In [None]:
# Quick test with fewer combinations
quick_param_ranges = {
    'hidden_channels': [32, 64],
    'num_layers': [2, 3],
    'dropout': [0.1, 0.2],
    'lr': [0.001, 0.01],
    'weight_decay': [0, 1e-4]
}

# Calculate combinations
quick_runs = 1
for values in quick_param_ranges.values():
    quick_runs *= len(values)
print(f"Quick test combinations: {quick_runs}")

In [None]:
# Run quick sweep
project_name = "gnn-planar-graphs-sweep"
sweep_name = "quick_test"

# Uncomment to run:
# sweep_id = run_sweep(
#     param_ranges=quick_param_ranges,
#     dataset=dataset,
#     project_name=project_name,
#     fixed_config=fixed_config,
#     sweep_name=sweep_name
# )

### Option B: Full Grid Search

In [None]:
# Full sweep - WARNING: This will run many experiments!
# sweep_id = run_sweep(
#     param_ranges=param_ranges,
#     dataset=dataset,
#     project_name=project_name,
#     fixed_config=fixed_config,
#     sweep_name="full_grid_search"
# )

### Option C: Using the Quick Sweep Function

In [None]:
# Even quicker sweep with default parameters
# sweep_id = quick_sweep(
#     dataset=dataset,
#     project_name=project_name,
#     hidden_channels=[32, 64],
#     num_layers=[2, 3],
#     dropout=[0.15, 0.25],
#     lr=[0.001, 0.005],
#     weight_decay=[0, 1e-4],
#     epochs=50  # Fewer epochs for testing
# )

## 4. Analyze Sweep Results

In [None]:
# Replace with your actual sweep ID
# sweep_id = "your-sweep-id-here"
# results = analyze_sweep_results(project_name, sweep_id)

In [None]:
# Display best configuration
# if results['best_config']:
#     print("Best Configuration Found:")
#     print(f"Validation Accuracy: {results['best_config']['best_val_accuracy']:.4f}")
#     print("\nHyperparameters:")
#     for param, value in results['best_config']['config'].items():
#         if param in param_ranges:
#             print(f"  {param}: {value}")

In [None]:
# Show top N configurations
# N = 10
# print(f"\nTop {N} Configurations:")
# for i, config in enumerate(results['all_results'][:N]):
#     print(f"\n{i+1}. Validation Accuracy: {config['best_val_accuracy']:.4f}")
#     print("   Config:", {k: v for k, v in config['config'].items() if k in param_ranges})

## 5. Visualize Results

In [None]:
# Function to visualize hyperparameter importance
def plot_hyperparameter_importance(results, param_name):
    """Plot validation accuracy distribution for different values of a hyperparameter."""
    if not results['all_results']:
        print("No results to plot")
        return
    
    # Extract data
    param_values = []
    accuracies = []
    
    for run in results['all_results']:
        if param_name in run['config']:
            param_values.append(run['config'][param_name])
            accuracies.append(run['best_val_accuracy'])
    
    # Create DataFrame
    df = pd.DataFrame({
        param_name: param_values,
        'validation_accuracy': accuracies
    })
    
    # Plot
    plt.figure(figsize=(10, 6))
    sns.boxplot(x=param_name, y='validation_accuracy', data=df)
    plt.title(f'Validation Accuracy vs {param_name}')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

In [None]:
# Plot for each hyperparameter
# for param in param_ranges.keys():
#     plot_hyperparameter_importance(results, param)

In [None]:
# Create a heatmap for two hyperparameters
def plot_2d_heatmap(results, param1, param2):
    """Create a heatmap showing validation accuracy for two hyperparameters."""
    if not results['all_results']:
        print("No results to plot")
        return
    
    # Extract data
    data = {}
    for run in results['all_results']:
        if param1 in run['config'] and param2 in run['config']:
            key = (run['config'][param1], run['config'][param2])
            if key not in data or run['best_val_accuracy'] > data[key]:
                data[key] = run['best_val_accuracy']
    
    # Create matrix
    param1_values = sorted(set(k[0] for k in data.keys()))
    param2_values = sorted(set(k[1] for k in data.keys()))
    
    matrix = np.zeros((len(param2_values), len(param1_values)))
    for i, p2 in enumerate(param2_values):
        for j, p1 in enumerate(param1_values):
            matrix[i, j] = data.get((p1, p2), 0)
    
    # Plot
    plt.figure(figsize=(10, 8))
    sns.heatmap(matrix, 
                xticklabels=param1_values, 
                yticklabels=param2_values,
                annot=True, 
                fmt='.3f', 
                cmap='viridis')
    plt.xlabel(param1)
    plt.ylabel(param2)
    plt.title(f'Validation Accuracy: {param1} vs {param2}')
    plt.tight_layout()
    plt.show()

In [None]:
# Plot heatmaps for interesting parameter pairs
# plot_2d_heatmap(results, 'hidden_channels', 'num_layers')
# plot_2d_heatmap(results, 'lr', 'weight_decay')
# plot_2d_heatmap(results, 'hidden_channels', 'dropout')

## 6. Save Results

In [None]:
# Save results to file
# import json
# with open(f'sweep_results_{sweep_id}.json', 'w') as f:
#     json.dump(results, f, indent=2)
# print(f"Results saved to sweep_results_{sweep_id}.json")

In [None]:
# Create a summary DataFrame
# if results['all_results']:
#     summary_data = []
#     for run in results['all_results']:
#         row = {
#             'val_accuracy': run['best_val_accuracy'],
#             'train_accuracy': run['final_train_accuracy'],
#             'best_epoch': run['best_epoch']
#         }
#         # Add hyperparameters
#         for param in param_ranges.keys():
#             if param in run['config']:
#                 row[param] = run['config'][param]
#         summary_data.append(row)
#     
#     summary_df = pd.DataFrame(summary_data)
#     summary_df.to_csv(f'sweep_summary_{sweep_id}.csv', index=False)
#     print("Summary saved to CSV")
#     print(summary_df.head())

## 7. Train Final Model with Best Hyperparameters

In [None]:
# Extract best hyperparameters
# if results['best_config']:
#     best_params = results['best_config']['config']
#     
#     # Create configuration for final training
#     final_config = SimpleNamespace(
#         model_name='gin',
#         hidden_channels=best_params['hidden_channels'],
#         num_layers=best_params['num_layers'],
#         dropout=best_params['dropout'],
#         lr=best_params['lr'],
#         weight_decay=best_params['weight_decay'],
#         epochs=150,  # Train longer for final model
#         batch_size=32,
#         scheduler_type='onecycle',
#         use_wandb=True,
#         project='gnn-planar-graphs-final',
#         experiment_name='best_model_from_sweep',
#         in_channels=dataset[0].x.shape[1]
#     )
#     
#     # Train final model
#     from training_utils import train
#     final_results = train(final_config, dataset)
#     
#     print(f"Final model validation accuracy: {final_results['best_val_acc']:.4f}")
#     
#     # Save the final model
#     torch.save(final_results['model_state'], 'best_model_from_sweep.pt')