# Fine Tuning with Optuna

Notebook inspired by [Hands-On Machine Learning with Scikit-Learn and PyTorch](https://www.oreilly.com/library/view/hands-on-machine-learning/9798341607972/).

## Image Classifier Code

In [None]:
import torch
import torchvision
import torchvision.transforms.v2 as T

# set device depending on what's available
if torch.cuda.is_available():
  device = 'cuda'
elif torch.backends.mps.is_available():
  device = 'mps'
else:
  device = 'cpu'


# create tensor object we'll transform FashionMNIST data to
toTensor = T.Compose([T.ToImage(), T.ToDtype(torch.float32, scale = True)])

# bring in train, test, valid data
train_and_valid_data = torchvision.datasets.FashionMNIST(
    root = 'datasets',
    train = True,
    download = True,
    transform = toTensor
)

test_data = torchvision.datasets.FashionMNIST(
    root = 'datasets',
    train = False,
    download = True,
    transform = toTensor
)

# reproducibility
torch.manual_seed(42)

# save back 5_000 from train to be reserved for validation
train_data, valid_data = torch.utils.data.random_split(
    train_and_valid_data,
    [55_000, 5_000]
)

from torch.utils.data import DataLoader

# create data loaders
train_loader = DataLoader(train_data, batch_size = 32, shuffle = True)
valid_loader = DataLoader(valid_data, batch_size = 32)
test_loader = DataLoader(test_data, batch_size = 32)

from torch import nn
# custom classification MLP w/ 2 hidden layers
class ImageClassifier(nn.Module):
  def __init__(self, n_inputs, n_hidden1, n_hidden2, n_classes):
    super().__init__()
    self.mlp = nn.Sequential(
        nn.Flatten(),
        nn.Linear(n_inputs, n_hidden1),
        nn.ReLU(),
        nn.Linear(n_hidden1, n_hidden2),
        nn.ReLU(),
        nn.Linear(n_hidden2, n_classes)
    )

  def forward(self, X):
    return self.mlp(X)

# train function to implement mb gd
def train_mbgd(model, optimizer, criterion, train_loader, n_epochs):
  model.train() # set training mode
  for epoch in range(n_epochs):
    total_loss = 0
    for X_batch, y_batch in train_loader:
      # get batch
      X_batch, y_batch = X_batch.to(device), y_batch.to(device)
      # mod pred
      y_pred = model(X_batch)
      # calc loss and tally
      loss = criterion(y_pred, y_batch)
      total_loss += loss.item()
      # calc grads and do step
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()

    mean_loss = total_loss / len(train_loader)
    if epoch % 10 == 0: # every ten epochs, print out loss
      print(f'Epoch {epoch + 1}, Loss: {mean_loss}')

## create evaluation function
def evaluate(model, data_loader, metric, aggregate = torch.mean):
  model.eval() # change model mode to evaluation (no gradient work)
  metrics = []

  with torch.no_grad():
    for X_batch, y_batch in data_loader:
      # move data to GPU / cuda
      X_batch, y_batch = X_batch.to(device), y_batch.to(device)
      y_pred = model(X_batch)
      metric_val = metric(y_pred, y_batch)
      metrics.append(metric_val)

  # retrun agg met over all batches
  return aggregate(torch.stack(metrics))

100%|██████████| 26.4M/26.4M [00:01<00:00, 16.2MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 275kB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 5.13MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 21.1MB/s]


## Setup

In [None]:
!pip install optuna

Collecting optuna
  Downloading optuna-4.6.0-py3-none-any.whl.metadata (17 kB)
Collecting colorlog (from optuna)
  Downloading colorlog-6.10.1-py3-none-any.whl.metadata (11 kB)
Downloading optuna-4.6.0-py3-none-any.whl (404 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m404.7/404.7 kB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading colorlog-6.10.1-py3-none-any.whl (11 kB)
Installing collected packages: colorlog, optuna
Successfully installed colorlog-6.10.1 optuna-4.6.0


In [None]:
!pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Downloading torchmetrics-1.8.2-py3-none-any.whl (983 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m17.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.15.2-py3-none-any.whl (29 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.15.2 torchmetrics-1.8.2


In [None]:
import optuna
import torchmetrics

# define function that will be called by Optuna
# function takes trial object and asks
# Optuna for hyperparam vals;
# these vals will be used to train model
def objective(trial):
  # setting log to true will have Optuna sample a much larger range of
  # values by using log distribution
  learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-1, log=True)
  n_hidden = trial.suggest_int('n_hidden', 20, 300)

  model = ImageClassifier(n_inputs = 1 * 28 * 28, n_hidden1 = n_hidden,
                          n_hidden2 = n_hidden, n_classes = 10).to(device)

  optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)

  criterion = nn.CrossEntropyLoss()

  train_mbgd(model, optimizer, criterion, train_loader, n_epochs = 100)

  # evaluate on validation set
  accuracy_val = evaluate(model, valid_loader,
                        lambda y_pred, y_batch: (y_pred.argmax(dim=1)
                        == y_batch).float().mean(),
                        aggregate = torch.mean)
  return accuracy_val

## Hyperparameter Tuning

In [None]:
# need to create study obj to begin tuning
torch.manual_seed(42)
# Tree-structured Parzen Estimator; sequential based optimization algo
sampler = optuna.samplers.TPESampler(seed=42)

# higher score better... therefore, maximize
study = optuna.create_study(sampler=sampler, direction="maximize")

study.optimize(objective, n_trials=5)

In [None]:
## look at best hyperparam found, and corresponding val accuracy
study.best_params

{'learning_rate': 0.00031489116479568613, 'n_hidden': 287}

In [None]:
study.best_value

0.8320063948631287

## Train Model on Full Training Data

In [None]:
# make data loader out of train_and_valid_data
train_loader = DataLoader(train_and_valid_data,
                          batch_size = 32,
                          shuffle = True)

In [None]:
new_model = ImageClassifier(n_inputs = 1 * 28 * 28, n_hidden1 = 287,
                          n_hidden2 = 287, n_classes = 10).to(device)

optimizer = torch.optim.SGD(new_model.parameters(), lr = 0.00031489116479568613)

criterion = nn.CrossEntropyLoss()

train_mbgd(new_model, optimizer, criterion, train_loader, n_epochs = 100)

Epoch 1, Loss: 2.271330679066976
Epoch 11, Loss: 0.9710857071558634
Epoch 21, Loss: 0.7233449332078298
Epoch 31, Loss: 0.6262222062905629
Epoch 41, Loss: 0.5675503201087316
Epoch 51, Loss: 0.5288081739107767
Epoch 61, Loss: 0.5019134725729625
Epoch 71, Loss: 0.48193759031295774
Epoch 81, Loss: 0.46643048729896547
Epoch 91, Loss: 0.4541492882847786


In [None]:
# evaluate model on test set
new_model.eval()

accuracy_test = evaluate(new_model, test_loader,
                        lambda y_pred, y_batch: (y_pred.argmax(dim=1)
                        == y_batch).float().mean(),
                        aggregate = torch.mean)

print(f'Accuracy on test set: {accuracy_test}')

Accuracy on test set: 0.8308705687522888


## Saving and Loading Models in PyTorch

In [None]:
## save model info in state dictionary
model_data = {
    'model_state_dict': new_model.state_dict(),
    'model_hyperparameters': {
        'n_inputs': 1 * 28 * 28,
        'n_hidden1': 287,
        'n_hidden2': 287,
        'n_classes': 10
    }
}

# save model
torch.save(model_data, 'model_fashion_mnist.pth')

In [None]:
# load in dictionary, construct model, and load state dictionary into mod
loaded_data = torch.load('model_fashion_mnist.pth', weights_only = True)

new_model = ImageClassifier(**loaded_data['model_hyperparameters'])
new_model.load_state_dict(loaded_data['model_state_dict'])

new_model.eval() # ready to rock and roll!

ImageClassifier(
  (mlp): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=784, out_features=287, bias=True)
    (2): ReLU()
    (3): Linear(in_features=287, out_features=287, bias=True)
    (4): ReLU()
    (5): Linear(in_features=287, out_features=10, bias=True)
  )
)