# GCN Hyperparameter Sweep with Optuna + Motif Metrics

This notebook performs systematic hyperparameter optimization for the GCN model using Optuna on Google Colab, then trains the best model and computes motif-specific metrics.

## Setup Instructions

1. **Mount Google Drive** (to access your project files)
2. **Install Dependencies** (Optuna if not already installed)
3. **Run the Sweep** (50 trials by default, ~3-5 hours)
4. **Train Top 5 Models and Compute Motif Metrics** (~30 minutes)
5. **View Results** (visualizations and metrics)

## Step 1: Mount Google Drive and Set Up Paths

In [None]:
from google.colab import drive
import os
from pathlib import Path

# Mount Google Drive
drive.mount('/content/drive')

# Set working directory to your project
project_dir = '/content/drive/My Drive/182-GNN_SAE'  # Adjust path if needed
os.chdir(project_dir)

print(f"Working directory: {os.getcwd()}")
print(f"\nDirectory contents:")
print(os.listdir('.'))

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Working directory: /content/drive/My Drive/182-GNN_SAE

Directory contents:
['gnn_train.py', 'virtual_graphs', 'train_gnn.py', 'hyperparameter_sweep_colab.ipynb', 'hyperparameter_sweep_colab_v2.ipynb', '__pycache__', 'outputs']


## Step 2: Install Dependencies

In [None]:
# Install required packages
!pip install optuna -q
!pip install torch-geometric -q

print("✓ Dependencies installed successfully")

✓ Dependencies installed successfully


## Step 3: Import Libraries

In [None]:
import json
import os
from pathlib import Path
from typing import List, Tuple
from collections import defaultdict

import numpy as np
import pandas as pd
import optuna
from optuna.trial import Trial
from optuna.samplers import TPESampler
from optuna.pruners import MedianPruner
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import seaborn as sns

# Check GPU availability
print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

GPU Available: True
GPU Name: NVIDIA A100-SXM4-80GB
GPU Memory: 85.17 GB


## Step 4: Import Functions from gnn_train.py

In [None]:
# Import from gnn_train.py
from gnn_train import (
    GraphDataset,
    GCNModel,
    GNNTrainer,
    collate_fn,
    load_all_graphs,
    split_data,
)

from train_gnn import(
    compute_motif_metrics
)

print("✓ Successfully imported all functions from gnn_train.py")
print("✓ Successfully imported all functions from train_gnn.py")

✓ Successfully imported all functions from gnn_train.py
✓ Successfully imported all functions from train_gnn.py


## Step 5: Define Objective Function

