# Data Augmentation
- Random Color Jitter
- Random Horizontal Flip
- Random Rotation
- Random Resize and Crop

In [1]:
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
os.chdir("../models")

import pickle
import torch
import torch.nn as nn
import torchvision

from common_utils import set_seed, EarlyStopper, train, get_mean_rgb, test_model
from datetime import datetime
from model import DepthPointWiseCNN
from sklearn.metrics import top_k_accuracy_score, f1_score
from torch.utils.data import DataLoader, default_collate
from torchvision import datasets, transforms
from torchvision.transforms.v2 import MixUp, CutMix, RandomChoice

### Set up variables, seed and pytorch device

In [2]:
# We might want to run the notebook with different parameters using a script
# Use environment varables to set the parameters if thats the case
defaults = {
    "model_name": "AugmentModel",
    "model_seed": 42,
    "batchnorm_moment": 0.05,
    "max_lr": 0.001,
    "min_lr": 0.0001
}
for k, v in defaults.items():
    globals()[k] = os.environ.get(k, defaults[k])
    # Use default value type to infer the variable type
    if isinstance(v, int):
        globals()[k] = int(globals()[k])
    elif isinstance(v, float):
        globals()[k] = float(globals()[k])

# set seed
set_seed(model_seed)

device_type = None
device = None
# determine device type
if torch.cuda.is_available(): # nvidia gpu
    device = torch.device("cuda")
    device_type = "cuda"
elif torch.backends.mps.is_available(): # apple gpu
    device = torch.device("mps")
    device_type = "mps"
else:
    device = torch.device("cpu")
    device_type = "mps"

[globals()[k] for k in defaults.keys()], device_type

(['AugmentModel', 42, 0.05, 0.001, 0.0001], 'cuda')

### Initialise model

In [3]:
model = DepthPointWiseCNN(batchnorm_moment=batchnorm_moment,
                          dropout_rate=0.05).to(device_type) # initialise model
data_dir_path = "../data"

# Make directory to save baseline model
# Don't overwrite exisiting models and model outputs
model_path = f"./saved_models/{model_name}/"
if not os.path.exists(model_path):
    os.makedirs(model_path, exist_ok=True)
else:
    raise Exception('''
        Directory already exists. Either choose a different 'model_name' or
        delete the exisiting directory.
    ''')

# Construct the full path
device_path = os.path.join(model_path, device_type)

# Create the directory if it doesn't exist
if not os.path.exists(device_path):
    os.mkdir(device_path)

### Initialise Dataset

In [4]:
def mixup_collate_fn(batch):
    mixup = MixUp(num_classes=102, alpha=mixup_alpha)
    cutmix = CutMix(num_classes=102, alpha=cutmix_alpha)

    return RandomChoice([mixup, cutmix])(*default_collate(batch))

data_path = "../data"

batch_size = 128
mixup_alpha = 0.2
cutmix_alpha = 1.0

train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.ColorJitter(brightness=0.2, saturation=0.2),
    transforms.RandomHorizontalFlip(0.5),
    # Random rotation triples the time
    # transforms.RandomRotation((-30, 30)),
    transforms.RandomResizedCrop((100, 100), scale=(0.8, 1.0), antialias=True),
])

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((100, 100), antialias=True),
])

# load data
# NOTE: Due to a bug with the Flowers102 dataset, the train and test splits are swapped
train_dataset = datasets.Flowers102(root=data_path,
                                    split='test',
                                    download=True,
                                    transform=train_transform
                                   ) 
val_dataset = datasets.Flowers102(root=data_path,
                                  split='val',
                                  download=True,
                                  transform=transform
                                 ) 
test_dataset = datasets.Flowers102(root=data_path,
                                   split='train',
                                   download=True,
                                   transform=transform
                                  )

# initialise dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=mixup_collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

### Specify hyperparameters

In [5]:
lr = min_lr # learning rate
optimiser = torch.optim.Adam(model.parameters(), lr=lr) # initialise optimiser
loss = torch.nn.CrossEntropyLoss() # initialise loss function
epochs = 30

