# GCN Hyperparameter Sweep with Optuna

This notebook performs systematic hyperparameter optimization for the GCN model using Optuna on Google Colab.

## Setup Instructions

1. **Mount Google Drive** (to access your project files)
2. **Install Dependencies** (Optuna if not already installed)
3. **Run the Sweep** (50 trials by default, ~3-5 hours)
4. **View Results** (visualizations and metrics)

## Step 1: Mount Google Drive and Set Up Paths

In [10]:
from google.colab import drive
#from google.colab import files
import os
from pathlib import Path

# Mount Google Drive
drive.mount('/content/drive')

# Set working directory to your project
project_dir = '/content/drive/My Drive/182-GNN_SAE'  # Adjust path if needed
os.chdir(project_dir)

print(f"Working directory: {os.getcwd()}")
print(f"\nDirectory contents:")
print(os.listdir('.'))

ValueError: mount failed

## Step 2: Install Dependencies

In [None]:
# Install required packages
!pip install optuna -q
!pip install torch-geometric -q

print("✓ Dependencies installed successfully")

## Step 3: Import Libraries and Set Up

In [None]:
import json
import os
from pathlib import Path
from typing import List, Tuple

import numpy as np
import pandas as pd
import optuna
from optuna.trial import Trial
from optuna.samplers import TPESampler
from optuna.pruners import MedianPruner
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import seaborn as sns

# Check GPU availability
print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## Step 4: Import Functions from gnn_train.py

In [None]:
# Import from gnn_train.py
from gnn_train import (
    GraphDataset,
    GCNModel,
    GNNTrainer,
    collate_fn,
    load_all_graphs,
    split_data
)

print("✓ Successfully imported all functions from gnn_train.py")

## Step 5: Define Objective Function

In [None]:
def objective(trial: Trial, train_loader: DataLoader, val_loader: DataLoader,
              test_loader: DataLoader, device: str, num_epochs: int = 100) -> float:
    """
    Optuna objective function to minimize validation loss.

    Args:
        trial: Optuna trial object
        train_loader: Training DataLoader
        val_loader: Validation DataLoader
        test_loader: Test DataLoader
        device: Device to train on
        num_epochs: Maximum number of epochs

    Returns:
        Best validation loss achieved
    """

    # Suggest hyperparameters
    hidden_dim = trial.suggest_int('hidden_dim', 16, 256, step=8)
    dropout = trial.suggest_float('dropout', 0.0, 0.5, step=0.05)
    learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-1, log=True)
    weight_decay = trial.suggest_float('weight_decay', 0.0, 1e-2, log=True)
    batch_size = trial.suggest_categorical('batch_size', [16, 32, 64, 128])
    mask_prob = trial.suggest_float('mask_prob', 0.1, 0.5, step=0.05)
    early_stopping_patience = trial.suggest_int('early_stopping_patience', 5, 30, step=5)

    # Create new model with suggested hyperparameters
    model = GCNModel(input_dim=2, hidden_dim=hidden_dim, output_dim=1, dropout=dropout)
    model = model.to(device)

    # Create new trainer with suggested learning rate and weight decay
    trainer = GNNTrainer(model, device=device, learning_rate=learning_rate,
                         weight_decay=weight_decay)

    # Recreate dataloaders with new batch size
    train_loader_new = DataLoader(
        train_loader.dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn
    )
    val_loader_new = DataLoader(
        val_loader.dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn
    )

    # Training loop with early stopping
    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(num_epochs):
        train_loss = trainer.train_epoch(train_loader_new)
        val_loss = trainer.validate(val_loader_new)

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= early_stopping_patience:
                break

        # Report intermediate value for pruning
        trial.report(val_loss, epoch)

        # Prune unpromising trials
        if trial.should_prune():
            raise optuna.TrialPruned()

    return best_val_loss

print("✓ Objective function defined")

## Step 6: Define Visualization Functions

