# Project I - Image Classification

**Team**: Filip Kołodziejczyk, Jerzy Kraszewski

## Introduction

The goal of this project is to create a model that can classify images of 10 different classes. The dataset used for this project is the CINIC-10 dataset, which is a combination of CIFAR-10 and ImageNet. The dataset contains 270,000 images, which are divided into 10 classes of equal size. The classes are: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, and truck. The images are 32x32 pixels in size and are in RGB format. Data is divided into training, validation, and test sets, equally for each class.
More details about the dataset can be found [here](https://datashare.ed.ac.uk/handle/10283/3192) and [here](https://www.kaggle.com/datasets/mengcius/cinic10/data).

TODO: Add citation for this dataset

## Environment setup

We load all the necessary libraries and set an appropriate backend for the PyTorch for most optimal performance.

In [16]:
import os
import shutil
import time
from zipfile import ZipFile

import pandas as pd
import timm
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision as tv
from IPython.display import display
from torch.utils.data import DataLoader, Subset, default_collate
from torchvision.transforms import v2 as T
from tqdm import tqdm

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

colab = "COLAB_GPU" in os.environ
if colab:
    from google.colab import drive

    drive.mount("/content/drive")

print(f"Using device: {device}")

Using device: mps


## Extracting and adjusting data split

The original dataset has predefined split of data. We adjust it there.

In [18]:
# Dataset must be downloaded from the link provided in Introduction and put into `data` directory.
# It should be renamed to `cinic10.zip`.

archive_path = "data/cinic10.zip" if not colab else "/content/drive/MyDrive/cinic10.zip"
data_dir = "data/cinic10" if not colab else "/content/cinic10"
data_subdirs = ["train", "test", "valid"]
props = [0.7, 0.15, 0.15]  # Train, test, valid proportions

if sum(props) != 1:
    raise ValueError("Props must sum to 1")

# Extracting the data
with ZipFile(archive_path, "r") as zip_ref:
    zip_ref.extractall(data_dir)

classes = os.listdir(os.path.join(data_dir, "train"))
num_classes = len(classes)

# Changing the data split
for cls in classes:
    dirs = [os.path.join(data_dir, subdir, cls) for subdir in data_subdirs]
    sizes = [len(os.listdir(d)) for d in dirs]
    total = sum(sizes)
    target_sizes = [int(p * total) for p in props]
    diffs = [target_sizes[i] - sizes[i] for i in range(len(sizes))]

    for i in range(len(diffs)):
        if diffs[i] < 0:
            for j in range(len(diffs)):
                if diffs[j] > 0:
                    count = min(abs(diffs[i]), diffs[j])
                    files = os.listdir(dirs[i])
                    files = files[:count]
                    for f in files:
                        shutil.move(os.path.join(dirs[i], f), os.path.join(dirs[j], f))
                    diffs[i] += count
                    diffs[j] -= count

# Checking the sizes
cls_sizes = {}
for cls in classes:
    cls_sizes[cls] = [
        len(os.listdir(os.path.join(data_dir, subdir, cls))) for subdir in data_subdirs
    ]
pd.DataFrame.from_dict(
    cls_sizes, orient="index", columns=[f"{set} size" for set in data_subdirs]
)

Unnamed: 0,train size,test size,valid size
cat,18900,4050,4050
dog,18900,4050,4050
truck,18900,4050,4050
bird,18900,4050,4050
airplane,18900,4050,4050
ship,18900,4050,4050
frog,18900,4050,4050
horse,18900,4050,4050
deer,18900,4050,4050
automobile,18900,4050,4050


## Loading the data

In [3]:
transforms_no_aug = T.Compose(
    [
        T.PILToTensor(),
        T.Resize((224, 224)),  # Default input size for most models
        T.ToDtype(torch.float32, scale=True),
        T.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        ),  # Default values for ImageNet
    ]
)

train_path = os.path.join(data_dir, "train")
test_path = os.path.join(data_dir, "test")
valid_path = os.path.join(data_dir, "valid")


def get_data(
    transforms: T.Compose, collate_fn=None, batch_size: int = 32
) -> tuple[DataLoader, DataLoader, DataLoader]:
    """
    Returns the data loaders of train, test and validation sets.

    Args:
    - tranforms: Transformations to be applied to the train set (augmentations).
    - batch_size: Batch size to be used for the data loaders (test and validation sets have double the size).

    Returns:
    - train_loader: DataLoader of the train set.
    - test_loader: DataLoader of the test set.
    - valid_loader: DataLoader of the validation set.
    """

    train = tv.datasets.ImageFolder(train_path, transform=transforms)
    test = tv.datasets.ImageFolder(test_path, transform=transforms_no_aug)
    valid = tv.datasets.ImageFolder(valid_path, transform=transforms_no_aug)

    # TODO: Remove from final version. Used for rapid prototyping.
    train_size, test_size, valid_size = 4, 4, 4
    train = Subset(train, torch.randperm(len(train))[:train_size])
    test = Subset(test, torch.randperm(len(test))[:test_size])
    valid = Subset(valid, torch.randperm(len(valid))[:valid_size])

    train_loader = DataLoader(
        train, shuffle=True, batch_size=batch_size, collate_fn=collate_fn
    )
    test_loader = DataLoader(test, shuffle=False, batch_size=batch_size * 2)
    valid_loader = DataLoader(valid, shuffle=False, batch_size=batch_size * 2)
    return train_loader, test_loader, valid_loader

## Defining the models

In [4]:
alexnet_model = tv.models.alexnet(weights="DEFAULT")
# Adjusting the last layer to match the number of classes
alexnet_model.classifier[-1] = nn.Linear(
    alexnet_model.classifier[-1].in_features, num_classes
)

resnet50_nontrained_model = timm.create_model(
    "resnet50", pretrained=False, num_classes=num_classes
)
# Addind dropout to the last layer for later fine-tuning
resnet50_nontrained_model.fc = nn.Sequential(
    nn.Dropout(0.0), resnet50_nontrained_model.fc
)

resnet50_model = timm.create_model("resnet50", pretrained=True, num_classes=num_classes)
# Addind dropout to the last layer for later fine-tuning
resnet50_model.fc = nn.Sequential(nn.Dropout(0.0), resnet50_model.fc)

vit_model = timm.create_model(
    "vit_base_patch16_224.augreg2_in21k_ft_in1k",
    pretrained=True,
    num_classes=num_classes,
)

models = {
    "AlexNet (pretrained)": alexnet_model,
    "ResNet50": resnet50_nontrained_model,
    "ResNet50 (pretrained)": resnet50_model,
    "VIT (pretrained)": vit_model,
}

