### Import the relevant modules

In [None]:
import torch
from torch.optim import Adam
import optuna
from src.utils import(
    DEVICE,
    MNIST_Loader,
    train_loop
)

from src.models import (
    My_NN_Classifier
)



### Define the global variales

In [2]:
criterion = torch.nn.CrossEntropyLoss()
batch_size = 300
num_epochs = 15

### Define the objective function that runs the optuna study

In [3]:

def objective(trial:optuna.trial)-> list[torch.Tensor, torch.Tensor]:
    """
    Objective function for hyperparameter optimization using Optuna.
    
    Args:
        trial: Optuna trial object used to suggest hyperparameter values
        
    Returns:
        tuple: Contains validation accuracy and F1 score as (accuracy, f1)
        
    Hyperparameters optimized:
        - Learning rate (lr): Categorical choice of [0.01, 0.001, 0.0001] (ideally the lr should be searched using "trial.suggest_float" instead)
        - Hidden layer size (middle_chan): Integer in range [50,150] with step=50 
        - Network depth (depth): Integer in range [1,3] with step=1
    """
    # lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
    lr = trial.suggest_categorical("lr", [0.01, 0.001, 0.0001])
    hidden_size = trial.suggest_int("middle_chan", 50, 150, step=50) 
    depth = trial.suggest_int("depth", 1, 3, step=1) 


    model = My_NN_Classifier(in_chan=28**2,
                             middle_chan=hidden_size,
                             num_classes=10,
                             depth=depth)
    model = model.to(DEVICE)
    optimizer = Adam(model.parameters(), lr=lr)
    mnist_loader = MNIST_Loader(root_dir="./data",
                                batch_size_train=batch_size,
                                batch_size_val=batch_size,
                                num_workers=2)
    best_accuracy, best_f1 = train_loop(model=model,
                                        optimizer=optimizer,
                                        loader=mnist_loader,
                                        num_epochs=num_epochs,
                                        criterion=criterion)
    return best_accuracy, best_f1

### Initialize the search session

In [None]:
# initialize the search session
study = optuna.create_study(directions=["maximize", "maximize"],
                            storage="sqlite:///db.sqlite3", # SQLite database will persist the study
                            study_name="MNIST_study",
                            load_if_exists=True,
                            pruner=optuna.pruners.MedianPruner(
                                n_startup_trials=2,
                                n_warmup_steps=3,
                            ),
    )
study.set_metric_names(["Accuracy", "F1 Score"]) # F1 Score is useful for unbalanced datasets
study.optimize(objective, n_trials=20) 



### Plot the optimization history

In [9]:

optuna.visualization.plot_optimization_history(study,
                                               target=lambda t: t.values[0],
                                               target_name="Accuracy")

### Parallel plot for visualizing the relationship between the objectives and the search parameters

In [8]:
optuna.visualization.plot_parallel_coordinate(study,
                                              target=lambda t: t.values[0],
                                               target_name="Accuracy")

In [7]:
optuna.visualization.plot_parallel_coordinate(study,
                                              target=lambda t: t.values[0],
                                               target_name="F1 Score")

### Importance of the searched parameters in achieving the optimal result

In [6]:
optuna.visualization.plot_param_importances(study) # Plots importance based on variance


### Load the study if the session is expired

In [5]:
# To load this study later, use:
study = optuna.load_study(
    study_name="MNIST_study",
    storage="sqlite:///db.sqlite3",
    pruner=optuna.pruners.MedianPruner(
        n_startup_trials=2,
        n_warmup_steps=3,
        interval_steps=2
    ),
)

