In [1]:
import os

import optuna
from optuna.trial import TrialState
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from torchvision import datasets
from torchvision import transforms

In [4]:
DEVICE = torch.device("cuda:1")
BATCHSIZE = 128
CLASSES = 10
DIR = os.getcwd()
EPOCHS = 10
N_TRAIN_EXAMPLES = BATCHSIZE * 100
N_VALID_EXAMPLES = BATCHSIZE * 10

In [5]:
def define_model(trial):
    
    # We optimize the number of layers, hidden units and dropout ratio in each layer.
    in_features = 28 * 28
    
    n_layers = trial.suggest_int("n_layers", 1, 3)
    layers = [nn.BatchNorm1d(in_features)]

    for i in range(n_layers):
        out_features = trial.suggest_int("n_units_l{}".format(i), 4, 128)
        layers.append(nn.Linear(in_features, out_features))
        layers.append(nn.ReLU())
        
        p = trial.suggest_float("dropout_l{}".format(i), 0.2, 0.5)
        layers.append(nn.Dropout(p))

        in_features = out_features
        
    layers.append(nn.Linear(in_features, CLASSES))
    layers.append(nn.LogSoftmax(dim=1))

    return nn.Sequential(*layers)


In [6]:
def get_data():
    # Load FashionMNIST dataset.
    train_loader = torch.utils.data.DataLoader(
        datasets.FashionMNIST(DIR, train=True, download=True, transform=transforms.ToTensor()),
        batch_size=BATCHSIZE,
        shuffle=True,
    )
    
    valid_loader = torch.utils.data.DataLoader(
        datasets.FashionMNIST(DIR, train=False, transform=transforms.ToTensor()),
        batch_size=BATCHSIZE,
        shuffle=True,
    )

    return train_loader, valid_loader

In [7]:
def objective(trial):
    # Generate the model.
    model = define_model(trial).to(DEVICE)

    # Generate the optimizers.
    optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "SGD"])
    lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
    optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)

    # Get the FashionMNIST dataset.
    train_loader, valid_loader = get_data()

    # Training of the model.
    for epoch in range(EPOCHS):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            # Limiting training data for faster epochs.
            if batch_idx * BATCHSIZE >= N_TRAIN_EXAMPLES:
                break

            data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE)

            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()

        # Validation of the model.
        model.eval()
        correct = 0
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(valid_loader):
                # Limiting validation data.
                if batch_idx * BATCHSIZE >= N_VALID_EXAMPLES:
                    break
                data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE)
                output = model(data)
                # Get the index of the max log-probability.
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()

        accuracy = correct / min(len(valid_loader.dataset), N_VALID_EXAMPLES)

        trial.report(accuracy, epoch)

        # Handle pruning based on the intermediate value.
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    return accuracy

In [10]:
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=100, timeout=60 * 20)

pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

print("Study statistics: ")
print("  Number of finished trials: ", len(study.trials))
print("  Number of pruned trials: ", len(pruned_trials))
print("  Number of complete trials: ", len(complete_trials))

print("Best trial:")
trial = study.best_trial

print("  Value: ", trial.value)

print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))