## Defining the training loop

In [5]:
def fit(
    model: nn.Module,
    train_loader: DataLoader,
    valid_loader: DataLoader,
    adam_lr: float = 0.001,
    weight_decay: float = 0.0001,
    epochs: int = 50,
    patience: int = 5,
    bar_postfix: dict[str, str] = {},
) -> tuple[float, list[float], list[float], list[float], list[float]]:
    """
    Fits the model to the data using Adam optimizer and CrossEntropyLoss.

    Parameters:
    - model: The model to be trained.
    - train_loader: The DataLoader for the training set.
    - valid_loader: The DataLoader for the validation set.
    - adam_lr: The learning rate for the Adam optimizer.
    - weight_decay: The weight decay for the Adam optimizer.
    - epochs: The number of epochs to train the model.
    - patience: The number of epochs to wait for the validation loss to improve before stopping the training.
    - bar_postfix: The postfix to be displayed in the progress bars.

    Returns:
    - duration: The duration of the training in seconds.
    - train_loss: The training loss for each epoch.
    - train_acc: The training accuracy for each epoch.
    - valid_loss: The validation loss for each epoch.
    - valid_acc: The validation accuracy for each epoch.
    """
    model.to(device)
    opt = optim.Adam(model.parameters(), lr=adam_lr, weight_decay=weight_decay)
    loss_fn = nn.CrossEntropyLoss()
    multilabel_loss_fn = nn.BCEWithLogitsLoss()  # Required for cutmix augmentations
    epochs_train_loss = []
    epochs_train_acc = []
    epochs_valid_loss = []
    epochs_valid_acc = []

    best_valid_loss = float("inf")
    patience_counter = 0

    start_time = time.time()
    for epoch in range(epochs):
        running_loss = 0.0
        correct, total = 0, 0

        model.train()
        with tqdm(
            train_loader,
            desc=f"Epoch {epoch+1}/{epochs} [TRAIN]",
            leave=False,
            postfix=bar_postfix,
        ) as train_bar:
            for inputs, labels in train_bar:
                inputs, labels = inputs.to(device), labels.to(device)
                predictions = model(inputs)
                if labels.dim() > 1:
                    loss = multilabel_loss_fn(predictions, labels.float())
                else:
                    loss = loss_fn(predictions, labels)
                loss.backward()
                opt.step()
                opt.zero_grad()

                running_loss += loss.item()
                _, classifications = torch.max(predictions, 1)
                if labels.dim() > 1:
                    labels = (labels > 0.35).int()
                    mask = torch.zeros_like(labels).int()
                    mask.scatter_(1, classifications.unsqueeze(1), 1)
                    correct += torch.sum(mask & labels).item()
                else:
                    correct += (classifications == labels).sum().item()
                total += labels.size(0)

        epochs_train_loss.append(running_loss / total)
        epochs_train_acc.append(correct / total)

        model.eval()
        running_loss = 0.0
        correct, total = 0, 0

        with tqdm(
            valid_loader,
            desc=f"Epoch {epoch+1}/{epochs} [VALID]",
            leave=False,
            postfix=bar_postfix,
        ) as valid_bar:
            with torch.no_grad():
                for inputs, labels in valid_bar:
                    inputs, labels = inputs.to(device), labels.to(device)
                    predictions = model(inputs)
                    loss = loss_fn(predictions, labels)

                    running_loss += loss.item()
                    _, classifications = torch.max(predictions, 1)
                    correct += (classifications == labels).sum().item()
                    total += labels.size(0)

        epochs_valid_loss.append(running_loss / total)
        epochs_valid_acc.append(correct / total)

        if epochs_valid_loss[-1] < best_valid_loss:
            best_valid_loss = epochs_valid_loss[-1]
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            break

    end_time = time.time()
    duration = end_time - start_time

    return (
        duration,
        epochs_train_loss,
        epochs_train_acc,
        epochs_valid_loss,
        epochs_valid_acc,
    )

## Defining the evaluation loop

In [6]:
def test(
    model: nn.Module, test_loader: DataLoader, bar_postfix: dict[str, str] = {}
) -> tuple[float, float]:
    """
    Test the model on the test set.

    Parameters:
    - model: The model to be tested.
    - test_loader: The DataLoader for the test set.

    Returns:
    - loss: The loss on the test set.
    - accuracy: The accuracy on the test set.
    """
    model.to(device)
    model.eval()

    loss_fn = nn.CrossEntropyLoss()
    running_loss = 0.0
    correct, total = 0, 0

    with tqdm(
        test_loader, desc=f"[TEST]", leave=False, postfix=bar_postfix
    ) as test_bar:
        with torch.no_grad():
            for inputs, labels in test_bar:
                inputs, labels = inputs.to(device), labels.to(device)
                predictions = model(inputs)
                loss = loss_fn(predictions, labels)

                running_loss += loss.item()
                _, classifications = torch.max(predictions, 1)
                correct += (classifications == labels).sum().item()
                total += labels.size(0)

    loss = running_loss / total
    accuracy = correct / total

    return loss, accuracy

## Simulation loop

In [10]:
def run(
    train_loader: DataLoader,
    test_loader: DataLoader,
    valid_loader: DataLoader,
    lr: float = 0.001,
    weight_decay: float = 0.0001,
    epochs: int = 50,
    patience: int = 5,
    simplify: bool = True,
    bar_extra: dict[str, str] = {},
) -> dict:
    """
    Runs the training and testing on all models.

    Parameters:
    - train_loader: The DataLoader for the training set.
    - test_loader: The DataLoader for the test set.
    - valid_loader: The DataLoader for the validation set.
    - lr: The learning rate for the Adam optimizer.
    - weight_decay: The weight decay for the Adam optimizer.
    - epochs: The number of epochs to train the model.
    - patience: The number of epochs to wait for the validation loss to improve before stopping the training.

    Returns:
    - results: A dictionary containing the results of the training and testing for each model.
    """

    results = {}

    for model_name, model in models.items():
        bar_postfix = {
            "model": model_name,
            "lr": f"{lr}",
            "weight_decay": f"{weight_decay}",
            "patience": f"{patience}",
            "batch_size": f"{train_loader.batch_size}",
            **bar_extra,
        }

        duration, train_loss, train_acc, valid_loss, valid_acc = fit(
            model,
            train_loader,
            valid_loader,
            lr,
            weight_decay,
            epochs,
            patience,
            bar_postfix,
        )
        test_loss, test_acc = test(model, test_loader, bar_postfix)
        results[model_name] = {
            "duration": duration,
            "train_loss": train_loss,
            "train_acc": train_acc,
            "valid_loss": valid_loss,
            "valid_acc": valid_acc,
            "test_loss": test_loss,
            "test_acc": test_acc,
        }

    if simplify:
        for model_name, model_results in results.items():
            results[model_name] = {
                "duration": model_results["duration"],
                "epochs": len(model_results["train_loss"]),
                "train_loss": model_results["train_loss"][-1],
                "train_acc": model_results["train_acc"][-1],
                "valid_loss": model_results["valid_loss"][-1],
                "valid_acc": model_results["valid_acc"][-1],
                "test_loss": model_results["test_loss"],
                "test_acc": model_results["test_acc"],
            }

    return results


