In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import optuna
import os
import json
import copy
import joblib
import pandas as pd
import numpy as np
import time
from typing import Dict, Any, Optional, Union, List, Tuple, Type

from constants import DEVICE
from data_handling import DataHandler
from metrics import weighted_cross_entropy_loss

Using MPS device (Apple Silicon GPU)


In [2]:
def suggest_mlp_params(trial: optuna.trial.Trial, depth: int):
    """Suggests learning rate, weight decay, hidden layer sizes, and dropout rates for an Optuna trial for an MLP."""
    params = {}
    params['learning_rate'] = trial.suggest_float('learning_rate', 1e-4, 1e-1, log=True)
    params['weight_decay'] = trial.suggest_float('weight_decay', 1e-6, 1e-2, log=True)
    for i in range(1, depth + 1):
        params[f'n_hidden_{i}'] = trial.suggest_int(f'n_hidden_{i}', 8, 128, step=8)
        params[f'dropout_rate_{i}'] = trial.suggest_float(f'dropout_rate_{i}', 0.0, 0.5, log=False)
    return params

def build_network(input_dim: int, depth: int, hparams: Dict[str, Any]):
    """Returns an MLP (nn.Sequential) with the specified input dimension, depth, and hyperparameters."""
    layers = []
    current_dim = input_dim

    for i in range(1, depth + 1):
        n_hidden = hparams[f"n_hidden_{i}"]
        layers.append(nn.Linear(current_dim, n_hidden))
        layers.append(nn.ReLU()) 

        dropout_rate = hparams.get(f"dropout_rate_{i}", 0.0)
        if dropout_rate > 0.0:
            layers.append(nn.Dropout(dropout_rate))
        current_dim = n_hidden

    layers.append(nn.Linear(current_dim, 4))
    layers.append(nn.Softmax(dim=1))
    return nn.Sequential(*layers)

In [3]:
def train_and_validate_one_fold(
        model: nn.Module,
        train_loader: torch.utils.data.DataLoader,
        val_loader: torch.utils.data.DataLoader,
        optimizer: optim.Optimizer,
        num_epochs: int
        ):
    """Trains and validates the model on one fold of the data."""
    # Training phase
    for _ in range(num_epochs):
        model.train()
        for features, targets, weights in train_loader:
            features, targets, weights = features.to(DEVICE), targets.to(DEVICE), weights.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(features)
            loss = weighted_cross_entropy_loss(outputs, targets, weights)
            loss.backward()
            optimizer.step()

    # Validation phase
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for features, targets, weights in val_loader:
            features, targets, weights = features.to(DEVICE), targets.to(DEVICE), weights.to(DEVICE)
            outputs = model(features)
            loss = weighted_cross_entropy_loss(outputs, targets, weights)
            val_loss += loss.item()

    return val_loss/ len(val_loader)

In [4]:
def objective(trial: optuna.trial.Trial, 
              input_dim: int,  
              depth: int, 
              dataloaders: List[Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]]):
    # Get hyperparams and build models as before
    hparams = suggest_mlp_params(trial, depth)
    max_epochs = 100  # Use a larger number, ASHA will prune inefficient ones
    
    # Create models for each fold
    models = {}
    for i in range(3):
        model_i = copy.deepcopy(build_network(input_dim, depth, hparams))
        optimizer_i = optim.AdamW(model_i.parameters(),
                                 lr=hparams['learning_rate'],
                                 weight_decay=hparams['weight_decay'])
        models[i] = (model_i, optimizer_i)
    
    # Train for multiple epochs
    for epoch in range(max_epochs):
        epoch_fold_losses = []
        
        # Train one epoch on each fold
        for fold_idx, (train_loader, val_loader) in enumerate(dataloaders):
            model, optimizer = models[fold_idx]
            
            # Train for ONE epoch
            model.to(DEVICE)
            model.train()
            for features, targets, weights in train_loader:
                features = features.to(DEVICE)
                targets = targets.to(DEVICE)
                weights = weights.to(DEVICE)
                optimizer.zero_grad()
                outputs = model(features)
                loss = weighted_cross_entropy_loss(outputs, targets, weights)
                loss.backward()
                optimizer.step()
            
            # Validate
            model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for features, targets, weights in val_loader:
                    features = features.to(DEVICE)
                    targets = targets.to(DEVICE)
                    weights = weights.to(DEVICE)
                    outputs = model(features)
                    loss = weighted_cross_entropy_loss(outputs, targets, weights)
                    val_loss += loss.item()
                    
            epoch_fold_losses.append(val_loss/len(val_loader))
        
        # Update current mean loss
        epoch_loss = np.mean(epoch_fold_losses)
        
        # Report and prune
        trial.report(epoch_loss, step=epoch)
        if trial.should_prune():
            raise optuna.TrialPruned()
    
    # Return the last epoch loss
    return epoch_loss