scheduler = torch.optim.lr_scheduler.OneCycleLR(optimiser,
                                                max_lr=max_lr,
                                                steps_per_epoch=len(train_dataloader),
                                                epochs=epochs)

patience = 10
early_stopper = EarlyStopper(patience=patience) # initialise early stopper

### Train the model

In [None]:
out = train(model, train_dataloader, val_dataloader, optimiser,
            loss, device, epochs, early_stopper, device_path, scheduler
           )
train_loss_list, val_loss_list, val_acc_list, train_time_list, lr_list, early_stop = out

Epoch 1/30: 100%|██████████| 49/49 [01:01<00:00,  1.26s/it, Training loss=4.3893, Learning rate=0.00007]


Epoch 1/30 took 65.76s | Train loss: 4.3893 | Val loss: 4.6629 | Val accuracy: 0.98% | EarlyStopper count: 0


Epoch 2/30: 100%|██████████| 49/49 [00:57<00:00,  1.17s/it, Training loss=3.9543, Learning rate=0.00015]


Epoch 2/30 took 61.83s | Train loss: 3.9543 | Val loss: 4.1389 | Val accuracy: 8.14% | EarlyStopper count: 0


Epoch 3/30: 100%|██████████| 49/49 [00:58<00:00,  1.20s/it, Training loss=3.8608, Learning rate=0.00028]


Epoch 3/30 took 62.93s | Train loss: 3.8608 | Val loss: 3.6101 | Val accuracy: 14.22% | EarlyStopper count: 0


Epoch 4/30: 100%|██████████| 49/49 [00:56<00:00,  1.16s/it, Training loss=3.5217, Learning rate=0.00044]


Epoch 4/30 took 61.16s | Train loss: 3.5217 | Val loss: 3.2534 | Val accuracy: 21.57% | EarlyStopper count: 0


Epoch 5/30: 100%|██████████| 49/49 [00:58<00:00,  1.19s/it, Training loss=3.1694, Learning rate=0.00061]


Epoch 5/30 took 62.53s | Train loss: 3.1694 | Val loss: 3.0334 | Val accuracy: 23.73% | EarlyStopper count: 0


Epoch 6/30: 100%|██████████| 49/49 [00:57<00:00,  1.18s/it, Training loss=3.0400, Learning rate=0.00076]


Epoch 6/30 took 62.24s | Train loss: 3.0400 | Val loss: 2.7862 | Val accuracy: 31.86% | EarlyStopper count: 0


Epoch 7/30: 100%|██████████| 49/49 [00:57<00:00,  1.17s/it, Training loss=3.0080, Learning rate=0.00089]


Epoch 7/30 took 61.49s | Train loss: 3.0080 | Val loss: 2.4019 | Val accuracy: 35.00% | EarlyStopper count: 0


Epoch 8/30: 100%|██████████| 49/49 [00:58<00:00,  1.19s/it, Training loss=2.6229, Learning rate=0.00097]


Epoch 8/30 took 63.02s | Train loss: 2.6229 | Val loss: 2.0610 | Val accuracy: 46.96% | EarlyStopper count: 0


Epoch 9/30: 100%|██████████| 49/49 [00:59<00:00,  1.21s/it, Training loss=2.4706, Learning rate=0.00100]


Epoch 9/30 took 63.77s | Train loss: 2.4706 | Val loss: 2.1169 | Val accuracy: 43.63% | EarlyStopper count: 0


Epoch 10/30: 100%|██████████| 49/49 [00:56<00:00,  1.16s/it, Training loss=2.4694, Learning rate=0.00099]


Epoch 10/30 took 60.75s | Train loss: 2.4694 | Val loss: 1.8740 | Val accuracy: 48.82% | EarlyStopper count: 1


Epoch 11/30: 100%|██████████| 49/49 [00:56<00:00,  1.15s/it, Training loss=2.4075, Learning rate=0.00098]


Epoch 11/30 took 60.56s | Train loss: 2.4075 | Val loss: 1.6884 | Val accuracy: 57.94% | EarlyStopper count: 0