In [None]:
def plot_optimization_history(trials_df: pd.DataFrame, output_dir: Path):
    """Plot optimization history (loss over trials)."""
    fig, ax = plt.subplots(figsize=(12, 6))

    trials_df_sorted = trials_df.sort_values('value').reset_index(drop=True)
    trials_df['best_value'] = trials_df['value'].cummin()

    ax.plot(range(len(trials_df)), trials_df['value'], 'o-', alpha=0.6, label='Trial Loss')
    ax.plot(range(len(trials_df)), trials_df['best_value'], 'r-', linewidth=2, label='Best Loss')

    ax.set_xlabel('Trial Number', fontsize=12)
    ax.set_ylabel('Validation Loss', fontsize=12)
    ax.set_title('Optimization History', fontsize=14, fontweight='bold')
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    output_path = output_dir / 'optimization_history.png'
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"  Saved: {output_path.name}")
    plt.show()
    plt.close()


def plot_parameter_distributions(trials_df: pd.DataFrame, output_dir: Path):
    """Plot distributions of hyperparameters colored by trial value."""
    param_cols = [col for col in trials_df.columns if col.startswith('params_')]
    param_names = [col.replace('params_', '') for col in param_cols]

    n_params = len(param_cols)
    n_cols = 3
    n_rows = (n_params + n_cols - 1) // n_cols

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 4 * n_rows))
    axes = axes.flatten()

    for idx, (param_col, param_name) in enumerate(zip(param_cols, param_names)):
        ax = axes[idx]

        unique_vals = trials_df[param_col].nunique()
        is_categorical = unique_vals <= 10

        if is_categorical:
            data = trials_df.groupby(param_col)['value'].agg(['mean', 'std', 'count'])
            x_labels = [str(x) for x in data.index]
            y_vals = data['mean'].values
            y_errs = data['std'].values

            ax.bar(range(len(x_labels)), y_vals, yerr=y_errs, capsize=5, alpha=0.7)
            ax.set_xticks(range(len(x_labels)))
            ax.set_xticklabels(x_labels, rotation=45)
            ax.set_ylabel('Mean Validation Loss', fontsize=10)
            ax.set_title(f'{param_name}', fontsize=11, fontweight='bold')

        else:
            scatter = ax.scatter(
                trials_df[param_col],
                trials_df['value'],
                c=trials_df['value'],
                cmap='viridis',
                alpha=0.6,
                s=50
            )
            ax.set_xlabel(param_name, fontsize=10)
            ax.set_ylabel('Validation Loss', fontsize=10)
            ax.set_title(f'{param_name}', fontsize=11, fontweight='bold')
            plt.colorbar(scatter, ax=ax, label='Loss')

        ax.grid(True, alpha=0.3)

    for idx in range(n_params, len(axes)):
        fig.delaxes(axes[idx])

    plt.tight_layout()
    output_path = output_dir / 'parameter_distributions.png'
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"  Saved: {output_path.name}")
    plt.show()
    plt.close()


def plot_parameter_correlations(trials_df: pd.DataFrame, output_dir: Path):
    """Plot correlation between parameters and validation loss."""
    param_cols = [col for col in trials_df.columns if col.startswith('params_')]
    param_names = [col.replace('params_', '') for col in param_cols]

    corr_data = []
    for param_col, param_name in zip(param_cols, param_names):
        unique_vals = trials_df[param_col].nunique()
        if unique_vals > 10:
            corr = trials_df[param_col].corr(trials_df['value'])
            corr_data.append({'Parameter': param_name, 'Correlation': corr})

    if not corr_data:
        return

    corr_df = pd.DataFrame(corr_data).sort_values('Correlation')

    fig, ax = plt.subplots(figsize=(10, 6))
    colors = ['red' if x < 0 else 'green' for x in corr_df['Correlation']]
    ax.barh(corr_df['Parameter'], corr_df['Correlation'], color=colors, alpha=0.7)
    ax.set_xlabel('Correlation with Validation Loss', fontsize=12)
    ax.set_title('Parameter Importance (Correlation)', fontsize=14, fontweight='bold')
    ax.axvline(x=0, color='black', linestyle='-', linewidth=0.5)
    ax.grid(True, alpha=0.3, axis='x')

    plt.tight_layout()
    output_path = output_dir / 'parameter_correlations.png'
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"  Saved: {output_path.name}")
    plt.show()
    plt.close()