def simulate(
    train_loader: DataLoader,
    test_loader: DataLoader,
    valid_loader: DataLoader,
    times: int = 5,
    lr: float = 0.001,
    weight_decay: float = 0.0001,
    epochs: int = 50,
    patience: int = 5,
) -> pd.DataFrame:
    """
    Simulates the training and testing of the models multiple times.

    Parameters:
    - times: The number of times to simulate the training and testing.
    - lr: The learning rate for the Adam optimizer.
    - weight_decay: The weight decay for the Adam optimizer.
    - epochs: The number of epochs to train the model.
    - patience: The number of epochs to wait for the validation loss to improve before stopping the training.

    Returns:
    - df: A DataFrame containing the results.
    """
    results = []
    for i in range(times):
        result = run(
            train_loader,
            test_loader,
            valid_loader,
            lr,
            weight_decay,
            epochs,
            patience,
            bar_extra={"simulation": f"{i+1}/{times}"},
        )
        results.append(result)

    results = [pd.DataFrame.from_dict(result, orient="index") for result in results]
    return pd.concat(results)

## Basic training (no data augmentation)

In [8]:
train_loader, test_loader, valid_loader = get_data(transforms_no_aug)
results = simulate(train_loader, test_loader, valid_loader, times=1)
display(results.groupby(results.index).agg(["min", "mean", "max", "std"]))

                                                                                                                                                   

Unnamed: 0_level_0,duration,duration,duration,duration,epochs,epochs,epochs,epochs,train_loss,train_loss,...,valid_acc,valid_acc,test_loss,test_loss,test_loss,test_loss,test_acc,test_acc,test_acc,test_acc
Unnamed: 0_level_1,min,mean,max,std,min,mean,max,std,min,mean,...,max,std,min,mean,max,std,min,mean,max,std
AlexNet (pretrained),1.427249,1.427249,1.427249,,6,6.0,6,,0.269313,0.269313,...,0.0,,1.372709,1.372709,1.372709,,0.0,0.0,0.0,
ResNet50,2.650984,2.650984,2.650984,,6,6.0,6,,0.000893,0.000893,...,0.0,,0.749342,0.749342,0.749342,,0.0,0.0,0.0,
ResNet50 (pretrained),1.092215,1.092215,1.092215,,6,6.0,6,,0.187937,0.187937,...,0.0,,0.672348,0.672348,0.672348,,0.0,0.0,0.0,
VIT (pretrained),2.862304,2.862304,2.862304,,8,8.0,8,,0.738243,0.738243,...,0.0,,2.113808,2.113808,2.113808,,0.0,0.0,0.0,


## Data augmentation

In [12]:
transforms_basic_aug = T.Compose(
    [
        T.PILToTensor(),
        T.Resize((224, 224)),  # Default input size for most models
        T.RandomHorizontalFlip(p=0.5),
        T.RandomVerticalFlip(p=0.5),
        T.RandomRotation(degrees=15),
        T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        T.ToDtype(torch.float32, scale=True),
        T.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        ),  # Default values for ImageNet
    ]
)

# CutMix works directly on the batch
cutmix = T.CutMix(num_classes=num_classes)
advanced_transforms = lambda batch: cutmix(*default_collate(batch))

transforms = [
    {
        "transforms": transforms_basic_aug,
        "collate_fn": advanced_transforms,
        "name": "Advanced Augmentations",
    },
    {
        "transforms": transforms_basic_aug,
        "collate_fn": None,
        "name": "Basic Augmentations",
    },
]

results["augmentation"] = "None"
for transform in transforms:
    train_loader, test_loader, valid_loader = get_data(
        transform["transforms"], transform["collate_fn"]
    )
    result = simulate(train_loader, test_loader, valid_loader)
    result["augmentation"] = transform["name"]
    results = pd.concat([results, result])

for model_name, model_results in results.groupby(results.index):
    print(model_name)
    display(model_results.groupby("augmentation").agg(["min", "mean", "max", "std"]))

                                                                                                                                                                         

AlexNet (pretrained)




Unnamed: 0_level_0,duration,duration,duration,duration,epochs,epochs,epochs,epochs,train_loss,train_loss,...,valid_acc,valid_acc,test_loss,test_loss,test_loss,test_loss,test_acc,test_acc,test_acc,test_acc
Unnamed: 0_level_1,min,mean,max,std,min,mean,max,std,min,mean,...,max,std,min,mean,max,std,min,mean,max,std
augmentation,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
Advanced Augmentations,0.787131,0.891022,1.032274,0.08916,8,8.0,8,0.0,0.056222,0.060482,...,0.25,0.136931,0.50035,0.981632,1.45922,0.341478,0.0,0.05,0.25,0.111803
Basic Augmentations,0.641553,0.827034,1.303558,0.273149,6,8.2,13,2.774887,0.067451,0.486309,...,0.0,0.0,1.307963,7.022713,17.507271,6.567098,0.0,0.0,0.0,0.0
,1.427249,1.427249,1.427249,,6,6.0,6,,0.269313,0.269313,...,0.0,,1.372709,1.372709,1.372709,,0.0,0.0,0.0,


ResNet50


Unnamed: 0_level_0,duration,duration,duration,duration,epochs,epochs,epochs,epochs,train_loss,train_loss,...,valid_acc,valid_acc,test_loss,test_loss,test_loss,test_loss,test_acc,test_acc,test_acc,test_acc
Unnamed: 0_level_1,min,mean,max,std,min,mean,max,std,min,mean,...,max,std,min,mean,max,std,min,mean,max,std
augmentation,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
Advanced Augmentations,1.273619,1.875412,2.654637,0.579982,6,8.6,12,2.607681,0.024334,0.045578,...,0.5,0.111803,0.690729,1.448088,2.951771,1.010261,0.0,0.1,0.25,0.136931
Basic Augmentations,1.179647,2.226527,4.279972,1.424037,6,10.6,20,6.542171,3e-06,0.015983,...,0.0,0.0,1.253387,2.298413,3.952567,1.128747,0.0,0.0,0.0,0.0
,2.650984,2.650984,2.650984,,6,6.0,6,,0.000893,0.000893,...,0.0,,0.749342,0.749342,0.749342,,0.0,0.0,0.0,