In [5]:
dh = DataHandler()
cv_dataloaders = dh.get_nn_data('cv', batch_size=256)

DataHandler initialized - Using 114 features - Test year: 2020


In [6]:
import optuna
# First, set up the ASHA pruner
asha_pruner = optuna.pruners.SuccessiveHalvingPruner(
    min_resource=10,        # Minimum number of steps before pruning
    reduction_factor=3,    # Reduction factor for successive halving
    min_early_stopping_rate=0
)

study = optuna.create_study(direction="minimize", pruner=asha_pruner)

[I 2025-05-10 13:41:24,100] A new study created in memory with name: no-name-5081d646-8924-41df-98bf-3c0a5ff14755


In [7]:
study.optimize(
    lambda trial: objective(trial,
                            input_dim=dh.input_dim,
                            depth=3,
                            dataloaders=cv_dataloaders),
    n_trials=100,  # Number of trials to run
    timeout=1200,   # Timeout in seconds
    n_jobs=-1,     # Use all available cores
)

[I 2025-05-10 13:42:12,602] Trial 2 pruned. 
[I 2025-05-10 13:42:13,699] Trial 3 pruned. 
[I 2025-05-10 13:42:15,335] Trial 5 pruned. 
[I 2025-05-10 13:42:16,279] Trial 6 pruned. 
[I 2025-05-10 13:42:25,681] Trial 0 pruned. 
[I 2025-05-10 13:42:35,371] Trial 7 pruned. 
[I 2025-05-10 13:42:58,439] Trial 8 pruned. 
[I 2025-05-10 13:42:59,884] Trial 9 pruned. 
[I 2025-05-10 13:43:03,079] Trial 11 pruned. 
[I 2025-05-10 13:43:43,624] Trial 14 pruned. 
[I 2025-05-10 13:43:45,970] Trial 4 pruned. 
[I 2025-05-10 13:43:49,491] Trial 16 pruned. 
[I 2025-05-10 13:44:34,215] Trial 18 pruned. 
[I 2025-05-10 13:45:07,235] Trial 15 pruned. 
[I 2025-05-10 13:45:39,771] Trial 13 pruned. 
[I 2025-05-10 13:45:50,415] Trial 17 pruned. 
[I 2025-05-10 13:45:59,701] Trial 19 pruned. 
[I 2025-05-10 13:46:45,768] Trial 22 pruned. 
[I 2025-05-10 13:47:14,620] Trial 21 pruned. 
[I 2025-05-10 13:47:51,754] Trial 25 pruned. 
[I 2025-05-10 13:47:57,107] Trial 23 pruned. 
[I 2025-05-10 13:47:59,892] Trial 26 pruned

In [29]:
dir(study)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_directions',
 '_get_trials',
 '_is_multi_objective',
 '_log_completed_trial',
 '_pop_waiting_trial_id',
 '_should_skip_enqueue',
 '_stop_flag',
 '_storage',
 '_study_id',
 '_thread_local',
 'add_trial',
 'add_trials',
 'ask',
 'best_params',
 'best_trial',
 'best_trials',
 'best_value',
 'direction',
 'directions',
 'enqueue_trial',
 'get_trials',
 'metric_names',
 'optimize',
 'pruner',
 'sampler',
 'set_metric_names',
 'set_system_attr',
 'set_user_attr',
 'stop',
 'study_name',
 'system_attrs',
 'tell',
 'trials',
 'trials_dataframe',
 'user_attrs']

In [31]:
cv_results = study.trials_dataframe()

In [33]:
cv_results