[32m[I 2023-04-09 18:02:21,583][0m Trial 51 finished with value: 0.8640625 and parameters: {'n_layers': 1, 'n_units_l0': 118, 'dropout_l0': 0.37399922345646414, 'optimizer': 'Adam', 'lr': 0.002372997806370044}. Best is trial 31 with value: 0.88203125.[0m
[32m[I 2023-04-09 18:02:24,541][0m Trial 52 pruned. [0m
[32m[I 2023-04-09 18:02:47,406][0m Trial 53 finished with value: 0.84296875 and parameters: {'n_layers': 1, 'n_units_l0': 113, 'dropout_l0': 0.4276668579818117, 'optimizer': 'Adam', 'lr': 0.004334417949556404}. Best is trial 31 with value: 0.88203125.[0m
[32m[I 2023-04-09 18:02:54,459][0m Trial 54 pruned. [0m
[32m[I 2023-04-09 18:02:56,746][0m Trial 55 pruned. [0m
[32m[I 2023-04-09 18:03:18,035][0m Trial 56 finished with value: 0.8625 and parameters: {'n_layers': 1, 'n_units_l0': 114, 'dropout_l0': 0.3468592853210527, 'optimizer': 'Adam', 'lr': 0.0024914252707798784}. Best is trial 31 with value: 0.88203125.[0m
[32m[I 2023-04-09 18:03:21,038][0m Trial 57 pruned

[32m[I 2023-04-09 18:12:31,379][0m Trial 129 pruned. [0m
[32m[I 2023-04-09 18:12:52,780][0m Trial 130 finished with value: 0.8453125 and parameters: {'n_layers': 1, 'n_units_l0': 71, 'dropout_l0': 0.32989641895627525, 'optimizer': 'Adam', 'lr': 0.002038621478892737}. Best is trial 31 with value: 0.88203125.[0m
[32m[I 2023-04-09 18:13:02,692][0m Trial 131 pruned. [0m
[32m[I 2023-04-09 18:13:05,353][0m Trial 132 pruned. [0m
[32m[I 2023-04-09 18:13:10,552][0m Trial 133 pruned. [0m
[32m[I 2023-04-09 18:13:14,016][0m Trial 134 pruned. [0m
[32m[I 2023-04-09 18:13:36,867][0m Trial 135 finished with value: 0.8609375 and parameters: {'n_layers': 1, 'n_units_l0': 114, 'dropout_l0': 0.263111838140669, 'optimizer': 'Adam', 'lr': 0.0022805903203947423}. Best is trial 31 with value: 0.88203125.[0m
[32m[I 2023-04-09 18:13:38,978][0m Trial 136 pruned. [0m
[32m[I 2023-04-09 18:13:41,877][0m Trial 137 pruned. [0m
[32m[I 2023-04-09 18:14:03,494][0m Trial 138 finished with val

Study statistics: 
  Number of finished trials:  151
  Number of pruned trials:  105
  Number of complete trials:  46
Best trial:
  Value:  0.88203125
  Params: 
    n_layers: 1
    n_units_l0: 107
    dropout_l0: 0.45003266637297173
    optimizer: Adam
    lr: 0.0009297941580693064


In [11]:
study.trials_dataframe()

Unnamed: 0,number,value,datetime_start,datetime_complete,duration,params_dropout_l0,params_dropout_l1,params_dropout_l2,params_lr,params_n_layers,params_n_units_l0,params_n_units_l1,params_n_units_l2,params_optimizer,state
0,0,0.832031,2023-04-09 17:34:57.750156,2023-04-09 17:35:26.840460,0 days 00:00:29.090304,0.219981,,,0.000447,1,36,,,Adam,COMPLETE
1,1,0.368750,2023-04-09 17:35:26.842764,2023-04-09 17:35:53.200163,0 days 00:00:26.357399,0.267891,0.499541,0.364719,0.002556,3,39,23.0,26.0,SGD,COMPLETE
2,2,0.846875,2023-04-09 17:35:53.203569,2023-04-09 17:36:15.504321,0 days 00:00:22.300752,0.468623,,,0.003582,1,53,,,Adam,COMPLETE
3,3,0.526563,2023-04-09 17:36:15.506507,2023-04-09 17:36:37.021371,0 days 00:00:21.514864,0.475179,0.471153,,0.001212,2,20,62.0,,SGD,COMPLETE
4,4,0.455469,2023-04-09 17:36:37.023656,2023-04-09 17:37:02.944402,0 days 00:00:25.920746,0.422846,,,0.000143,1,108,,,SGD,COMPLETE
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
146,146,0.511719,2023-04-09 18:15:11.401664,2023-04-09 18:15:13.857193,0 days 00:00:02.455529,0.265558,,,0.002197,1,120,,,SGD,PRUNED
147,147,0.826562,2023-04-09 18:15:13.859122,2023-04-09 18:15:23.725454,0 days 00:00:09.866332,0.270145,,,0.003728,1,116,,,Adam,PRUNED
148,148,0.828906,2023-04-09 18:15:23.727733,2023-04-09 18:15:29.050372,0 days 00:00:05.322639,0.242689,,,0.005053,1,123,,,Adam,PRUNED
149,149,0.878906,2023-04-09 18:15:29.052308,2023-04-09 18:15:53.309852,0 days 00:00:24.257544,0.255844,,,0.002631,1,125,,,Adam,COMPLETE