ResNet50 (pretrained)


Unnamed: 0_level_0,duration,duration,duration,duration,epochs,epochs,epochs,epochs,train_loss,train_loss,...,valid_acc,valid_acc,test_loss,test_loss,test_loss,test_loss,test_acc,test_acc,test_acc,test_acc
Unnamed: 0_level_1,min,mean,max,std,min,mean,max,std,min,mean,...,max,std,min,mean,max,std,min,mean,max,std
augmentation,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
Advanced Augmentations,1.188902,1.299999,1.510219,0.131445,6,6.2,7,0.447214,0.043896,0.057233,...,0.5,0.209165,0.583717,0.710555,0.793285,0.083842,0.0,0.0,0.0,0.0
Basic Augmentations,1.296114,2.662553,7.336753,2.633657,6,12.4,34,12.198361,8.4e-05,0.161622,...,0.0,0.0,0.822262,2.803492,8.550341,3.236398,0.0,0.0,0.0,0.0
,1.092215,1.092215,1.092215,,6,6.0,6,,0.187937,0.187937,...,0.0,,0.672348,0.672348,0.672348,,0.0,0.0,0.0,


VIT (pretrained)


Unnamed: 0_level_0,duration,duration,duration,duration,epochs,epochs,epochs,epochs,train_loss,train_loss,...,valid_acc,valid_acc,test_loss,test_loss,test_loss,test_loss,test_acc,test_acc,test_acc,test_acc
Unnamed: 0_level_1,min,mean,max,std,min,mean,max,std,min,mean,...,max,std,min,mean,max,std,min,mean,max,std
augmentation,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
Advanced Augmentations,2.007723,3.280416,4.611955,1.026985,6,10.0,14,3.162278,0.056936,0.059967,...,0.5,0.209165,0.864874,1.025186,1.156014,0.111451,0.0,0.25,0.5,0.25
Basic Augmentations,2.530021,3.814754,5.232007,1.33511,8,11.8,16,4.024922,0.097978,0.129336,...,0.0,0.0,1.972348,2.708704,3.389903,0.693061,0.0,0.0,0.0,0.0
,2.862304,2.862304,2.862304,,8,8.0,8,,0.738243,0.738243,...,0.0,,2.113808,2.113808,2.113808,,0.0,0.0,0.0,


## Batch size tuning

In [14]:
batch_sizes = [1, 4, 32, 128, 512]
results = []

for batch_size in batch_sizes:
    train_loader, test_loader, valid_loader = get_data(
        transforms_basic_aug, advanced_transforms, batch_size=batch_size
    )
    result = simulate(train_loader, test_loader, valid_loader)
    result["batch_size"] = batch_size
    results.append(result)

results = pd.concat(results)
for model_name, model_results in results.groupby(results.index):
    print(model_name)
    display(model_results.groupby("batch_size").agg(["min", "mean", "max", "std"]))

                                                                                                                                                                          

AlexNet (pretrained)




Unnamed: 0_level_0,duration,duration,duration,duration,epochs,epochs,epochs,epochs,train_loss,train_loss,...,valid_acc,valid_acc,test_loss,test_loss,test_loss,test_loss,test_acc,test_acc,test_acc,test_acc
Unnamed: 0_level_1,min,mean,max,std,min,mean,max,std,min,mean,...,max,std,min,mean,max,std,min,mean,max,std
batch_size,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
1,12.292037,12.636668,13.327302,0.412702,50,50.0,50,0.0,290619100000000.0,2.080398e+18,...,0.25,0.111803,2261910000000000.0,4.439354e+18,1.947146e+19,8.465851e+18,0.0,0.0,0.0,0.0
4,0.748326,1.230835,2.087544,0.626569,6,8.8,13,3.420526,20635530000000.0,72471840000000.0,...,0.0,0.0,304944500000000.0,525039600000000.0,788911900000000.0,188694600000000.0,0.25,0.25,0.25,0.0
32,0.709512,0.966003,1.288667,0.21824,7,8.6,11,1.67332,8645100000000.0,38939500000000.0,...,0.75,0.223607,119692100000000.0,132568000000000.0,163121200000000.0,17753180000000.0,0.25,0.3,0.5,0.111803
128,0.671713,1.209926,1.56106,0.333156,7,11.0,14,2.54951,12368320000000.0,77654730000000.0,...,0.0,0.0,116275200000000.0,152183800000000.0,204930300000000.0,37520360000000.0,0.0,0.0,0.0,0.0
512,0.719574,1.07846,1.788963,0.447996,6,9.6,15,3.911521,11998510000000.0,34722900000000.0,...,0.0,0.0,182254200000000.0,218765700000000.0,271273600000000.0,36010170000000.0,0.0,0.0,0.0,0.0


ResNet50


Unnamed: 0_level_0,duration,duration,duration,duration,epochs,epochs,epochs,epochs,train_loss,train_loss,...,valid_acc,valid_acc,test_loss,test_loss,test_loss,test_loss,test_acc,test_acc,test_acc,test_acc
Unnamed: 0_level_1,min,mean,max,std,min,mean,max,std,min,mean,...,max,std,min,mean,max,std,min,mean,max,std
batch_size,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
1,6.74614,11.601288,20.063563,5.613383,9,16.0,28,7.778175,10.437561,14.432922,...,0.25,0.136931,50.124301,85.643386,122.855584,27.587103,0.25,0.25,0.25,0.0
4,1.782119,3.917674,6.123293,1.633512,6,16.6,24,7.056912,2.558356,2.923995,...,0.0,0.0,19.085958,21.028572,24.549709,2.295296,0.0,0.0,0.0,0.0
32,1.352449,2.559994,3.96914,1.259577,6,11.4,18,5.683309,0.904717,1.014315,...,0.0,0.0,18.767651,19.779391,21.000748,0.991744,0.0,0.0,0.0,0.0
128,1.480926,3.320402,6.333324,1.978701,6,14.0,28,8.860023,2.381181,2.536147,...,0.0,0.0,17.703173,18.997,20.794411,1.294151,0.0,0.0,0.0,0.0
512,1.268526,1.576669,2.403,0.477218,6,7.0,11,2.236068,1.645583,1.763625,...,0.25,0.0,28.629992,29.770912,31.589544,1.257515,0.0,0.0,0.0,0.0