In [None]:
def objective(trial: Trial, train_loader: DataLoader, val_loader: DataLoader,
              test_loader: DataLoader, device: str, num_epochs: int = 50) -> float:
    """
    Optuna objective function to minimize validation loss.
    """
    # Suggest hyperparameters
    hidden_dim = trial.suggest_int('hidden_dim', 16, 256, step=8)
    dropout = trial.suggest_float('dropout', 0.0, 0.5, step=0.05)
    learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-1, log=True)
    batch_size = trial.suggest_categorical('batch_size', [16, 32, 64, 128])
    mask_prob = 0.3
    #early_stopping_patience = trial.suggest_int('early_stopping_patience', 5, 30, step=5)
    early_stopping_patience = 20


    # Create model and trainer
    model = GCNModel(input_dim=2, hidden_dim=hidden_dim, output_dim=1, dropout=dropout)
    model = model.to(device)
    trainer = GNNTrainer(model, device=device, learning_rate=learning_rate)

    # Recreate dataloaders with new batch size
    train_loader_new = DataLoader(train_loader.dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader_new = DataLoader(val_loader.dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

    # Training loop
    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(num_epochs):
        train_loss = trainer.train_epoch(train_loader_new)
        val_loss = trainer.validate(val_loader_new)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= early_stopping_patience:
                break

        trial.report(val_loss, epoch)
        if trial.should_prune():
            raise optuna.TrialPruned()

    return best_val_loss

print("✓ Objective function defined")

✓ Objective function defined


## Step 6: Define Visualization Functions

In [26]:
def plot_loss_distribution(trials_df: pd.DataFrame, output_dir: Path):
    """Plot distribution of validation losses."""
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    axes[0].hist(trials_df['value'], bins=20, alpha=0.7, edgecolor='black')
    axes[0].axvline(trials_df['value'].min(), color='red', linestyle='--', linewidth=2,
                    label=f'Best: {trials_df["value"].min():.4f}')
    axes[0].axvline(trials_df['value'].mean(), color='green', linestyle='--', linewidth=2,
                    label=f'Mean: {trials_df["value"].mean():.4f}')
    axes[0].set_xlabel('Validation Loss', fontsize=12)
    axes[0].set_ylabel('Frequency', fontsize=12)
    axes[0].set_title('Loss Distribution', fontsize=13, fontweight='bold')
    axes[0].legend(fontsize=10)
    axes[0].grid(True, alpha=0.3, axis='y')

    box_data = trials_df['value'].values
    bp = axes[1].boxplot(box_data, vert=True, patch_artist=True)
    bp['boxes'][0].set_facecolor('lightblue')
    axes[1].set_ylabel('Validation Loss', fontsize=12)
    axes[1].set_title('Loss Box Plot', fontsize=13, fontweight='bold')
    axes[1].grid(True, alpha=0.3, axis='y')
    axes[1].set_xticklabels(['All Trials'])

    stats_text = f"""Min: {trials_df['value'].min():.6f}
Max: {trials_df['value'].max():.6f}
Mean: {trials_df['value'].mean():.6f}
Std: {trials_df['value'].std():.6f}"""
    axes[1].text(1.3, trials_df['value'].mean(), stats_text, fontsize=10,
                verticalalignment='center', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

    plt.tight_layout()
    output_path = output_dir / 'loss_distribution.png'
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"  Saved: {output_path.name}")
    plt.close()


def plot_hyperparameter_heatmap(trials_df: pd.DataFrame, output_dir: Path):
    """Plot heatmap of hyperparameters for top trials."""
    top_n = 10
    param_cols = [col for col in trials_df.columns if col.startswith('params_')]

    if not param_cols:
        return

    param_names = [col.replace('params_', '') for col in param_cols]

    top_trials = trials_df.nsmallest(top_n, 'value')[param_cols].copy()
    top_trials.columns = param_names

    for col in top_trials.columns:
        min_val = trials_df[f'params_{col}'].min()
        max_val = trials_df[f'params_{col}'].max()
        if max_val > min_val:
            top_trials[col] = (top_trials[col] - min_val) / (max_val - min_val)

    fig, ax = plt.subplots(figsize=(12, 6))
    sns.heatmap(
        top_trials.T,
        annot=False,
        cmap='RdYlGn',
        cbar_kws={'label': 'Normalized Value'},
        ax=ax
    )
    ax.set_title(f'Top {top_n} Trials Hyperparameters (Normalized)', fontsize=14, fontweight='bold')
    ax.set_xlabel('Trial Rank (Best → Worst)', fontsize=12)

    plt.tight_layout()
    output_path = output_dir / 'top_trials_heatmap.png'
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"  Saved: {output_path.name}")
    plt.close()


def generate_optuna_visualizations(study, output_dir: Path):
    """Generate Optuna's built-in visualization functions."""
    try:
        # Optimization history
        fig = optuna.visualization.plot_optimization_history(study)
        fig.write_html(str(output_dir / 'optuna_optimization_history.html'))
        print(f"  Saved: optuna_optimization_history.html")
    except Exception as e:
        print(f"  Warning: Could not generate optimization history: {e}")

    try:
        # Parameter importances
        fig = optuna.visualization.plot_param_importances(study)
        fig.write_html(str(output_dir / 'optuna_param_importances.html'))
        print(f"  Saved: optuna_param_importances.html")
    except Exception as e:
        print(f"  Warning: Could not generate param importances: {e}")

    try:
        # Slice plot (parameter distributions)
        fig = optuna.visualization.plot_slice(study)
        fig.write_html(str(output_dir / 'optuna_slice_plot.html'))
        print(f"  Saved: optuna_slice_plot.html")
    except Exception as e:
        print(f"  Warning: Could not generate slice plot: {e}")

    try:
        # Parallel coordinates (parameter interactions)
        fig = optuna.visualization.plot_parallel_coordinate(study)
        fig.write_html(str(output_dir / 'optuna_parallel_coordinates.html'))
        print(f"  Saved: optuna_parallel_coordinates.html")
    except Exception as e:
        print(f"  Warning: Could not generate parallel coordinates: {e}")

    try:
        # Contour plot (parameter interactions)
        fig = optuna.visualization.plot_contour(study)
        fig.write_html(str(output_dir / 'optuna_contour_plot.html'))
        print(f"  Saved: optuna_contour_plot.html")
    except Exception as e:
        print(f"  Warning: Could not generate contour plot: {e}")


def generate_visualizations(trials_df: pd.DataFrame, study, output_dir: Path):
    """Generate all visualizations from trials."""
    print("\nGenerating custom visualizations...")
    plot_loss_distribution(trials_df, output_dir)
    plot_hyperparameter_heatmap(trials_df, output_dir)

    print("\nGenerating Optuna visualizations...")
    generate_optuna_visualizations(study, output_dir)

print("✓ Visualization functions defined")

✓ Visualization functions defined


## Step 7: Configuration

In [None]:
CONFIG = {
    'data_dir': 'virtual_graphs/data',
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'num_trials': 50,
    'num_epochs': 50,
    'output_dir': 'outputs/hyperparameter_sweep_gcn',
    'seed': 42
}

print("Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

Configuration:
  data_dir: virtual_graphs/data
  device: cuda
  num_trials: 50
  num_epochs: 50
  output_dir: outputs2/hyperparameter_sweep
  seed: 42


## Step 8: Load Data

In [None]:
np.random.seed(CONFIG['seed'])
torch.manual_seed(CONFIG['seed'])
Path(CONFIG['output_dir']).mkdir(parents=True, exist_ok=True)

print("Loading data...")
graph_paths = load_all_graphs(CONFIG['data_dir'], single_motif_only=True)
print(f"Loaded {len(graph_paths)} graphs")

train_paths, val_paths, test_paths = split_data(graph_paths, train_ratio=0.8, val_ratio=0.1, stratify_by_motif=True)
print(f"Train: {len(train_paths)}, Val: {len(val_paths)}, Test: {len(test_paths)}")

train_dataset = GraphDataset(train_paths, mask_prob=0.3, seed=CONFIG['seed'])
val_dataset = GraphDataset(val_paths, mask_prob=0.3, seed=CONFIG['seed'])
test_dataset = GraphDataset(test_paths, mask_prob=0.3, seed=CONFIG['seed'])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

print("✓ Data loading complete")

Loading data...
Loaded 4000 graphs
Train: 3200, Val: 400, Test: 400
✓ Data loading complete


## Step 9: Run Hyperparameter Sweep

**⏱️ This may take 3-5 hours for 50 trials**

In [None]:
print("=" * 80)
print("GCN HYPERPARAMETER SWEEP WITH OPTUNA")
print("=" * 80)
print(f"Running {CONFIG['num_trials']} trials with up to {CONFIG['num_epochs']} epochs each\n")

sampler = TPESampler(seed=CONFIG['seed'])
pruner = MedianPruner(n_startup_trials=10, n_warmup_steps=0)

study = optuna.create_study(
    direction='minimize',
    sampler=sampler,
    pruner=pruner,
    study_name='gcn_optimization'
)

study.optimize(
    lambda trial: objective(trial, train_loader, val_loader, test_loader, CONFIG['device'], CONFIG['num_epochs']),
    n_trials=CONFIG['num_trials'],
    show_progress_bar=True
)

[I 2025-11-22 22:14:45,174] A new study created in memory with name: gcn_optimization


GCN HYPERPARAMETER SWEEP WITH OPTUNA
Running 50 trials with up to 50 epochs each



  0%|          | 0/50 [00:00<?, ?it/s]

[I 2025-11-22 22:52:12,047] Trial 0 finished with value: 0.003108785841614008 and parameters: {'hidden_dim': 104, 'dropout': 0.5, 'learning_rate': 0.008471801418819975, 'batch_size': 16}. Best is trial 0 with value: 0.003108785841614008.
[I 2025-11-22 23:01:22,717] Trial 1 finished with value: 0.0031664290357954227 and parameters: {'hidden_dim': 224, 'dropout': 0.30000000000000004, 'learning_rate': 0.006796578090758156, 'batch_size': 32}. Best is trial 0 with value: 0.003108785841614008.
[I 2025-11-22 23:10:38,070] Trial 2 finished with value: 0.09771811962127686 and parameters: {'hidden_dim': 56, 'dropout': 0.1, 'learning_rate': 0.00016480446427978953, 'batch_size': 128}. Best is trial 0 with value: 0.003108785841614008.
[I 2025-11-22 23:20:10,269] Trial 3 finished with value: 0.003149425252698935 and parameters: {'hidden_dim': 48, 'dropout': 0.15000000000000002, 'learning_rate': 0.0002920433847181409, 'batch_size': 32}. Best is trial 0 with value: 0.003108785841614008.
[I 2025-11-22 

## Step 10: Display and Save Results

In [27]:
print("\n" + "=" * 80)
print("OPTIMIZATION COMPLETE")
print("=" * 80)

best_trial = study.best_trial
print(f"\nBest Trial: {best_trial.number}")
print(f"Best Validation Loss: {best_trial.value:.6f}")
print("\nBest Hyperparameters:")
for key, value in best_trial.params.items():
    print(f"  {key}: {value}")

# Save results
output_path = Path(CONFIG['output_dir'])
print(f"\nSaving results to {CONFIG['output_dir']}...")

with open(output_path / "best_params.json", 'w') as f:
    json.dump(best_trial.params, f, indent=2)

with open(output_path / "study_info.json", 'w') as f:
    json.dump({
        'best_value': best_trial.value,
        'best_trial': best_trial.number,
        'n_trials': len(study.trials),
        'n_complete_trials': len([t for t in study.trials if t.state.name == 'COMPLETE'])
    }, f, indent=2)

trials_df = study.trials_dataframe()
trials_df.to_csv(output_path / "trials.csv", index=False)
print(f"✓ Saved results to {output_path}")


OPTIMIZATION COMPLETE

Best Trial: 17
Best Validation Loss: 0.002794

Best Hyperparameters:
  hidden_dim: 144
  dropout: 0.25
  learning_rate: 0.028505322089850224
  batch_size: 64

Saving results to outputs2/hyperparameter_sweep...
✓ Saved results to outputs2/hyperparameter_sweep


## Step 11: Generate Sweep Visualizations

In [28]:
generate_visualizations(trials_df, study, output_path)


Generating custom visualizations...
  Saved: loss_distribution.png
  Saved: top_trials_heatmap.png

Generating Optuna visualizations...
  Saved: optuna_optimization_history.html
  Saved: optuna_param_importances.html
  Saved: optuna_slice_plot.html
  Saved: optuna_parallel_coordinates.html
  Saved: optuna_contour_plot.html


## Step 12: Train Top 5 Models and Compute Motif Metrics

**⏱️ This takes ~2-3 hours for 5 models**

In [32]:
print("\n" + "=" * 80)
print("TRAINING TOP 5 TRIALS AND EXTRACTING BEST MODEL ARTIFACTS")
print("=" * 80)

# Get top 5 trials (best trial is the first one)
sorted_trials = sorted(study.trials, key=lambda t: t.value)[:5]
best_trial = sorted_trials[0]

print(f"\nBest Trial: Trial {best_trial.number} with Val Loss = {best_trial.value:.6f}")
print(f"\nTop 5 Trial Rankings:")
for idx, trial in enumerate(sorted_trials, 1):
    print(f"  {idx}. Trial {trial.number}: Val Loss = {trial.value:.6f}")

# Dictionary to store metrics for top 5 trials
top5_motif_metrics = {}

for trial_idx, trial in enumerate(sorted_trials):
    trial_name = f"Best_Trial_{best_trial.number}" if trial == best_trial else f"Trial_{trial.number}"
    print(f"\n--- Processing {trial_name} ---")

    # Extract hyperparameters for this trial
    trial_hidden_dim = int(trial.params['hidden_dim'])
    trial_dropout = float(trial.params['dropout'])
    trial_learning_rate = float(trial.params['learning_rate'])
    trial_batch_size = int(trial.params['batch_size'])
    trial_patience = 20

    # Create dataloaders with trial's batch size
    trial_train_loader = DataLoader(train_dataset, batch_size=trial_batch_size, shuffle=True, collate_fn=collate_fn)
    trial_val_loader = DataLoader(val_dataset, batch_size=trial_batch_size, shuffle=False, collate_fn=collate_fn)
    trial_test_loader = DataLoader(test_dataset, batch_size=trial_batch_size, shuffle=False, collate_fn=collate_fn)

    # Create and train model with this trial's hyperparameters
    trial_model = GCNModel(input_dim=2, hidden_dim=trial_hidden_dim, output_dim=1, dropout=trial_dropout)
    trial_model = trial_model.to(CONFIG['device'])
    trial_trainer = GNNTrainer(trial_model, device=CONFIG['device'], learning_rate=trial_learning_rate)

    # Training loop with early stopping
    trial_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(CONFIG['num_epochs']):
        train_loss = trial_trainer.train_epoch(trial_train_loader)
        val_loss = trial_trainer.validate(trial_val_loader)

        if val_loss < trial_val_loss:
            trial_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= trial_patience:
                break

    # Evaluate on test set
    test_loss = trial_trainer.validate(trial_test_loader)
    print(f"  Test Loss: {test_loss:.6f}")

    # Save model and activations only for the best (top-ranked) trial
    if trial == best_trial:
        print(f"  Saving best model and activations...")
        model_save_path = output_path / "best_model.pt"
        trial_trainer.save_model(str(model_save_path))
        print(f"  ✓ Saved best model to {model_save_path}")

        # Extract and save activations for all splits
        activations_output_dir = output_path
        trial_trainer.extract_and_save_activations(trial_train_loader, activations_output_dir, "train")
        trial_trainer.extract_and_save_activations(trial_val_loader, activations_output_dir, "val")
        trial_trainer.extract_and_save_activations(trial_test_loader, activations_output_dir, "test")
        print(f"  ✓ Saved activations to {activations_output_dir}/activations/")

    # Compute motif metrics for this trial using the imported function
    print(f"  Computing motif metrics...")
    train_motif_metrics = compute_motif_metrics(trial_trainer.model, CONFIG['device'], train_paths, CONFIG['seed'])
    val_motif_metrics = compute_motif_metrics(trial_trainer.model, CONFIG['device'], val_paths, CONFIG['seed'])
    test_motif_metrics = compute_motif_metrics(trial_trainer.model, CONFIG['device'], test_paths, CONFIG['seed'])

    top5_motif_metrics[f"Trial_{trial.number}"] = {
        'train': train_motif_metrics,
        'val': val_motif_metrics,
        'test': test_motif_metrics
    }

    print(f"  ✓ {trial_name} complete")

print("\n✓ All top 5 trials processed")


TRAINING TOP 5 TRIALS AND EXTRACTING BEST MODEL ARTIFACTS

Best Trial: Trial 17 with Val Loss = 0.002794

Top 5 Trial Rankings:
  1. Trial 17: Val Loss = 0.002794
  2. Trial 29: Val Loss = 0.002807
  3. Trial 14: Val Loss = 0.002850
  4. Trial 4: Val Loss = 0.002910
  5. Trial 40: Val Loss = 0.002941

--- Processing Best_Trial_17 ---
  Test Loss: 0.004829
  Saving best model and activations...
Model saved to outputs2/hyperparameter_sweep/best_model.pt
  ✓ Saved best model to outputs2/hyperparameter_sweep/best_model.pt


Extracting train activations: 100%|██████████| 50/50 [00:45<00:00,  1.10it/s]


Saved activations for 3200 graphs to outputs2/hyperparameter_sweep/activations/


Extracting val activations: 100%|██████████| 7/7 [00:05<00:00,  1.27it/s]


Saved activations for 400 graphs to outputs2/hyperparameter_sweep/activations/


Extracting test activations: 100%|██████████| 7/7 [00:05<00:00,  1.21it/s]


Saved activations for 400 graphs to outputs2/hyperparameter_sweep/activations/
  ✓ Saved activations to outputs2/hyperparameter_sweep/activations/
  Computing motif metrics...
  ✓ Best_Trial_17 complete

--- Processing Trial_29 ---
  Test Loss: 0.003626
  Computing motif metrics...
  ✓ Trial_29 complete

--- Processing Trial_14 ---
  Test Loss: 0.003567
  Computing motif metrics...
  ✓ Trial_14 complete

--- Processing Trial_4 ---
  Test Loss: 0.003019
  Computing motif metrics...
  ✓ Trial_4 complete

--- Processing Trial_40 ---
  Test Loss: 0.003409
  Computing motif metrics...
  ✓ Trial_40 complete

✓ All top 5 trials processed


## Step 13: Display Motif Metrics for Top 5 Trials

In [33]:
print("\n" + "=" * 80)
print("MOTIF-SPECIFIC METRICS (Top 5 Trials)")
print("=" * 80)

for trial_name, trial_data in top5_motif_metrics.items():
    print(f"\n\n{'=' * 80}")
    print(f"{trial_name.upper()}")
    print('=' * 80)

    for split_name, metrics in trial_data.items():
        print(f"\n{split_name.upper()} SET:")
        print("-" * 80)

        if not metrics:
            print("  No metrics computed")
            continue

        for motif_label in sorted(metrics.keys()):
            motif_data = metrics[motif_label]
            print(f"\n  {motif_label.upper()}:")
            print(f"    Mean MSE:      {motif_data['mean_mse']:.6f}")
            print(f"    Std MSE:       {motif_data['std_mse']:.6f}")


MOTIF-SPECIFIC METRICS (Top 5 Trials)


TRIAL_17

TRAIN SET:
--------------------------------------------------------------------------------

  CASCADE:
    Mean MSE:      0.004206
    Std MSE:       0.000689

  FEEDBACK_LOOP:
    Mean MSE:      0.003406
    Std MSE:       0.001056

  FEEDFORWARD_LOOP:
    Mean MSE:      0.004471
    Std MSE:       0.001333

  SINGLE_INPUT_MODULE:
    Mean MSE:      0.003493
    Std MSE:       0.001730

VAL SET:
--------------------------------------------------------------------------------

  CASCADE:
    Mean MSE:      0.003249
    Std MSE:       0.001119

  FEEDBACK_LOOP:
    Mean MSE:      0.004198
    Std MSE:       0.001324

  FEEDFORWARD_LOOP:
    Mean MSE:      0.004681
    Std MSE:       0.001986

  SINGLE_INPUT_MODULE:
    Mean MSE:      0.002414
    Std MSE:       0.001082

TEST SET:
--------------------------------------------------------------------------------

  CASCADE:
    Mean MSE:      0.004470
    Std MSE:       0.001283

  FEEDB

## Step 14: Save and Visualize Motif Metrics for Top 5 Trials

In [34]:
# Save motif metrics for top 5 trials
motif_metrics_path = output_path / "motif_metrics_top5_trials.json"
with open(motif_metrics_path, 'w') as f:
    json.dump(top5_motif_metrics, f, indent=2)
print(f"✓ Saved top 5 trials motif metrics to {motif_metrics_path}")

# Create visualizations by motif type (averaged over top 5 trials)
splits = ['train', 'val', 'test']

for split_name in splits:
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))

    # Get motif types from first trial
    first_trial_key = f"Trial_{sorted_trials[0].number}"
    motif_types = list(sorted(list(top5_motif_metrics[first_trial_key][split_name].keys())))

    # MSE by Motif Type
    ax = axes[0]
    x_pos = np.arange(len(motif_types))
    mse_means = []
    mse_stds = []

    for motif in motif_types:
        all_mses = []
        for trial in sorted_trials:
            trial_key = f"Trial_{trial.number}"
            if motif in top5_motif_metrics[trial_key][split_name]:
                all_mses.append(top5_motif_metrics[trial_key][split_name][motif]['mean_mse'])
        if all_mses:
            mse_means.append(np.mean(all_mses))
            mse_stds.append(np.std(all_mses))
        else:
            mse_means.append(0)
            mse_stds.append(0)

    ax.bar(x_pos, mse_means, yerr=mse_stds, capsize=5, color='steelblue', alpha=0.8, edgecolor='black')
    ax.set_xlabel('Motif Type', fontsize=12)
    ax.set_ylabel('MSE', fontsize=12)
    ax.set_title(f'MSE by Motif Type ({split_name.upper()}, Averaged over Top 5 Trials)', fontsize=13, fontweight='bold')
    ax.set_xticks(x_pos)
    ax.set_xticklabels(motif_types, rotation=45, ha='right')
    ax.grid(True, alpha=0.3, axis='y')

    # MAE by Motif Type
    ax = axes[1]
    mae_means = []
    mae_stds = []

    for motif in motif_types:
        all_maes = []
        for trial in sorted_trials:
            trial_key = f"Trial_{trial.number}"
            if motif in top5_motif_metrics[trial_key][split_name]:
                motif_data = top5_motif_metrics[trial_key][split_name][motif]
                # Calculate MAE from mean and std MSE if available
                if 'mean_mse' in motif_data:
                    all_maes.append(np.sqrt(motif_data['mean_mse']))
        if all_maes:
            mae_means.append(np.mean(all_maes))
            mae_stds.append(np.std(all_maes))
        else:
            mae_means.append(0)
            mae_stds.append(0)

    ax.bar(x_pos, mae_means, yerr=mae_stds, capsize=5, color='steelblue', alpha=0.8, edgecolor='black')
    ax.set_xlabel('Motif Type', fontsize=12)
    ax.set_ylabel('MAE', fontsize=12)
    ax.set_title(f'MAE by Motif Type ({split_name.upper()}, Averaged over Top 5 Trials)', fontsize=13, fontweight='bold')
    ax.set_xticks(x_pos)
    ax.set_xticklabels(motif_types, rotation=45, ha='right')
    ax.grid(True, alpha=0.3, axis='y')

    plt.tight_layout()
    motif_viz_path = output_path / f'motif_comparison_{split_name}.png'
    plt.savefig(motif_viz_path, dpi=300, bbox_inches='tight')
    print(f"✓ Saved {split_name} motif comparison to {motif_viz_path}")
    plt.close()

print("\n" + "=" * 80)
print("BEST MODEL MOTIF-SPECIFIC METRICS VISUALIZATIONS")
print("=" * 80)

# Create detailed visualizations for the best model only
best_trial_key = f"Trial_{best_trial.number}"
best_metrics = top5_motif_metrics[best_trial_key]

for split_name in splits:
    # Get motif types and their metrics for best model
    split_metrics = best_metrics[split_name]
    motif_types_best = sorted(split_metrics.keys())

    # Extract MSE and MAE values for best model
    mse_values = [split_metrics[motif]['mean_mse'] for motif in motif_types_best]
    mse_stds = [split_metrics[motif]['std_mse'] for motif in motif_types_best]
    mae_values = [np.sqrt(split_metrics[motif]['mean_mse']) for motif in motif_types_best]
    mae_stds = [np.sqrt(split_metrics[motif]['std_mse']) for motif in motif_types_best]

    # Create comprehensive figure for best model
    fig = plt.figure(figsize=(18, 10))
    gs = fig.add_gridspec(2, 2, hspace=0.3, wspace=0.3)

    # 1. MSE comparison bar chart
    ax1 = fig.add_subplot(gs[0, 0])
    x_pos = np.arange(len(motif_types_best))
    colors = plt.cm.Set3(np.linspace(0, 1, len(motif_types_best)))
    bars1 = ax1.bar(x_pos, mse_values, yerr=mse_stds, capsize=8, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)
    ax1.set_xlabel('Motif Type', fontsize=12, fontweight='bold')
    ax1.set_ylabel('MSE', fontsize=12, fontweight='bold')
    ax1.set_title(f'MSE by Motif Type - Best Model ({split_name.upper()})', fontsize=13, fontweight='bold')
    ax1.set_xticks(x_pos)
    ax1.set_xticklabels(motif_types_best, rotation=45, ha='right', fontsize=11)
    ax1.grid(True, alpha=0.3, axis='y')

    # Add value labels on bars
    for i, (bar, val) in enumerate(zip(bars1, mse_values)):
        ax1.text(bar.get_x() + bar.get_width()/2, val + mse_stds[i], f'{val:.4f}',
                ha='center', va='bottom', fontsize=9)

    # 2. MAE comparison bar chart
    ax2 = fig.add_subplot(gs[0, 1])
    bars2 = ax2.bar(x_pos, mae_values, yerr=mae_stds, capsize=8, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)
    ax2.set_xlabel('Motif Type', fontsize=12, fontweight='bold')
    ax2.set_ylabel('MAE', fontsize=12, fontweight='bold')
    ax2.set_title(f'MAE by Motif Type - Best Model ({split_name.upper()})', fontsize=13, fontweight='bold')
    ax2.set_xticks(x_pos)
    ax2.set_xticklabels(motif_types_best, rotation=45, ha='right', fontsize=11)
    ax2.grid(True, alpha=0.3, axis='y')

    # Add value labels on bars
    for i, (bar, val) in enumerate(zip(bars2, mae_values)):
        ax2.text(bar.get_x() + bar.get_width()/2, val + mae_stds[i], f'{val:.4f}',
                ha='center', va='bottom', fontsize=9)

    # 3. MSE and MAE side-by-side comparison
    ax3 = fig.add_subplot(gs[1, 0])
    width = 0.35
    x_pos_grouped = np.arange(len(motif_types_best))
    bars3a = ax3.bar(x_pos_grouped - width/2, mse_values, width, label='MSE', color='steelblue', alpha=0.8, edgecolor='black')
    bars3b = ax3.bar(x_pos_grouped + width/2, mae_values, width, label='MAE', color='coral', alpha=0.8, edgecolor='black')

    ax3.set_xlabel('Motif Type', fontsize=12, fontweight='bold')
    ax3.set_ylabel('Error Value', fontsize=12, fontweight='bold')
    ax3.set_title(f'MSE vs MAE Comparison - Best Model ({split_name.upper()})', fontsize=13, fontweight='bold')
    ax3.set_xticks(x_pos_grouped)
    ax3.set_xticklabels(motif_types_best, rotation=45, ha='right', fontsize=11)
    ax3.legend(fontsize=11, loc='upper left')
    ax3.grid(True, alpha=0.3, axis='y')

    # 4. Error statistics table
    ax4 = fig.add_subplot(gs[1, 1])
    ax4.axis('tight')
    ax4.axis('off')

    table_data = []
    table_data.append(['Motif Type', 'MSE', 'Std MSE', 'MAE', 'Std MAE'])
    for i, motif in enumerate(motif_types_best):
        table_data.append([
            motif,
            f'{mse_values[i]:.6f}',
            f'{mse_stds[i]:.6f}',
            f'{mae_values[i]:.6f}',
            f'{mae_stds[i]:.6f}'
        ])

    table = ax4.table(cellText=table_data, cellLoc='center', loc='center',
                     colWidths=[0.25, 0.18, 0.18, 0.18, 0.18])
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1, 2)

    # Style header row
    for i in range(len(table_data[0])):
        table[(0, i)].set_facecolor('#40466e')
        table[(0, i)].set_text_props(weight='bold', color='white')

    # Alternate row colors
    for i in range(1, len(table_data)):
        for j in range(len(table_data[0])):
            if i % 2 == 0:
                table[(i, j)].set_facecolor('#f0f0f0')
            else:
                table[(i, j)].set_facecolor('#ffffff')

    fig.suptitle(f'Best Model (Trial {best_trial.number}) - Motif-Specific Metrics\n{split_name.upper()} Split',
                fontsize=16, fontweight='bold', y=0.98)

    plt.savefig(output_path / f'best_model_motif_metrics_{split_name}.png', dpi=300, bbox_inches='tight')
    print(f"✓ Saved best model motif metrics visualization for {split_name}")
    plt.close()

