Notebook to train a baseline model on the CIFAR100 dataset using a ResNet-like architecture. The model is trained on the full CIFAR100 training set. Data augmentation is applied on the training set including random cropping and random horizontal flipping. The tensors are then normalized. Test and evaluation data are not augmented but only normalized.

In [1]:
import torch
import torchvision

from models.resnet.restnet18 import Resnet18
from datasets.cifar100_dataset import CIFAR100Dataset

from torchsummary import summary

from utils.dataset_utils import train_test_split
from utils.cifar100_utils import CIFAR100_LABELS

In [2]:
transformations_training = torchvision.transforms.Compose([
                torchvision.transforms.RandomHorizontalFlip(p=0.5),
                torchvision.transforms.RandomCrop(size=32, padding=4),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
])

transformations_test = torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
])

In [3]:
cifar_data = CIFAR100Dataset(root_dir='../../data/cifar100/training', transform=transformations_training)
training_data, validation_data = train_test_split(cifar_data, [0.8, 0.2])
validation_data.transform = transformations_test
test_data = CIFAR100Dataset(root_dir='../../data/cifar100/testing', transform=transformations_test)

In [4]:
model = Resnet18(classes=CIFAR100_LABELS, name='Cifar100 Baseline')

In [5]:
model.load_state_dict(torch.load('./trained_models/baseline_model.pth'))
# model.fit(
#     training_data=training_data,
#     validation_data=validation_data,
#     num_epochs=50,
#     batch_size=256,
#     learning_rate=0.01,
#     save_state_path='./trained_models/baseline_model.pth',
#     return_best_model=True
# )

<All keys matched successfully>

In [6]:
model.evaluate(test_data)

{'acc': 0.5787, 'name': 'Cifar100 Baseline'}