ResNet50 (pretrained)


Unnamed: 0_level_0,duration,duration,duration,duration,epochs,epochs,epochs,epochs,train_loss,train_loss,...,valid_acc,valid_acc,test_loss,test_loss,test_loss,test_loss,test_acc,test_acc,test_acc,test_acc
Unnamed: 0_level_1,min,mean,max,std,min,mean,max,std,min,mean,...,max,std,min,mean,max,std,min,mean,max,std
batch_size,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
1,6.433843,15.402963,22.565833,6.399763,7,22.0,33,10.416333,6.708401,17.025098,...,0.25,0.111803,24.496798,15240.056943,46238.798893,18951.117206,0.0,0.0,0.0,0.0
4,1.54848,3.75221,7.784665,2.41968,6,16.2,34,10.825895,0.661659,1.466591,...,0.0,0.0,7.5918,7.997889,8.796241,0.506072,0.0,0.2,0.25,0.111803
32,1.301629,1.374076,1.504861,0.076938,6,6.0,6,0.0,1.098993,2.455892,...,0.0,0.0,32.471195,5121.256495,14520.199219,6710.346454,0.0,0.0,0.0,0.0
128,1.444505,2.640172,4.489537,1.293388,6,11.2,18,5.167204,1.013017,1.766395,...,0.0,0.0,7.727118,73.49639,262.760895,108.142199,0.0,0.0,0.0,0.0
512,1.457113,2.907155,6.858294,2.276537,6,11.8,25,7.854935,0.603645,0.995566,...,0.0,0.0,451.453003,2356.485144,4882.056152,1798.831746,0.0,0.0,0.0,0.0


VIT (pretrained)


Unnamed: 0_level_0,duration,duration,duration,duration,epochs,epochs,epochs,epochs,train_loss,train_loss,...,valid_acc,valid_acc,test_loss,test_loss,test_loss,test_loss,test_acc,test_acc,test_acc,test_acc
Unnamed: 0_level_1,min,mean,max,std,min,mean,max,std,min,mean,...,max,std,min,mean,max,std,min,mean,max,std
batch_size,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
1,4.489616,4.922348,5.288567,0.334188,6,7.0,8,0.707107,0.190325,0.210309,...,0.0,0.0,4.220624,4.891161,5.349406,0.464846,0.0,0.0,0.0,0.0
4,2.715698,3.855442,5.086884,0.910756,8,10.2,12,2.04939,0.036177,0.040654,...,0.0,0.0,2.033588,2.337843,2.927811,0.359858,0.0,0.0,0.0,0.0
32,2.294016,2.629706,2.983756,0.323014,7,7.8,9,0.83666,0.047604,0.067851,...,0.5,0.273861,0.717976,1.202722,1.498418,0.332016,0.0,0.2,0.25,0.111803
128,2.384223,2.808514,3.552093,0.450974,7,8.4,11,1.516575,0.042456,0.045229,...,0.25,0.0,0.696551,1.021297,1.357837,0.239562,0.5,0.5,0.5,0.0
512,1.87362,3.225112,5.202865,1.242069,6,9.6,16,3.847077,0.057559,0.075982,...,0.0,0.0,0.792681,1.042288,1.203949,0.155015,0.0,0.25,0.5,0.25


## Learning rate tuning

In [15]:
learning_rates = [1, 0.1, 0.01, 0.001, 0.0001, 0.00001]
results = []

for lr in learning_rates:
    train_loader, test_loader, valid_loader = get_data(
        transforms_basic_aug, advanced_transforms
    )
    result = simulate(train_loader, test_loader, valid_loader, lr=lr)
    result["lr"] = lr
    results.append(result)

results = pd.concat(results)
for model_name, model_results in results.groupby(results.index):
    print(model_name)
    display(model_results.groupby("lr").agg(["min", "mean", "max", "std"]))

                                                                                                                                                                          

AlexNet (pretrained)




Unnamed: 0_level_0,duration,duration,duration,duration,epochs,epochs,epochs,epochs,train_loss,train_loss,...,valid_acc,valid_acc,test_loss,test_loss,test_loss,test_loss,test_acc,test_acc,test_acc,test_acc
Unnamed: 0_level_1,min,mean,max,std,min,mean,max,std,min,mean,...,max,std,min,mean,max,std,min,mean,max,std
lr,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
1e-05,1.052239,2.800855,4.824537,1.634696,8,24.0,46,16.507574,298359.7,3242364.0,...,0.5,0.0,1087802.0,1599861.0,2523499.0,548479.5,0.0,0.0,0.0,0.0
0.0001,1.528894,2.534532,4.482288,1.175172,13,23.0,41,11.045361,119877.0,869727.0,...,0.25,0.111803,1602392.0,2106253.0,2947990.0,508386.8,0.25,0.25,0.25,0.0
0.001,0.527111,1.578518,2.990566,1.189257,6,14.8,27,10.756393,99508.02,1287505.0,...,0.25,0.136931,14130450.0,181399800.0,262715000.0,96307180.0,0.0,0.15,0.5,0.223607
0.01,0.799069,1.579888,2.812852,0.967166,7,12.4,25,7.569676,43741150.0,545968700.0,...,0.25,0.111803,899788000.0,23022620000.0,44896060000.0,18477280000.0,0.0,0.05,0.25,0.111803
0.1,0.758474,2.603566,5.986036,2.369433,7,22.4,50,20.549939,214.3866,23735020000000.0,...,0.25,0.111803,15888.94,4006297000000000.0,1.582129e+16,6758717000000000.0,0.0,0.0,0.0,0.0
1.0,0.690192,1.306656,2.845016,0.941092,6,11.6,23,7.95613,7461345.0,2.97048e+19,...,0.25,0.136931,82816180.0,6.131601e+20,2.491732e+21,1.077517e+21,0.0,0.2,0.25,0.111803


ResNet50