Unnamed: 0,number,value,datetime_start,datetime_complete,duration,params_dropout_rate_1,params_dropout_rate_2,params_dropout_rate_3,params_learning_rate,params_n_hidden_1,params_n_hidden_2,params_n_hidden_3,params_weight_decay,system_attrs_completed_rung_0,system_attrs_completed_rung_1,system_attrs_completed_rung_2,state
0,0,0.8551,2025-05-10 13:41:26.524868,2025-05-10 13:42:25.681876,0 days 00:00:59.157008,0.4596,0.3798,0.4893,0.0064,16,120,112,0.0012,0.8551,,,PRUNED
1,1,0.8458,2025-05-10 13:41:26.528398,2025-05-10 13:48:15.860357,0 days 00:06:49.331959,0.2697,0.4742,0.1764,0.0018,112,48,128,0.0001,0.8463,0.8466,0.8463,COMPLETE
2,2,0.8503,2025-05-10 13:41:26.530250,2025-05-10 13:42:12.602229,0 days 00:00:46.071979,0.0798,0.0642,0.1680,0.0002,112,96,120,0.0000,0.8503,,,PRUNED
3,3,0.8546,2025-05-10 13:41:26.532781,2025-05-10 13:42:13.699408,0 days 00:00:47.166627,0.4531,0.0280,0.3106,0.0175,120,72,8,0.0013,0.8546,,,PRUNED
4,4,0.8473,2025-05-10 13:41:26.539049,2025-05-10 13:43:45.970324,0 days 00:02:19.431275,0.2116,0.1461,0.0880,0.0104,128,32,72,0.0000,0.8458,0.8473,,PRUNED
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
72,72,0.8433,2025-05-10 13:57:26.809194,2025-05-10 14:03:08.476634,0 days 00:05:41.667440,0.0348,0.0608,0.1154,0.0017,32,128,128,0.0001,0.8448,0.8433,0.8434,COMPLETE
73,73,0.8443,2025-05-10 13:57:53.524313,2025-05-10 14:03:06.917164,0 days 00:05:13.392851,0.2082,0.1012,0.1049,0.0019,32,104,128,0.0002,0.8457,0.8437,0.8443,PRUNED
74,74,0.8473,2025-05-10 13:58:25.992187,2025-05-10 13:59:13.968227,0 days 00:00:47.976040,0.1617,0.3485,0.1025,0.0017,40,16,128,0.0001,0.8473,,,PRUNED
75,75,0.8448,2025-05-10 13:59:13.969216,2025-05-10 14:01:28.231832,0 days 00:02:14.262616,0.2102,0.1007,0.0541,0.0028,32,32,112,0.0000,0.8455,0.8448,,PRUNED


In [39]:
df# Plot parameter importance if you have matplotlib
import optuna.visualization as vis
import matplotlib.pyplot as plt
param_importances = vis.plot_param_importances(study)
optimization_history = vis.plot_optimization_history(study)
parallel_coordinate = vis.plot_parallel_coordinate(study)

In [41]:
# plot the figures
param_importances.show()
optimization_history.show()
parallel_coordinate.show()

In [42]:
import joblib

# Save the complete study to a file
study_filename = f"mlp_optuna_study_{time.strftime('%Y%m%d_%H%M%S')}.pkl"
joblib.dump(study, study_filename)
print(f"Study saved to {study_filename}")

Study saved to mlp_optuna_study_20250510_150441.pkl


In [43]:
# Later, you can load it back
loaded_study = joblib.load(study_filename)

In [44]:
import optuna
# First, set up the ASHA pruner
asha_pruner = optuna.pruners.SuccessiveHalvingPruner(
    min_resource=20,        # Minimum number of steps before pruning
    reduction_factor=2,    # Reduction factor for successive halving
    min_early_stopping_rate=0
)

In [45]:
study_mlp0 = optuna.create_study(direction="minimize", pruner=asha_pruner)

[I 2025-05-10 15:18:47,761] A new study created in memory with name: no-name-78db8b76-9125-4860-b0ff-bfecfd29d304


In [46]:
study_mlp0.optimize(
    lambda trial: objective(trial,
                            input_dim=dh.input_dim,
                            depth=0,
                            dataloaders=cv_dataloaders),
    n_trials=30,  # Number of trials to run
    timeout=600,   # Timeout in seconds
    n_jobs=-1,     # Use all available cores
)

