In [1]:
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import optuna
import os
import json
import copy
import joblib
import pandas as pd
import numpy as np
import time
from typing import Dict, Any, Optional, Union, List, Tuple, Type

from constants import DEVICE
from data_handling import DataHandler
from constants import RESULTS_DIR, MODELS_DIR, PREDS_DIR

Using MPS device (Apple Silicon GPU)


In [11]:
#load the mlp3 study
study = joblib.load('mlp_studies/mlp3_optuna_study_20250511_234516.pkl')

In [12]:
trial = study.best_trial
intermediate_vals = list(trial.intermediate_values.values())
best_epoch = int(np.argmin(intermediate_vals))+1
best_epoch

130

In [14]:
len(intermediate_vals)

200

In [10]:
len(intermediate_vals)

128

In [9]:
intermediate_vals

[0.5598997306823731,
 0.557727247873942,
 0.5549625269571941,
 0.552809648513794,
 0.5535721333821615,
 0.5533955510457357,
 0.5520859146118164,
 0.551773370107015,
 0.5509164428710938,
 0.5495435333251953,
 0.5500985590616861,
 0.5497001520792644,
 0.5489520359039307,
 0.5491504192352294,
 0.5480689811706543,
 0.5491851870218912,
 0.5484400240580241,
 0.5493580691019694,
 0.5482187493642171,
 0.5487093194325765,
 0.5479621601104736,
 0.5476616986592611,
 0.5483583990732829,
 0.5488706843058269,
 0.5472586949666342,
 0.5479297637939453,
 0.5474513085683187,
 0.5474869696299235,
 0.5476428731282552,
 0.5471930821736654,
 0.5474745146433513,
 0.5473065821329752,
 0.5475208473205566,
 0.5470022296905518,
 0.5481533145904541,
 0.5467185052235921,
 0.547602513631185,
 0.5469236914316813,
 0.547269385655721,
 0.5467228762308757,
 0.5465819676717122,
 0.5471183109283447,
 0.5472512022654216,
 0.5467758433024088,
 0.547442553838094,
 0.5467590522766114,
 0.5474474175771078,
 0.5468031883239746

In [None]:
def weighted_cross_entropy_loss(outputs: torch.Tensor,
                                targets: torch.Tensor,
                                weights: torch.Tensor):
    """
    Calculates a custom weighted cross-entropy loss.
    Handles both standard batch inputs (2D tensors) and fold-batched inputs (3D tensors).
    
    For each sample (or sample within a fold):
    1. Scales the model outputs by the sum of the target probabilities for that sample (P(18plus)).
       This is because the targets are soft labels representing a subset of classes.
    2. Computes the cross-entropy: Sample_CE_Loss = - sum_k ( target_k * log(scaled_output_k) ).
    
    The overall loss is the weighted average of these sample CE losses.
    If inputs are fold-batched (3D), it computes the sum of the losses per fold.
    and then returns the mean of these per-fold losses.

    Args:
        outputs (torch.Tensor): Model predictions (probabilities).
                                Shape: [batch_size, num_classes] or [num_folds, batch_size_per_fold, num_classes].
        targets (torch.Tensor): Ground truth probabilities.
                                Shape: [batch_size, num_classes] or [num_folds, batch_size_per_fold, num_classes].
        weights (torch.Tensor): Sample weights ('P(C)').
                                Shape: [batch_size, 1] or [num_folds, batch_size_per_fold, 1].

    Returns:
        torch.Tensor: Scalar tensor representing the final loss.
    """
    # Ensure 3D for unified processing
    if outputs.ndim == 2:
        outputs = outputs.unsqueeze(0)
        targets = targets.unsqueeze(0)
        weights = weights.unsqueeze(0)

    # Scale outputs by P(18plus) and clamp to avoid log(0)
    tots = targets.sum(dim=2, keepdim=True) # Shape: (K, B, 1)
    outputs = outputs * tots
    outputs = torch.clamp(outputs, 1e-10, 1. - 1e-10) 

    # Tensors of shape (K, B, 1)
    sample_ce_loss = -torch.sum(targets * torch.log(outputs), dim=2, keepdim=True)
    weights_reshaped = weights.view_as(sample_ce_loss) 
    weighted_sample_ce_losses = sample_ce_loss * weights_reshaped

    # Tensors of shape (K, 1, 1)
    sum_weighted_losses_fold = weighted_sample_ce_losses.sum(dim=1, keepdim=True)
    sum_weights_fold = weights_reshaped.sum(dim=1, keepdim=True)
    loss_per_fold = sum_weighted_losses_fold / sum_weights_fold
    
    return loss_per_fold.sum()

class BatchedLinear(nn.Module):
    """
    Batched version of `nn.Linear`; handles `K`-fold batched inputs. It performs `K` independent linear transformations (one per fold) using batched matrix multiplication (`torch.bmm`). 

    Weights shape: `(K, out_features, in_features)`. 
    Biases shape: `(K, out_features)`.
    """
    def __init__(self, K: int, in_features: int, out_features: int, bias: bool = True):
        """Initialize K parallel linear layers with shared parameters structure."""
        super().__init__()
        self.K = K
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(K, out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(K, out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        """Initialize (fold-by-fold) the weights using Kaiming uniform and the biases within calculated bounds."""
        for k in range(self.K): 
            init.kaiming_uniform_(self.weight[k], a=0, mode='fan_in', nonlinearity='leaky_relu')
            if self.bias is not None:
                fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight[k])
                bound = 1 / (fan_in**0.5) if fan_in > 0 else 0
                init.uniform_(self.bias[k], -bound, bound)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply `K` parallel transformations to `K` batches. 
        Input shape: `(K, batch_size, in_features)`. 
        Output shape: `(K, batch_size, out_features)`.
        """
        output = torch.bmm(x, self.weight.transpose(1, 2))
        if self.bias is not None:
            output = output + self.bias.unsqueeze(1)
        return output

    def extra_repr(self) -> str:
        """Return string representation of layer parameters."""
        return f'K={self.K}, in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}'

   
def build_network(input_dim: int, 
                 depth: int, 
                 hparams: Dict[str, Any], 
                 K: int = 3, 
                 num_classes: int = 4):    
    layers = []
    current_dim = input_dim

    for i in range(1, depth + 1):
        n_hidden = hparams[f"n_hidden_{i}"]
        layers.append(BatchedLinear(K, current_dim, n_hidden))
        layers.append(nn.ReLU())
        dropout_rate = hparams.get(f"dropout_rate_{i}", 0.0)
        if dropout_rate > 0.0:
            layers.append(nn.Dropout(dropout_rate))
        current_dim = n_hidden
    
    layers.append(BatchedLinear(K, current_dim, num_classes))
    layers.append(nn.Softmax(dim=2))
    
    return nn.Sequential(*layers)
    

def objective_mlp(trial: optuna.trial.Trial, 
                  input_dim: int,
                  depth: int, 
                  dataloaders: List[Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]],
                  max_epochs: int = 100):
    """
    Optuna objective function for BatchedMLP hyperparameter optimization.
    
    Args:
        trial: Current Optuna trial
        input_dim: Set this as `dh.input_dim`
        depth: Number of hidden layers
        dataloaders: List of `(train_loader, val_loader)` tuples. Set this as `dh.get_nn_data('cv', batch_size)`
        max_epochs: Maximum training epochs
        
    Returns:
        Best validation loss achieved
    """
    
    params = {}
    params['learning_rate'] = trial.suggest_float('learning_rate', 1e-4, 1e-1, log=True)
    params['weight_decay'] = trial.suggest_float('weight_decay', 1e-6, 1e-2, log=True)
    for i in range(1, depth + 1):
        params[f'n_hidden_{i}'] = trial.suggest_int(f'n_hidden_{i}', 8, 128, step=8)
        params[f'dropout_rate_{i}'] = trial.suggest_float(f'dropout_rate_{i}', 0.0, 0.5, log=False)

    model = build_network(input_dim, depth, params).to(DEVICE)
    optimizer = optim.AdamW(model.parameters(),
                             lr=params['learning_rate'],
                             weight_decay=params['weight_decay'])

    train_loader_list = [dl_pair[0] for dl_pair in dataloaders]
    val_loader_list = [dl_pair[1] for dl_pair in dataloaders]
    num_batches = len(val_loader_list[0])

    best_epoch_loss = float('inf')      
    for epoch in range(max_epochs):
        model.train()
        for batched_train_data in zip(*train_loader_list):
            stacked_features = torch.stack([data_fold[0] for data_fold in batched_train_data], dim=0).to(DEVICE)
            stacked_targets  = torch.stack([data_fold[1] for data_fold in batched_train_data], dim=0).to(DEVICE)
            stacked_weights  = torch.stack([data_fold[2] for data_fold in batched_train_data], dim=0).to(DEVICE)
            
            optimizer.zero_grad()
            outputs = model(stacked_features)
            loss = weighted_cross_entropy_loss(outputs, stacked_targets, stacked_weights)
            loss.backward()
            optimizer.step()
            
        model.eval()
        epoch_loss = 0.0
        with torch.no_grad():
            for batched_val_data in zip(*val_loader_list):
                stacked_features = torch.stack([data_fold[0] for data_fold in batched_val_data], dim=0).to(DEVICE)
                stacked_targets  = torch.stack([data_fold[1] for data_fold in batched_val_data], dim=0).to(DEVICE)
                stacked_weights  = torch.stack([data_fold[2] for data_fold in batched_val_data], dim=0).to(DEVICE)
                
                outputs = model(stacked_features)
                loss = weighted_cross_entropy_loss(outputs, stacked_targets, stacked_weights)
                epoch_loss += loss.item()
        
        epoch_loss = epoch_loss / (3 * num_batches)
        best_epoch_loss = min(best_epoch_loss, epoch_loss)
        
        trial.report(epoch_loss, step=epoch)
        if trial.should_prune():
            raise optuna.TrialPruned()
    
    return best_epoch_loss

In [16]:
dh = DataHandler()

DataHandler initialized - Using 114 features - Test year: 2020


In [17]:
# First, set up the ASHA pruner
asha_pruner = optuna.pruners.SuccessiveHalvingPruner(
    min_resource=30,        # Minimum number of steps before pruning
    reduction_factor=2,    # Reduction factor for successive halving
    min_early_stopping_rate=0
)
study_mlp3 = optuna.create_study(direction="minimize", pruner=asha_pruner)

[I 2025-05-11 13:48:19,306] A new study created in memory with name: no-name-d0ca8efc-daec-48ba-8e1c-c2a704a35f4a


In [18]:
study_mlp3.optimize(
    lambda trial: objective_mlp(trial,
                            input_dim=dh.input_dim,
                            depth=3,
                            dataloaders=dh.get_nn_data('cv', batch_size=256),
                            max_epochs=128),
    n_trials=64,  # Number of trials to run
    timeout=3600,   # Timeout in 1 hour
    n_jobs=-1,     # Use all available cores
)

[I 2025-05-11 13:50:24,282] Trial 6 pruned. 
[I 2025-05-11 13:50:29,396] Trial 2 pruned. 
[I 2025-05-11 13:50:38,079] Trial 1 pruned. 
[I 2025-05-11 13:51:24,416] Trial 7 pruned. 
[I 2025-05-11 13:52:41,795] Trial 9 pruned. 
[I 2025-05-11 13:52:54,122] Trial 10 pruned. 
[I 2025-05-11 13:53:15,842] Trial 3 pruned. 
[I 2025-05-11 13:55:03,462] Trial 0 pruned. 
[I 2025-05-11 13:55:09,995] Trial 13 pruned. 
[I 2025-05-11 13:57:06,376] Trial 12 pruned. 
[I 2025-05-11 13:57:15,351] Trial 5 finished with value: 0.551027660369873 and parameters: {'learning_rate': 0.025117968808891406, 'weight_decay': 0.004281667872952683, 'n_hidden_1': 64, 'dropout_rate_1': 0.47056183175606797, 'n_hidden_2': 40, 'dropout_rate_2': 0.35572650549353874, 'n_hidden_3': 16, 'dropout_rate_3': 0.1556395922113895}. Best is trial 5 with value: 0.551027660369873.
[I 2025-05-11 13:59:36,494] Trial 8 finished with value: 0.5464740371704102 and parameters: {'learning_rate': 0.0025450777009106197, 'weight_decay': 3.799302204

In [19]:
# Save the complete study to a file
study_filename = f"mlp3_optuna_study_{time.strftime('%Y%m%d_%H%M%S')}.pkl"
joblib.dump(study_mlp3, study_filename)
print(f"Study saved to {study_filename}")

Study saved to mlp3_optuna_study_20250511_145318.pkl


In [None]:
# First, set up the ASHA pruner
asha_pruner = optuna.pruners.SuccessiveHalvingPruner(
    min_resource=20,        # Minimum number of steps before pruning
    reduction_factor=2,    # Reduction factor for successive halving
    min_early_stopping_rate=0
)
study_mlp0 = optuna.create_study(direction="minimize", pruner=asha_pruner)
study_mlp0.optimize(
    lambda trial: objective_mlp(trial,
                            input_dim=dh.input_dim,
                            depth=0,
                            dataloaders=dh.get_nn_data('cv', batch_size=256),
                            max_epochs=80),
    n_trials=32,  # Number of trials to run
    timeout=900,   # Timeout in 15 mins
    n_jobs=-1,     # Use all available cores
)

[I 2025-05-11 18:17:41,303] A new study created in memory with name: no-name-db62d198-214e-4eb1-8315-aba86bdad98b
[I 2025-05-11 18:18:34,057] Trial 2 pruned. 
[I 2025-05-11 18:18:35,780] Trial 0 pruned. 
[I 2025-05-11 18:18:36,162] Trial 7 pruned. 
[I 2025-05-11 18:18:37,505] Trial 1 pruned. 
[I 2025-05-11 18:18:39,308] Trial 6 pruned. 
[I 2025-05-11 18:18:41,788] Trial 3 pruned. 
[I 2025-05-11 18:19:25,967] Trial 8 pruned. 
[I 2025-05-11 18:19:31,493] Trial 10 pruned. 
[I 2025-05-11 18:19:33,528] Trial 11 pruned. 
[I 2025-05-11 18:19:35,354] Trial 12 pruned. 
[I 2025-05-11 18:19:40,972] Trial 13 pruned. 
[I 2025-05-11 18:20:20,288] Trial 14 pruned. 
[I 2025-05-11 18:20:25,454] Trial 9 pruned. 
[I 2025-05-11 18:21:03,168] Trial 4 finished with value: 0.5646484025319417 and parameters: {'learning_rate': 0.06595787221952261, 'weight_decay': 0.0017123065515180812}. Best is trial 4 with value: 0.5646484025319417.
[I 2025-05-11 18:21:04,709] Trial 5 finished with value: 0.5484834575653076 a

In [None]:
# Save the complete study to a file
study_filename = f"mlp0_optuna_study_f"mlp0_optuna_study_{time.strftime('%Y%m%d_%H%M%S')}.pkl""
joblib.dump(study_mlp0, study_filename)
print(f"Study saved to {study_filename}")