Unnamed: 0_level_0,duration,duration,duration,duration,epochs,epochs,epochs,epochs,train_loss,train_loss,...,valid_acc,valid_acc,test_loss,test_loss,test_loss,test_loss,test_acc,test_acc,test_acc,test_acc
Unnamed: 0_level_1,min,mean,max,std,min,mean,max,std,min,mean,...,max,std,min,mean,max,std,min,mean,max,std
lr,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
1e-05,1.299401,2.076139,2.922294,0.582376,6,8.8,11,1.923538,0.344466,0.418568,...,0.0,0.0,1.417987,1.444077,1.453691,0.01478521,0.25,0.25,0.25,0.0
0.0001,1.707944,4.533064,10.883865,3.662431,7,19.0,46,15.890249,0.187177,0.20785,...,0.0,0.0,1.496949,1.573579,1.68356,0.08100871,0.0,0.0,0.0,0.0
0.001,1.555684,3.843468,9.992786,3.512337,7,15.8,37,12.316655,0.102653,0.119812,...,0.0,0.0,2.705313,2.857009,2.996236,0.1304249,0.0,0.0,0.0,0.0
0.01,1.363997,2.158457,3.677788,0.953862,6,9.4,17,4.722288,0.066287,0.13235,...,0.25,0.111803,1.344295,2.227801,2.895855,0.6764574,0.0,0.05,0.25,0.111803
0.1,1.281435,2.86076,5.556809,1.816703,6,11.6,23,7.635444,0.212793,0.480728,...,0.0,0.0,3.486407,5.12423,6.897058,1.345879,0.0,0.05,0.25,0.111803
1.0,1.400752,1.771525,2.388937,0.376267,6,8.0,11,1.870829,2.53604,5.863146,...,0.25,0.111803,2712.392334,120559700000.0,592924800000.0,264094500000.0,0.0,0.1,0.25,0.136931


ResNet50 (pretrained)


Unnamed: 0_level_0,duration,duration,duration,duration,epochs,epochs,epochs,epochs,train_loss,train_loss,...,valid_acc,valid_acc,test_loss,test_loss,test_loss,test_loss,test_acc,test_acc,test_acc,test_acc
Unnamed: 0_level_1,min,mean,max,std,min,mean,max,std,min,mean,...,max,std,min,mean,max,std,min,mean,max,std
lr,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
1e-05,1.266349,2.105076,3.790782,1.093167,6,9.0,15,4.242641,0.57426,0.699833,...,0.0,0.0,48.435947,57.62659,64.76685,7.275234,0.25,0.25,0.25,0.0
0.0001,1.371224,3.89972,11.871962,4.485864,6,16.6,50,18.809572,0.160455,0.19792,...,0.25,0.136931,2.223423,3.537691,4.235633,0.7818875,0.0,0.15,0.25,0.136931
0.001,1.50359,3.092276,5.201946,1.34197,7,12.6,22,5.85662,0.068506,0.094426,...,0.0,0.0,6.015432,6.182313,6.35706,0.1357631,0.0,0.0,0.0,0.0
0.01,1.232629,2.110096,3.081578,0.707013,6,8.4,14,3.577709,0.064053,0.208916,...,0.25,0.111803,6.411244,7.135587,7.669309,0.5608415,0.0,0.0,0.0,0.0
0.1,1.606903,3.268805,5.791214,2.022106,7,12.8,23,8.01249,0.215592,0.65179,...,0.0,0.0,5.156852,12.74846,20.02557,7.036186,0.0,0.1,0.25,0.136931
1.0,1.240411,1.562162,2.045129,0.433775,6,7.2,10,1.788854,5.120643,9.431401,...,0.25,0.111803,25726.603516,122885500.0,457366000.0,198946700.0,0.0,0.2,0.25,0.111803


VIT (pretrained)


Unnamed: 0_level_0,duration,duration,duration,duration,epochs,epochs,epochs,epochs,train_loss,train_loss,...,valid_acc,valid_acc,test_loss,test_loss,test_loss,test_loss,test_acc,test_acc,test_acc,test_acc
Unnamed: 0_level_1,min,mean,max,std,min,mean,max,std,min,mean,...,max,std,min,mean,max,std,min,mean,max,std
lr,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
1e-05,16.577179,16.971045,17.544271,0.35584,50,50.0,50,0.0,2.499599,2.578522,...,0.0,0.0,1.552596,1.726622,1.943198,0.159758,0.0,0.2,0.25,0.111803
0.0001,1.986458,5.12439,16.70318,6.484972,6,15.4,50,19.385562,0.131382,0.156315,...,0.25,0.111803,1.490881,1.802362,2.063199,0.2296,0.0,0.15,0.25,0.136931
0.001,2.177841,3.157915,3.850548,0.693271,7,9.4,12,2.073644,0.057442,0.112165,...,0.25,0.111803,27.466904,27.755635,28.282778,0.329063,0.0,0.0,0.0,0.0
0.01,2.577254,3.273639,4.682485,0.842944,6,9.4,14,2.966479,0.296223,0.717761,...,0.25,0.136931,28.173367,33.021724,48.891415,8.910033,0.0,0.05,0.25,0.111803
0.1,2.95218,4.325411,7.691186,1.974575,9,12.0,21,5.09902,1.596185,4.368663,...,0.25,0.111803,20.589836,86.890168,174.119446,76.431469,0.0,0.1,0.25,0.136931
1.0,1.85405,2.543262,3.488052,0.728422,6,7.8,11,2.167948,18.002171,27.189544,...,0.0,0.0,157.780884,371.890338,940.694458,331.233479,0.0,0.1,0.25,0.136931


## Weight decay tuning

In [19]:
weight_decays = [0, 0.0001, 0.001, 0.01, 0.1]
results = []

for weight_decay in weight_decays:
    train_loader, test_loader, valid_loader = get_data(
        transforms_basic_aug, advanced_transforms
    )
    result = simulate(
        train_loader, test_loader, valid_loader, weight_decay=weight_decay
    )
    result["weight_decay"] = weight_decay
    results.append(result)

results = pd.concat(results)
for model_name, model_results in results.groupby(results.index):
    print(model_name)
    display(model_results.groupby("weight_decay").agg(["min", "mean", "max", "std"]))

                                                                                                                                                                         

AlexNet (pretrained)




Unnamed: 0_level_0,duration,duration,duration,duration,epochs,epochs,epochs,epochs,train_loss,train_loss,...,valid_acc,valid_acc,test_loss,test_loss,test_loss,test_loss,test_acc,test_acc,test_acc,test_acc
Unnamed: 0_level_1,min,mean,max,std,min,mean,max,std,min,mean,...,max,std,min,mean,max,std,min,mean,max,std
weight_decay,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
0.0,1.774094,3.448348,6.171636,2.010312,7,10.8,20,5.263079,22771.246094,1380568.0,...,0.25,0.136931,44107.7,1380911.0,3517412.0,1545989.0,0.0,0.2,0.5,0.273861
0.0001,0.896816,1.260683,1.750914,0.311365,7,8.4,12,2.073644,81742.242188,938163.9,...,0.5,0.223607,304721.9,8174804.0,15321811.0,6019227.0,0.0,0.15,0.5,0.223607
0.001,0.926267,2.61123,4.607891,1.636134,6,15.4,29,9.2358,18867.998047,294646.6,...,0.5,0.209165,7608078.0,120814300.0,274771776.0,109581700.0,0.0,0.1,0.25,0.136931
0.01,0.956922,1.766155,2.884906,0.820277,7,11.2,19,4.764452,162108.671875,3370483.0,...,0.25,0.111803,2463734.0,4202550.0,6293742.0,1367545.0,0.0,0.0,0.0,0.0
0.1,1.084617,2.889545,6.224247,2.058764,6,15.2,35,11.987493,588.257507,28427.11,...,0.25,0.111803,20127060.0,33020790.0,49281156.0,12507850.0,0.0,0.1,0.5,0.223607


