#### This notebook trains our baseline model

Imports

In [None]:
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
import torch
from torch import nn
import torchvision
from torch.utils.data import DataLoader, default_collate
from torchvision import datasets, transforms
from torchvision.transforms.v2 import MixUp, CutMix
os.chdir("../models")
from model import CustomCNN
from common_utils import set_seed, EarlyStopper, train, get_mean_rgb, CustomTransform, CustomRegularizeTransform

import matplotlib.pyplot as plt

# set seed
set_seed(42)

Initialise model and dataset

In [None]:
model = CustomCNN(fcn_depth=3, fcn_width=[1024, 512]) # initialise model

transform1 = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Resize((100, 100), antialias=True),
    transforms.ColorJitter(brightness=0.2, saturation=0.2),
    transforms.RandomHorizontalFlip(0.5),
    # most of the pictures are rotation invariant
    # transforms.RandomRotation((-30, 30)),
    transforms.RandomResizedCrop((100, 100), scale=(0.8, 1.0), antialias=True),
])

transform2 = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((100, 100), antialias=True),
])

def mixup_collate_fn(batch):
    return MixUp(num_classes=102)(*default_collate(batch))

# load data
train_dataset = datasets.Flowers102(root='../data', split='test', download=True, transform=transform1) 
val_dataset = datasets.Flowers102(root='../data', split='val', download=True, transform=transform2) 
test_dataset = datasets.Flowers102(root='../data', split='train', download=True, transform=transform2)
# NOTE: Due to a bug with the Flowers102 dataset, the train and test splits are swapped

batch_size = 128
# initialise dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=mixup_collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Visualize the original images and their transformed version

In [None]:
from itertools import product

train_dataset_ = datasets.Flowers102(root='../data', split='test', download=True, transform=transforms.ToTensor()) 

fig, axs = plt.subplots(4, 4, figsize=(20, 20))
offset = 1000 # Changet this to get different images

for i, j in product(range(4), range(2)):
    axs[i][2*j].imshow(train_dataset_[i+j*offset][0].permute(1, 2, 0))
    axs[i][2*j+1].imshow(train_dataset[i+j*offset][0].permute(1, 2, 0))

Specify hyperparameters

In [None]:
lr = 0.001 # learning rate
optimiser = torch.optim.Adam(model.parameters(), lr=lr) # initialise optimiser
loss = torch.nn.CrossEntropyLoss() # initialise loss function

if torch.cuda.is_available(): # nvidia gpu
    device = torch.device("cuda")
elif torch.backends.mps.is_available(): # apple gpu
    device = torch.device("mps")
else:
    device = torch.device("cpu")

model.to(device)

epochs = 300 # number of epochs
early_stopper = EarlyStopper(patience=10) # initialise early stopper


# Make directory to save baseline model
baseline_model_path = "./saved_models/baseline_model/non_preprocessed"
if not os.path.exists(baseline_model_path):
    os.makedirs(baseline_model_path, exist_ok=True)

# Define the device-specific path
device_type = None
if device == torch.device("cuda"):
    device_type = "cuda"
elif device == torch.device("mps"):
    device_type = "mps"
else:
    device_type = "cpu"

# Construct the full path
device_path = os.path.join(baseline_model_path, device_type)

# Create the directory if it doesn't exist
if not os.path.exists(device_path):
    os.mkdir(device_path)

model

Train the model

In [None]:
train_loss_list, val_loss_list, val_acc_list = train(model, train_dataloader, val_dataloader, optimiser, loss, device, epochs, early_stopper, device_path) # train model

Plot train, test loss and test accuracy

In [None]:
# plot the graphs
import matplotlib.pyplot as plt
plt.plot(train_loss_list, label="train loss")
plt.plot(val_loss_list, label="val loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()

plt.plot(val_acc_list, label="val accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.show()