def plot_hyperparameter_heatmap(trials_df: pd.DataFrame, output_dir: Path):
    """Plot heatmap of hyperparameters for top trials."""
    top_n = 10
    param_cols = [col for col in trials_df.columns if col.startswith('params_')]

    if not param_cols:
        return

    param_names = [col.replace('params_', '') for col in param_cols]

    top_trials = trials_df.nsmallest(top_n, 'value')[param_cols].copy()
    top_trials.columns = param_names

    for col in top_trials.columns:
        min_val = trials_df[f'params_{col}'].min()
        max_val = trials_df[f'params_{col}'].max()
        if max_val > min_val:
            top_trials[col] = (top_trials[col] - min_val) / (max_val - min_val)

    fig, ax = plt.subplots(figsize=(12, 6))
    sns.heatmap(
        top_trials.T,
        annot=False,
        cmap='RdYlGn',
        cbar_kws={'label': 'Normalized Value'},
        ax=ax
    )
    ax.set_title(f'Top {top_n} Trials Hyperparameters (Normalized)', fontsize=14, fontweight='bold')
    ax.set_xlabel('Trial Rank (Best → Worst)', fontsize=12)

    plt.tight_layout()
    output_path = output_dir / 'top_trials_heatmap.png'
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"  Saved: {output_path.name}")
    plt.show()
    plt.close()


def plot_parameter_interactions(trials_df: pd.DataFrame, output_dir: Path):
    """Plot interactions between key hyperparameters."""
    pairs = [
        ('params_hidden_dim', 'params_dropout'),
        ('params_learning_rate', 'params_weight_decay'),
        ('params_hidden_dim', 'params_learning_rate'),
        ('params_dropout', 'params_mask_prob'),
    ]

    valid_pairs = [(p1, p2) for p1, p2 in pairs if p1 in trials_df.columns and p2 in trials_df.columns]

    if not valid_pairs:
        return

    n_pairs = len(valid_pairs)
    n_cols = 2
    n_rows = (n_pairs + n_cols - 1) // n_cols

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(14, 5 * n_rows))
    axes = axes.flatten()

    for idx, (param1, param2) in enumerate(valid_pairs):
        ax = axes[idx]

        scatter = ax.scatter(
            trials_df[param1],
            trials_df[param2],
            c=trials_df['value'],
            cmap='viridis',
            s=100,
            alpha=0.6,
            edgecolors='black',
            linewidth=0.5
        )

        param1_name = param1.replace('params_', '')
        param2_name = param2.replace('params_', '')

        ax.set_xlabel(param1_name, fontsize=11)
        ax.set_ylabel(param2_name, fontsize=11)
        ax.set_title(f'{param1_name} vs {param2_name}', fontsize=12, fontweight='bold')
        ax.grid(True, alpha=0.3)

        cbar = plt.colorbar(scatter, ax=ax)
        cbar.set_label('Val Loss', fontsize=10)

    for idx in range(len(valid_pairs), len(axes)):
        fig.delaxes(axes[idx])

    plt.tight_layout()
    output_path = output_dir / 'parameter_interactions.png'
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"  Saved: {output_path.name}")
    plt.show()
    plt.close()


def plot_loss_distribution(trials_df: pd.DataFrame, output_dir: Path):
    """Plot distribution of validation losses."""
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    axes[0].hist(trials_df['value'], bins=20, alpha=0.7, edgecolor='black')
    axes[0].axvline(trials_df['value'].min(), color='red', linestyle='--', linewidth=2,
                    label=f'Best: {trials_df["value"].min():.4f}')
    axes[0].axvline(trials_df['value'].mean(), color='green', linestyle='--', linewidth=2,
                    label=f'Mean: {trials_df["value"].mean():.4f}')
    axes[0].set_xlabel('Validation Loss', fontsize=12)
    axes[0].set_ylabel('Frequency', fontsize=12)
    axes[0].set_title('Loss Distribution', fontsize=13, fontweight='bold')
    axes[0].legend(fontsize=10)
    axes[0].grid(True, alpha=0.3, axis='y')

    box_data = trials_df['value'].values
    bp = axes[1].boxplot(box_data, vert=True, patch_artist=True)
    bp['boxes'][0].set_facecolor('lightblue')
    axes[1].set_ylabel('Validation Loss', fontsize=12)
    axes[1].set_title('Loss Box Plot', fontsize=13, fontweight='bold')
    axes[1].grid(True, alpha=0.3, axis='y')
    axes[1].set_xticklabels(['All Trials'])

    stats_text = f"""Min: {trials_df['value'].min():.6f}
Max: {trials_df['value'].max():.6f}
Mean: {trials_df['value'].mean():.6f}
Std: {trials_df['value'].std():.6f}"""
    axes[1].text(1.3, trials_df['value'].mean(), stats_text, fontsize=10,
                verticalalignment='center', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

    plt.tight_layout()
    output_path = output_dir / 'loss_distribution.png'
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"  Saved: {output_path.name}")
    plt.show()
    plt.close()