# Save summary of best model (which is the top-ranked trial)
best_summary_path = output_path / "best_model_summary.json"
with open(best_summary_path, 'w') as f:
    json.dump({
        'best_trial_number': best_trial.number,
        'best_validation_loss': float(best_trial.value),
        'hyperparameters': best_trial.params,
        'motif_metrics': top5_motif_metrics[best_trial_key]
    }, f, indent=2)
print(f"✓ Saved best model summary to {best_summary_path}")

print("\n✓ All motif-specific metrics visualizations complete")

✓ Saved top 5 trials motif metrics to outputs2/hyperparameter_sweep/motif_metrics_top5_trials.json
✓ Saved train motif comparison to outputs2/hyperparameter_sweep/motif_comparison_train.png
✓ Saved val motif comparison to outputs2/hyperparameter_sweep/motif_comparison_val.png
✓ Saved test motif comparison to outputs2/hyperparameter_sweep/motif_comparison_test.png

BEST MODEL MOTIF-SPECIFIC METRICS VISUALIZATIONS
✓ Saved best model motif metrics visualization for train
✓ Saved best model motif metrics visualization for val
✓ Saved best model motif metrics visualization for test
✓ Saved best model summary to outputs2/hyperparameter_sweep/best_model_summary.json

✓ All motif-specific metrics visualizations complete


## Step 15: Download Results

In [35]:
print("\n" + "=" * 80)
print("RESULTS SUMMARY")
print("=" * 80)
print(f"\nResults location: {output_path}")
print(f"\nGenerated files:")
for file in sorted(output_path.iterdir()):
    print(f"  - {file.name}")

print("\n" + "=" * 80)
print("ALL TASKS COMPLETE!")
print("=" * 80)


RESULTS SUMMARY

Results location: outputs2/hyperparameter_sweep

Generated files:
  - activations
  - best_model.pt
  - best_model_motif_metrics_test.png
  - best_model_motif_metrics_train.png
  - best_model_motif_metrics_val.png
  - best_model_summary.json
  - best_params.json
  - loss_distribution.png
  - motif_comparison_test.png
  - motif_comparison_train.png
  - motif_comparison_val.png
  - motif_metrics_top5_trials.json
  - optuna_contour_plot.html
  - optuna_optimization_history.html
  - optuna_parallel_coordinates.html
  - optuna_param_importances.html
  - optuna_slice_plot.html
  - study_info.json
  - top_trials_heatmap.png
  - trials.csv

ALL TASKS COMPLETE!
