In [1]:
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import os
os.chdir("../models")
from model import CustomCNN
from common_utils import set_seed, EarlyStopper, train_step, val_step, train, save_model, transform

# set seed
set_seed(42)

In [2]:
model = CustomCNN() # initialise model

# load data
train_dataset = datasets.Flowers102(root='../data', split='test', download=True, transform=transform) 
val_dataset = datasets.Flowers102(root='../data', split='val', download=True, transform=transform) 
test_dataset = datasets.Flowers102(root='../data', split='train', download=True, transform=transform)
# NOTE: Due to a bug with the Flowers102 dataset, the train and test splits are swapped

# initialise dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [3]:
lr = 0.001 # learning rate
optimiser = torch.optim.Adam(model.parameters(), lr=lr) # initialise optimiser
loss = torch.nn.CrossEntropyLoss() # initialise loss function

if torch.cuda.is_available(): # nvidia gpu
    device = torch.device("cuda")
elif torch.backends.mps.is_available(): # apple gpu
    device = torch.device("mps")
else:
    device = torch.device("cpu")
model.to(device)

epochs = 100 # number of epochs
early_stopper = EarlyStopper(patience=10) # initialise early stopper


# Make directory to save baseline model
baseline_model_path = "./baseline_model"
if not os.path.exists(baseline_model_path):
    os.mkdir(baseline_model_path)

# Define the device-specific path
device_type = None
if device == torch.device("cuda"):
    device_type = "cuda"
elif device == torch.device("mps"):
    device_type = "mps"
else:
    device_type = "cpu"

# Construct the full path
device_path = os.path.join(baseline_model_path, device_type)

# Create the directory if it doesn't exist
if not os.path.exists(device_path):
    os.mkdir(device_path)

In [4]:
train(model, train_dataloader, val_dataloader, optimiser, loss, device, epochs, early_stopper, device_path) # train model

Epoch 1/100:   0%|          | 0/193 [00:00<?, ?it/s]

Epoch 1/100: 100%|██████████| 193/193 [01:18<00:00,  2.47it/s, Training loss=4.4661]


Epoch 1/100 took 84.62s | Train loss: 4.4661 | Val loss: 4.6956 | Val accuracy: 0.98% | EarlyStopper count: 0


Epoch 2/100:  32%|███▏      | 61/193 [00:25<00:55,  2.39it/s, Training loss=4.4232]


KeyboardInterrupt: 