def generate_visualizations(trials_df: pd.DataFrame, output_dir: Path):
    """Generate all visualizations from trials."""
    print("\nGenerating visualizations...")
    plot_optimization_history(trials_df, output_dir)
    plot_parameter_distributions(trials_df, output_dir)
    plot_parameter_correlations(trials_df, output_dir)
    plot_hyperparameter_heatmap(trials_df, output_dir)
    plot_parameter_interactions(trials_df, output_dir)
    plot_loss_distribution(trials_df, output_dir)
    print("Visualization generation complete!")

print("✓ All visualization functions defined")

## Step 7: Set Configuration Parameters

In [None]:
# Configuration
CONFIG = {
    'data_dir': 'virtual_graphs/data',
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'num_trials': 50,  # Increase for more thorough search
    'num_epochs': 100,  # Max epochs per trial
    'output_dir': 'outputs/hyperparameter_sweep',
    'seed': 42
}

print("Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

## Step 8: Load Data and Create Datasets

In [None]:
# Set seeds for reproducibility
np.random.seed(CONFIG['seed'])
torch.manual_seed(CONFIG['seed'])

# Create output directory
Path(CONFIG['output_dir']).mkdir(parents=True, exist_ok=True)

print("Loading graphs...")
graph_paths = load_all_graphs(CONFIG['data_dir'], single_motif_only=True)
print(f"Loaded {len(graph_paths)} graphs")

print("\nSplitting data...")
train_paths, val_paths, test_paths = split_data(
    graph_paths,
    train_ratio=0.8,
    val_ratio=0.1,
    stratify_by_motif=True
)
print(f"Train: {len(train_paths)}, Val: {len(val_paths)}, Test: {len(test_paths)}")

print("\nCreating datasets...")
train_dataset = GraphDataset(train_paths, mask_prob=0.3, seed=CONFIG['seed'])
val_dataset = GraphDataset(val_paths, mask_prob=0.3, seed=CONFIG['seed'])
test_dataset = GraphDataset(test_paths, mask_prob=0.3, seed=CONFIG['seed'])

print("\nCreating dataloaders...")
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

print("✓ Data loading complete")

## Step 9: Run Hyperparameter Sweep

**⏱️ This may take 3-5 hours for 50 trials**

In [None]:
print("=" * 80)
print("GCN HYPERPARAMETER SWEEP WITH OPTUNA")
print("=" * 80)
print(f"Running {CONFIG['num_trials']} trials with up to {CONFIG['num_epochs']} epochs each\n")

# Create Optuna study
sampler = TPESampler(seed=CONFIG['seed'])
pruner = MedianPruner(n_startup_trials=10, n_warmup_steps=0)

study = optuna.create_study(
    direction='minimize',
    sampler=sampler,
    pruner=pruner,
    study_name='gcn_optimization'
)

# Run optimization
study.optimize(
    lambda trial: objective(
        trial, train_loader, val_loader, test_loader, CONFIG['device'], CONFIG['num_epochs']
    ),
    n_trials=CONFIG['num_trials'],
    show_progress_bar=True
)

## Step 10: Display Results

In [None]:
print("\n" + "=" * 80)
print("OPTIMIZATION COMPLETE")
print("=" * 80)

best_trial = study.best_trial

print(f"\nBest Trial: {best_trial.number}")
print(f"Best Validation Loss: {best_trial.value:.6f}")
print("\nBest Hyperparameters:")
for key, value in best_trial.params.items():
    print(f"  {key}: {value}")

## Step 11: Save Results

In [None]:
output_path = Path(CONFIG['output_dir'])

print(f"\nSaving results to {CONFIG['output_dir']}...")

# Save best parameters
best_params_path = output_path / "best_params.json"
with open(best_params_path, 'w') as f:
    json.dump(best_trial.params, f, indent=2)
print(f"✓ Saved best params to {best_params_path}")

# Save study information
study_info_path = output_path / "study_info.json"
with open(study_info_path, 'w') as f:
    json.dump({
        'best_value': best_trial.value,
        'best_trial': best_trial.number,
        'n_trials': len(study.trials),
        'n_complete_trials': len([t for t in study.trials if t.state.name == 'COMPLETE'])
    }, f, indent=2)
print(f"✓ Saved study info to {study_info_path}")

# Save all trials to CSV
trials_df = study.trials_dataframe()
trials_csv_path = output_path / "trials.csv"
trials_df.to_csv(trials_csv_path, index=False)
print(f"✓ Saved {len(trials_df)} trials to {trials_csv_path}")

## Step 12: Generate Visualizations

In [None]:
generate_visualizations(trials_df, output_path)

## Step 13: Create Summary Report

In [None]:
# Create summary report
summary_path = output_path / "sweep_summary.txt"
with open(summary_path, 'w') as f:
    f.write("=" * 80 + "\n")
    f.write("GCN HYPERPARAMETER SWEEP SUMMARY\n")
    f.write("=" * 80 + "\n\n")

    f.write(f"Total Trials: {len(study.trials)}\n")
    f.write(f"Complete Trials: {len([t for t in study.trials if t.state.name == 'COMPLETE'])}\n")
    f.write(f"Pruned Trials: {len([t for t in study.trials if t.state.name == 'PRUNED'])}\n\n")

    f.write("BEST TRIAL RESULTS\n")
    f.write("-" * 80 + "\n")
    f.write(f"Trial Number: {best_trial.number}\n")
    f.write(f"Best Validation Loss: {best_trial.value:.6f}\n\n")

    f.write("Best Hyperparameters:\n")
    for key, value in best_trial.params.items():
        f.write(f"  {key}: {value}\n")

    f.write("\n" + "=" * 80 + "\n")
    f.write("HYPERPARAMETER SEARCH SPACE\n")
    f.write("=" * 80 + "\n\n")

    f.write("hidden_dim: [16, 256] (step 8)\n")
    f.write("dropout: [0.0, 0.5] (step 0.05)\n")
    f.write("learning_rate: [1e-5, 1e-1] (log scale)\n")
    f.write("weight_decay: [0.0, 1e-2] (log scale)\n")
    f.write("batch_size: {16, 32, 64, 128}\n")
    f.write("mask_prob: [0.1, 0.5] (step 0.05)\n")
    f.write("early_stopping_patience: [5, 30] (step 5)\n")

print(f"✓ Saved summary to {summary_path}")

print("\n" + "=" * 80)
print("SWEEP COMPLETE!")
print("=" * 80)
print(f"\nAll results saved to: {output_path}")

## Step 14: Display Top 5 Trials

In [None]:
print("\nTOP 5 TRIALS:\n")
top_5 = trials_df.nsmallest(5, 'value')
for idx, (_, row) in enumerate(top_5.iterrows(), 1):
    print(f"Rank {idx}: Loss = {row['value']:.6f}")
    param_cols = [col for col in row.index if col.startswith('params_')]
    for col in param_cols:
        param_name = col.replace('params_', '')
        print(f"  {param_name}: {row[col]}")
    print()

## Step 15: Download Results

Your results have been saved to Google Drive and are ready to download!

In [None]:
print("Results location:")
print(f"  Local: {output_path}")
print(f"  Google Drive: {CONFIG['output_dir']}")
print("\nGenerated files:")
for file in sorted(output_path.iterdir()):
    print(f"  - {file.name}")