In [1]:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
import optuna
from tqdm import tqdm
from dataset import get_datasets
from model import get_model, print_model_size

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 

  @register_model
  @register_model
  @register_model
  @register_model
  @register_model


In [2]:
image_dims = (224, 224)

transform = transforms.Compose([
    transforms.Resize(image_dims),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4772, 0.4597, 0.4612], std=[0.2997, 0.2808, 0.2837])
])

train_dataset, val_dataset, test_dataset = get_datasets(transform)


In [4]:
n_train_examples = 30 * 64
n_val_examples = 10 * 64

criterion = nn.CrossEntropyLoss()

In [5]:
def get_dataloaders(trial):
    batch_size = trial.suggest_categorical('batch_size', [16, 32, 64])

    # Load train dataset
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

    # Load val dataset
    valid_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

    return train_loader, valid_loader, batch_size


In [6]:
def objective(trial):
    model = get_model().to(device)

    # Generate the optimizers.
    lr_slow = trial.suggest_float("lr_slow", 1e-5, 1e-1, log=True) # log=True, will use log scale to interplolate between
    weight_decay_slow = trial.suggest_float("weight_decay_slow", 1e-5, 1e-2)
    optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "SGD"])
    optimizer_slow = getattr(torch.optim, optimizer_name)([{'params':model.stages[3].parameters()},
                                                           {'params':model.norm.parameters()}], lr=lr_slow , weight_decay=weight_decay_slow)
    lr_fast = trial.suggest_float("lr_fast", 1e-4, 1, log=True) # log=True, will use log scale to interplolate between
    weight_decay_fast = trial.suggest_float("weight_decay_fast", 1e-5, 1e-2)
    optimizer_fast = getattr(torch.optim, optimizer_name)(model.head.parameters(), lr=lr_fast, weight_decay=weight_decay_fast)
    train_loader, valid_loader, batch_size = get_dataloaders(trial)
    
    epochs = 10
    # Training of the model.
    for epoch in tqdm(range(epochs)):
        model.train()
        for i, (images, labels) in enumerate(train_loader):
            # Limiting training data for faster epochs.
            if i * batch_size >= n_train_examples:
                break

            images = images.to(device)
            labels = labels.to(device)

            output = model(images)
            loss = criterion(output, labels)
            optimizer_slow.zero_grad()
            optimizer_fast.zero_grad()
            loss.backward()
            optimizer_slow.step()
            optimizer_fast.step()
            
        # Validation of the model.
        model.eval()
        correct = 0
        with torch.no_grad():
            for i, (images, labels) in enumerate(valid_loader):
                if i * batch_size >= n_train_examples:
                    break

                images = images.to(device)
                labels = labels.to(device)

                output = model(images)
                # Get the index of the max log-probability.
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(labels.view_as(pred)).sum().item()

        accuracy = correct / len(valid_loader.dataset)
        # report back to Optuna how far it is (epoch-wise) into the trial and how well it is doing (accuracy)
        trial.report(accuracy, epoch)
        # then, Optuna can decide if the trial should be pruned
        # Handle pruning based on the intermediate value.
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()
                
    return accuracy


In [7]:
# now we can run the experiment
sampler = optuna.samplers.TPESampler()
study = optuna.create_study(study_name="vehicle_damage", direction="maximize", sampler=sampler)
study.optimize(objective, n_trials=200, timeout=None)

pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]

print("Study statistics: ")
print(" Number of finished trials: ", len(study.trials))
print(" Number of pruned trials: ", len(pruned_trials))
print(" Number of complete trials: ", len(complete_trials))

print("Best trial:")
trial = study.best_trial

print(" Value: ", trial.value)

print(" Params: ")
for key, value in trial.params.items():
    print(" {}: {}".format(key, value))


[I 2024-08-03 20:01:29,179] A new study created in memory with name: vehicle_damage
100%|██████████| 10/10 [06:24<00:00, 38.42s/it]
[I 2024-08-03 20:07:53,902] Trial 0 finished with value: 0.8345035105315948 and parameters: {'lr_slow': 8.710956958575865e-05, 'weight_decay_slow': 0.00381441505452984, 'optimizer': 'SGD', 'lr_fast': 0.31590520301744557, 'weight_decay_fast': 0.002533576586905862, 'batch_size': 32}. Best is trial 0 with value: 0.8345035105315948.
100%|██████████| 10/10 [07:11<00:00, 43.14s/it]
[I 2024-08-03 20:15:05,891] Trial 1 finished with value: 0.8716148445336008 and parameters: {'lr_slow': 0.0008884901704341717, 'weight_decay_slow': 0.004813004245101346, 'optimizer': 'SGD', 'lr_fast': 0.5426775012659383, 'weight_decay_fast': 0.0015654556048191575, 'batch_size': 64}. Best is trial 1 with value: 0.8716148445336008.
100%|██████████| 10/10 [07:01<00:00, 42.14s/it]
[I 2024-08-03 20:22:07,829] Trial 2 finished with value: 0.8876629889669007 and parameters: {'lr_slow': 0.000

Study statistics: 
 Number of finished trials:  200
 Number of pruned trials:  143
 Number of complete trials:  57
Best trial:
 Value:  0.970912738214644
 Params: 
 lr_slow: 0.00013955902469821164
 weight_decay_slow: 0.0011343113514715552
 optimizer: Adam
 lr_fast: 0.009427355395105285
 weight_decay_fast: 0.001827833478313847
 batch_size: 32


In [8]:
optuna.visualization.plot_param_importances(study)

In [9]:
optuna.visualization.plot_contour(study, params=["batch_size", "lr_fast"])


In [10]:
optuna.visualization.plot_contour(study, params=["lr_slow", "lr_fast"])

In [11]:
optuna.visualization.plot_contour(study, params=["lr_slow", "weight_decay_slow"])

In [12]:
optuna.visualization.plot_contour(study, params=["lr_fast", "weight_decay_fast"])

In [13]:
optuna.visualization.plot_contour(study, params=["batch_size", "optimizer"])