ResNet50


Unnamed: 0_level_0,duration,duration,duration,duration,epochs,epochs,epochs,epochs,train_loss,train_loss,...,valid_acc,valid_acc,test_loss,test_loss,test_loss,test_loss,test_acc,test_acc,test_acc,test_acc
Unnamed: 0_level_1,min,mean,max,std,min,mean,max,std,min,mean,...,max,std,min,mean,max,std,min,mean,max,std
weight_decay,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
0.0,1.995515,5.88041,14.120814,4.809053,6,14.2,43,16.192591,0.067424,0.177675,...,0.5,0.111803,0.920808,1.045571,1.184497,0.120627,0.25,0.25,0.25,0.0
0.0001,1.833739,5.075301,15.32672,5.817791,6,16.6,50,18.968395,0.058804,0.083568,...,0.25,0.111803,0.61571,0.719901,0.892392,0.11092,0.0,0.05,0.25,0.111803
0.001,1.499354,1.673018,1.95492,0.17528,6,6.0,6,0.0,0.089148,0.145301,...,0.25,0.0,0.758825,0.785115,0.851845,0.039289,0.25,0.25,0.25,0.0
0.01,1.656434,3.47737,8.018673,2.617297,6,11.8,31,10.802777,0.05467,0.078012,...,0.25,0.136931,1.55355,1.604747,1.660944,0.042449,0.0,0.0,0.0,0.0
0.1,1.514124,4.632575,12.309208,4.367129,6,15.0,41,14.713939,0.060406,0.074233,...,0.0,0.0,1.155004,1.184742,1.239578,0.032159,0.0,0.0,0.0,0.0


ResNet50 (pretrained)


Unnamed: 0_level_0,duration,duration,duration,duration,epochs,epochs,epochs,epochs,train_loss,train_loss,...,valid_acc,valid_acc,test_loss,test_loss,test_loss,test_loss,test_acc,test_acc,test_acc,test_acc
Unnamed: 0_level_1,min,mean,max,std,min,mean,max,std,min,mean,...,max,std,min,mean,max,std,min,mean,max,std
weight_decay,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
0.0,1.821185,3.840255,9.338797,3.153536,6,8.4,11,2.50998,0.071481,0.220903,...,0.5,0.136931,1.207919,1.441394,1.803344,0.277332,0.0,0.15,0.25,0.136931
0.0001,1.378542,4.386812,15.558829,6.247993,6,14.0,46,17.888544,0.052517,0.059817,...,0.25,0.0,26.15097,73.136885,136.230911,45.938607,0.0,0.05,0.25,0.111803
0.001,1.601532,1.749302,1.953237,0.132648,6,6.0,6,0.0,0.066474,0.104599,...,0.5,0.111803,0.720148,0.981159,1.343951,0.26823,0.0,0.2,0.5,0.209165
0.01,1.461142,2.751713,5.950837,1.829503,6,10.6,23,7.127412,0.058986,0.091134,...,0.25,0.0,4.513608,5.010452,5.255964,0.329696,0.0,0.0,0.0,0.0
0.1,1.383346,2.709096,4.105395,1.136915,6,10.0,16,4.0,0.094313,0.11349,...,0.0,0.0,2.665132,3.207234,3.99181,0.638645,0.0,0.2,0.25,0.111803


VIT (pretrained)


Unnamed: 0_level_0,duration,duration,duration,duration,epochs,epochs,epochs,epochs,train_loss,train_loss,...,valid_acc,valid_acc,test_loss,test_loss,test_loss,test_loss,test_acc,test_acc,test_acc,test_acc
Unnamed: 0_level_1,min,mean,max,std,min,mean,max,std,min,mean,...,max,std,min,mean,max,std,min,mean,max,std
weight_decay,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
0.0,2.263692,3.95342,6.116941,1.461795,6,8.0,10,1.870829,0.104588,1.104021,...,0.5,0.209165,1.302825,1.519301,1.749753,0.160154,0.0,0.15,0.5,0.223607
0.0001,2.188706,3.410836,5.487451,1.276344,6,8.2,13,2.774887,0.056891,0.060046,...,0.25,0.136931,2.443638,3.357903,4.314363,0.748838,0.0,0.1,0.25,0.136931
0.001,2.063174,2.377521,3.272087,0.504633,6,6.6,9,1.341641,0.057759,0.171072,...,0.25,0.136931,1.058827,2.271562,3.063349,0.788862,0.0,0.2,0.25,0.111803
0.01,2.622007,4.960057,8.885305,2.37541,7,12.6,21,5.22494,0.308356,0.656897,...,0.25,0.136931,5.686426,8.387229,12.409633,2.453134,0.0,0.0,0.0,0.0
0.1,2.229029,3.267037,5.447147,1.292304,6,8.0,12,2.44949,0.058417,0.128177,...,0.25,0.136931,9.050358,10.017333,10.888554,0.812287,0.0,0.05,0.25,0.111803


## Dropout tuning

In [20]:
dropout_probs = [0.0, 0.2, 0.4, 0.6, 0.8]
results = []


def set_dropout(model, dropout_prob):
    for name, module in model.named_modules():
        if isinstance(module, nn.Dropout):
            module.p = dropout_prob
    return model


for dropout_prob in dropout_probs:
    for model_name, model in models.items():
        model = set_dropout(model, dropout_prob)
    train_loader, test_loader, valid_loader = get_data(
        transforms_basic_aug, advanced_transforms
    )
    result = simulate(train_loader, test_loader, valid_loader)
    result["dropout_prob"] = dropout_prob
    results.append(result)

results = pd.concat(results)
for model_name, model_results in results.groupby(results.index):
    print(model_name)
    display(model_results.groupby("dropout_prob").agg(["min", "mean", "max", "std"]))

                                                                                                                                                                         

AlexNet (pretrained)