Epoch 12/30: 100%|██████████| 49/49 [00:57<00:00,  1.17s/it, Training loss=2.3535, Learning rate=0.00095]


Epoch 12/30 took 61.80s | Train loss: 2.3535 | Val loss: 1.7360 | Val accuracy: 52.35% | EarlyStopper count: 0


Epoch 13/30: 100%|██████████| 49/49 [00:57<00:00,  1.17s/it, Training loss=2.0897, Learning rate=0.00091]


Epoch 13/30 took 61.75s | Train loss: 2.0897 | Val loss: 2.0038 | Val accuracy: 51.08% | EarlyStopper count: 1


Epoch 14/30: 100%|██████████| 49/49 [00:57<00:00,  1.18s/it, Training loss=2.1514, Learning rate=0.00087]


Epoch 14/30 took 61.91s | Train loss: 2.1514 | Val loss: 1.6110 | Val accuracy: 59.31% | EarlyStopper count: 2


Epoch 15/30: 100%|██████████| 49/49 [00:54<00:00,  1.10s/it, Training loss=2.0111, Learning rate=0.00081]


Epoch 15/30 took 58.11s | Train loss: 2.0111 | Val loss: 1.4235 | Val accuracy: 62.55% | EarlyStopper count: 0


Epoch 16/30: 100%|██████████| 49/49 [00:58<00:00,  1.19s/it, Training loss=2.0150, Learning rate=0.00075]


Epoch 16/30 took 62.65s | Train loss: 2.0150 | Val loss: 1.3097 | Val accuracy: 65.29% | EarlyStopper count: 0


Epoch 17/30: 100%|██████████| 49/49 [00:57<00:00,  1.18s/it, Training loss=1.8209, Learning rate=0.00068]


Epoch 17/30 took 62.28s | Train loss: 1.8209 | Val loss: 1.3557 | Val accuracy: 64.22% | EarlyStopper count: 0


Epoch 18/30: 100%|██████████| 49/49 [00:56<00:00,  1.15s/it, Training loss=1.8355, Learning rate=0.00061]


Epoch 18/30 took 60.62s | Train loss: 1.8355 | Val loss: 1.2518 | Val accuracy: 67.75% | EarlyStopper count: 1


Epoch 19/30: 100%|██████████| 49/49 [00:56<00:00,  1.16s/it, Training loss=1.9323, Learning rate=0.00054]


Epoch 19/30 took 60.67s | Train loss: 1.9323 | Val loss: 1.2860 | Val accuracy: 66.96% | EarlyStopper count: 0


Epoch 20/30: 100%|██████████| 49/49 [00:58<00:00,  1.20s/it, Training loss=1.3940, Learning rate=0.00046]


Epoch 20/30 took 62.94s | Train loss: 1.3940 | Val loss: 1.1315 | Val accuracy: 72.94% | EarlyStopper count: 1


Epoch 21/30: 100%|██████████| 49/49 [00:58<00:00,  1.19s/it, Training loss=1.6251, Learning rate=0.00039]


Epoch 21/30 took 62.61s | Train loss: 1.6251 | Val loss: 1.0841 | Val accuracy: 73.04% | EarlyStopper count: 0


Epoch 22/30: 100%|██████████| 49/49 [00:55<00:00,  1.13s/it, Training loss=1.8203, Learning rate=0.00032]


Epoch 22/30 took 59.62s | Train loss: 1.8203 | Val loss: 1.0119 | Val accuracy: 75.29% | EarlyStopper count: 0


Epoch 23/30: 100%|██████████| 49/49 [00:56<00:00,  1.16s/it, Training loss=1.2230, Learning rate=0.00025]


Epoch 23/30 took 61.15s | Train loss: 1.2230 | Val loss: 1.0296 | Val accuracy: 75.39% | EarlyStopper count: 0


Epoch 24/30: 100%|██████████| 49/49 [00:57<00:00,  1.18s/it, Training loss=1.4943, Learning rate=0.00019]


Epoch 24/30 took 62.47s | Train loss: 1.4943 | Val loss: 1.0153 | Val accuracy: 75.20% | EarlyStopper count: 1