[I 2025-05-10 15:19:44,594] Trial 0 pruned. 
[I 2025-05-10 15:19:46,948] Trial 6 pruned. 
[I 2025-05-10 15:19:56,699] Trial 4 pruned. 
[I 2025-05-10 15:20:20,745] Trial 3 pruned. 
[I 2025-05-10 15:20:31,092] Trial 8 pruned. 
[I 2025-05-10 15:20:40,195] Trial 7 pruned. 
[I 2025-05-10 15:20:56,048] Trial 10 pruned. 
[I 2025-05-10 15:21:31,698] Trial 13 pruned. 
[I 2025-05-10 15:21:54,212] Trial 14 pruned. 
[I 2025-05-10 15:22:27,438] Trial 2 finished with value: 0.8636572926472393 and parameters: {'learning_rate': 0.029889873696552514, 'weight_decay': 8.830553721022295e-06}. Best is trial 2 with value: 0.8636572926472393.
[I 2025-05-10 15:22:33,366] Trial 5 finished with value: 0.8487560978302588 and parameters: {'learning_rate': 0.004095416753618702, 'weight_decay': 1.325392561516055e-06}. Best is trial 5 with value: 0.8487560978302588.
[I 2025-05-10 15:23:05,379] Trial 11 pruned. 
[I 2025-05-10 15:23:11,838] Trial 17 pruned. 
[I 2025-05-10 15:23:26,516] Trial 1 pruned. 
[I 2025-05-10 1

In [47]:
study_mlp0.trials_dataframe()

Unnamed: 0,number,value,datetime_start,datetime_complete,duration,params_learning_rate,params_weight_decay,system_attrs_completed_rung_0,system_attrs_completed_rung_1,system_attrs_completed_rung_2,state
0,0,0.877499,2025-05-10 15:18:57.321908,2025-05-10 15:19:44.593940,0 days 00:00:47.272032,0.001836,1.9e-05,0.877499,,,PRUNED
1,1,0.850591,2025-05-10 15:18:57.333704,2025-05-10 15:23:26.516318,0 days 00:04:29.182614,0.019597,0.000772,0.859103,0.850641,0.850591,PRUNED
2,2,0.863657,2025-05-10 15:18:57.335416,2025-05-10 15:22:27.438276,0 days 00:03:30.102860,0.02989,9e-06,0.85964,0.867064,0.857138,COMPLETE
3,3,0.992468,2025-05-10 15:18:57.336431,2025-05-10 15:20:20.745331,0 days 00:01:23.408900,0.000267,0.005868,0.992468,,,PRUNED
4,4,1.105934,2025-05-10 15:18:57.337485,2025-05-10 15:19:56.699140,0 days 00:00:59.361655,0.095726,0.000589,1.105934,,,PRUNED
5,5,0.848756,2025-05-10 15:18:57.338264,2025-05-10 15:22:33.366451,0 days 00:03:36.028187,0.004095,1e-06,0.851878,0.848884,0.848719,COMPLETE
6,6,0.879264,2025-05-10 15:18:57.344447,2025-05-10 15:19:46.948447,0 days 00:00:49.604000,0.001905,2.3e-05,0.879264,,,PRUNED
7,7,0.862094,2025-05-10 15:18:57.345479,2025-05-10 15:20:40.195732,0 days 00:01:42.850253,0.010804,2e-06,0.852488,0.862094,,PRUNED
8,8,1.049048,2025-05-10 15:19:44.596212,2025-05-10 15:20:31.092185,0 days 00:00:46.495973,0.000104,2e-06,1.049048,,,PRUNED
9,9,0.847627,2025-05-10 15:19:46.949785,2025-05-10 15:23:38.316152,0 days 00:03:51.366367,0.004467,0.004788,0.85071,0.84779,0.847061,COMPLETE


In [48]:
# Save the complete study to a file
study_filename = f"mlp0_optuna_study_{time.strftime('%Y%m%d_%H%M%S')}.pkl"
joblib.dump(study, study_filename)
print(f"Study saved to {study_filename}")

Study saved to mlp0_optuna_study_20250510_155610.pkl


In [None]:
# First, set up the ASHA pruner
asha_pruner = optuna.pruners.SuccessiveHalvingPruner(
    min_resource=20,        # Minimum number of steps before pruning
    reduction_factor=2,    # Reduction factor for successive halving
    min_early_stopping_rate=0
)

In [None]:
study_mlp3 = optuna.create_study(direction="minimize", pruner=asha_pruner)
study_mlp3.optimize(
    lambda trial: objective(trial,
                            input_dim=dh.input_dim,
                            depth=3,
                            dataloaders=dh.get_nn_data('cv', batch_size=512),
                            max_epochs=100),
    n_trials=30,  # Number of trials to run
    timeout=600,   # Timeout in seconds
    n_jobs=-1,     # Use all available cores
)