Unnamed: 0_level_0,duration,duration,duration,duration,epochs,epochs,epochs,epochs,train_loss,train_loss,...,valid_acc,valid_acc,test_loss,test_loss,test_loss,test_loss,test_acc,test_acc,test_acc,test_acc
Unnamed: 0_level_1,min,mean,max,std,min,mean,max,std,min,mean,...,max,std,min,mean,max,std,min,mean,max,std
dropout_prob,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
0.0,1.847683,2.190815,2.487225,0.261237,6,9.8,16,3.898718,6169.606445,91059.650879,...,0.5,0.223607,253211.2,4423652.0,17301010.0,7228985.0,0.0,0.1,0.25,0.136931
0.2,1.215089,1.653162,2.691038,0.608277,6,9.0,14,3.162278,58559.792969,92610.7375,...,0.25,0.0,1345567.0,2810239.0,6731716.0,2287852.0,0.0,0.05,0.25,0.111803
0.4,1.05236,2.666036,6.403869,2.33808,6,13.0,37,13.435029,1066.366089,490040.704077,...,0.5,0.176777,128429.9,3266408.0,7706327.5,2759062.0,0.0,0.2,0.5,0.209165
0.6,0.936497,1.888572,2.651575,0.696359,6,9.2,15,3.701351,1560.950684,39654.376611,...,0.5,0.223607,4188860.0,20070110.0,38544744.0,13268210.0,0.0,0.05,0.25,0.111803
0.8,0.746808,1.234679,1.874783,0.430486,6,7.0,9,1.414214,12875.769531,292943.367188,...,0.25,0.136931,429386.1,3369396.0,8804047.0,3235333.0,0.0,0.0,0.0,0.0


ResNet50


Unnamed: 0_level_0,duration,duration,duration,duration,epochs,epochs,epochs,epochs,train_loss,train_loss,...,valid_acc,valid_acc,test_loss,test_loss,test_loss,test_loss,test_acc,test_acc,test_acc,test_acc
Unnamed: 0_level_1,min,mean,max,std,min,mean,max,std,min,mean,...,max,std,min,mean,max,std,min,mean,max,std
dropout_prob,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
0.0,1.575847,2.331755,2.946969,0.621161,6,6.6,8,0.894427,0.048523,0.061735,...,0.0,0.0,1.102464,1.475852,1.864501,0.317882,0.0,0.0,0.0,0.0
0.2,1.669358,3.895783,11.473446,4.246068,6,14.2,47,18.335757,0.125633,0.399831,...,0.5,0.0,0.964584,1.35099,1.747689,0.344747,0.25,0.25,0.25,0.0
0.4,1.714601,3.644963,9.277344,3.174061,6,12.8,32,10.94075,0.2469,0.536635,...,0.25,0.111803,3.070529,3.602354,3.823701,0.305719,0.25,0.25,0.25,0.0
0.6,2.114738,4.239307,8.538179,2.635409,6,13.8,32,10.756393,0.217012,0.293521,...,0.5,0.273861,2.313108,3.307989,5.612076,1.322143,0.0,0.05,0.25,0.111803
0.8,1.385406,4.593537,12.949914,4.781981,6,16.8,50,18.673511,0.08757,0.317618,...,0.0,0.0,2.519018,2.628816,2.897285,0.154714,0.0,0.05,0.25,0.111803


ResNet50 (pretrained)


Unnamed: 0_level_0,duration,duration,duration,duration,epochs,epochs,epochs,epochs,train_loss,train_loss,...,valid_acc,valid_acc,test_loss,test_loss,test_loss,test_loss,test_acc,test_acc,test_acc,test_acc
Unnamed: 0_level_1,min,mean,max,std,min,mean,max,std,min,mean,...,max,std,min,mean,max,std,min,mean,max,std
dropout_prob,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
0.0,1.690674,2.503572,3.343714,0.777773,6,7.2,9,1.30384,0.059126,0.115803,...,0.0,0.0,3.064599,18.082133,51.09478,19.18457,0.0,0.0,0.0,0.0
0.2,1.654225,2.677774,4.787808,1.223486,6,9.2,14,2.949576,0.103158,0.354092,...,0.5,0.176777,1.421577,3.665507,10.800774,4.00607,0.0,0.15,0.25,0.136931
0.4,1.652582,2.589965,5.77704,1.784637,6,9.0,21,6.708204,0.192881,0.604869,...,0.25,0.136931,11.682973,17.087617,19.239223,3.069946,0.25,0.25,0.25,0.0
0.6,2.084455,2.964064,3.654251,0.800882,6,9.0,12,2.54951,0.100769,0.495359,...,0.5,0.223607,6.773043,7.538954,8.267652,0.643904,0.0,0.05,0.25,0.111803
0.8,1.574079,2.144245,2.803425,0.564608,6,7.4,10,1.67332,0.277558,0.716328,...,0.0,0.0,8.215717,17.709922,34.455536,11.560786,0.25,0.25,0.25,0.0


VIT (pretrained)


Unnamed: 0_level_0,duration,duration,duration,duration,epochs,epochs,epochs,epochs,train_loss,train_loss,...,valid_acc,valid_acc,test_loss,test_loss,test_loss,test_loss,test_acc,test_acc,test_acc,test_acc
Unnamed: 0_level_1,min,mean,max,std,min,mean,max,std,min,mean,...,max,std,min,mean,max,std,min,mean,max,std
dropout_prob,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
0.0,2.814214,3.605843,5.011815,0.875468,7,8.8,13,2.387467,0.239222,0.357662,...,0.0,0.0,3.264619,3.746631,4.423058,0.430267,0.0,0.0,0.0,0.0
0.2,2.880784,3.541979,4.544854,0.64868,6,7.8,10,1.643168,0.542068,2.587223,...,0.25,0.111803,4.742131,5.080053,5.37412,0.303635,0.0,0.0,0.0,0.0
0.4,2.252617,2.867646,4.457467,0.910553,6,6.8,9,1.30384,2.00287,3.631438,...,0.0,0.0,36.751366,43.878009,50.726936,5.495876,0.0,0.0,0.0,0.0
0.6,4.992434,13.099792,23.162135,9.011425,12,28.4,50,19.806565,1.180176,3.055103,...,0.5,0.223607,30.74514,48.235353,58.907913,11.614675,0.0,0.0,0.0,0.0
0.8,2.384203,3.081375,4.358997,0.84189,6,7.6,11,2.073644,3.311845,5.164478,...,0.0,0.0,55.791672,57.948814,60.135757,1.752116,0.0,0.05,0.25,0.111803


## Ensemble of models

In [None]:
# TODO: Add ensemble of models