Epoch 25/30: 100%|██████████| 49/49 [00:59<00:00,  1.21s/it, Training loss=1.5045, Learning rate=0.00013]


Epoch 25/30 took 63.60s | Train loss: 1.5045 | Val loss: 0.9787 | Val accuracy: 76.27% | EarlyStopper count: 2


Epoch 26/30: 100%|██████████| 49/49 [00:57<00:00,  1.18s/it, Training loss=1.4693, Learning rate=0.00009]


Epoch 26/30 took 62.17s | Train loss: 1.4693 | Val loss: 0.9752 | Val accuracy: 76.08% | EarlyStopper count: 0


Epoch 27/30: 100%|██████████| 49/49 [00:59<00:00,  1.21s/it, Training loss=1.7226, Learning rate=0.00005]


Epoch 27/30 took 63.55s | Train loss: 1.7226 | Val loss: 0.9626 | Val accuracy: 76.37% | EarlyStopper count: 0


Epoch 28/30: 100%|██████████| 49/49 [00:58<00:00,  1.19s/it, Training loss=1.6507, Learning rate=0.00002]


Epoch 28/30 took 62.93s | Train loss: 1.6507 | Val loss: 0.9650 | Val accuracy: 76.67% | EarlyStopper count: 0


Epoch 29/30: 100%|██████████| 49/49 [00:52<00:00,  1.07s/it, Training loss=1.5006, Learning rate=0.00001]


Epoch 29/30 took 56.45s | Train loss: 1.5006 | Val loss: 0.9826 | Val accuracy: 75.78% | EarlyStopper count: 1


Epoch 30/30:  82%|████████▏ | 40/49 [00:49<00:11,  1.25s/it, Training loss=1.3784, Learning rate=0.00000]

### Plot train, test loss and test accuracy
Note that this is only to visualize how the training was. We will create nicer plots for the report.

In [None]:
import matplotlib.pyplot as plt

plt.plot(train_loss_list, label="train loss")
plt.plot(val_loss_list, label="val loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()

plt.plot(val_acc_list, label="val accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()

plt.show()

### Model Evaluation with test data

In [None]:
checkpoint = torch.load(os.path.join(device_path, 'best_model.pt'))

best_model = DepthPointWiseCNN().to(device_type)
best_model.load_state_dict(checkpoint)

true_labels = torch.stack([label for _, label in test_dataloader])
pred_softmax_labels = test_model(best_model, test_dataloader, loss, device).cpu()
pred_labels = torch.argmax(pred_softmax_labels, dim=1)

top_1_accuracy = top_k_accuracy_score(true_labels, pred_softmax_labels, k=1)
top_5_accuracy = top_k_accuracy_score(true_labels, pred_softmax_labels, k=5)
f1 = f1_score(true_labels, pred_labels, average='micro')

top_1_accuracy, top_5_accuracy, f1

### Save all relevant data/parameters to be used for plots, etc
Note that different models may have the same parameters. There's no guarantee that they can be accessible across different computers. There may also be some redundant parameters added just in case we might need them.

In [None]:
data = {
    "time_trained": datetime.now().strftime("%D,%H:%M:%S"),
    "model_name": model_name,
    "model_seed": model_seed,
    "device_type": device_type,
    
    "batch_size": batch_size,
    "train_transform": train_transform,
    "transform": transform,
    
    "lr": lr,
    "epochs": epochs,
    "patience": patience,
    
    "train_loss_list": train_loss_list,
    "val_loss_list": val_loss_list,
    "val_acc_list": val_acc_list,
    "train_time_list": train_time_list,
    "lr_list": lr_list,
    "early_stop": early_stop,   # Boolean for if early stopping happened

    "true_labels": true_labels,
    "pred_softmax_labels": pred_softmax_labels,
    "pred_labels": pred_labels,
    "top_1_accuracy": top_1_accuracy,
    "top_5_accuracy": top_5_accuracy,
    "f1": f1,

    "batchnorm_moment": batchnorm_moment
}

with open(os.path.join(model_path, "data.pickle"), "wb") as f: 
    pickle.dump(data, f)