# Imports and setup

In [48]:
# We can now load the dependencies
%matplotlib inline 
import torch 
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt 
import torchvision
import torchsummary
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.optim as optim
import sklearn.metrics as metrics
import torchmetrics
import os
import torchvision.transforms as transforms
from PIL import Image

We can start by setting a seed for reproducibility

In [49]:
torch.manual_seed(0)

<torch._C.Generator at 0x1b819c641d0>

# Pre-processing

In [50]:
class CustomDataset(Dataset):
    
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.images = os.listdir(root_dir)
        self.labels = torch.load(label_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.root_dir, img_name)
        image = Image.open(img_path)
        image = transforms.ToTensor()(image)
        label = self.labels[int(img_name[3:9])]
        return image, label

generated_data_root = "../../Data Generation/Pre Processed Data Generated"
train_gen_dataset = CustomDataset(generated_data_root + "/Square Images/Training", generated_data_root + "/Square Images/y_piece_generated.pt")
val_gen_dataset = CustomDataset(generated_data_root + "/Square Images/Validation", generated_data_root + "/Square Images/y_piece_generated.pt")
test_gen_dataset = CustomDataset(generated_data_root + "/Square Images/Testing", generated_data_root + "/Square Images/y_piece_generated.pt")

# Hyperparameter choices

We create a cell to hold the hyperparameters of the model

In [51]:
learning_rate = 0.001
batch_size = 50
num_epochs = 50
gamma_focal_loss_choices = {2, 3}
dropout_rate_choices = {0.2, 0.5}
n_loss = 100
n_eval_minibatches = 200 # Number of minibatches to use for validation every epoch

We can now create our dataloaders

In [52]:
gen_train_loader = DataLoader(train_gen_dataset, batch_size=batch_size, shuffle=True)

gen_val_loader = DataLoader(val_gen_dataset, batch_size=batch_size, shuffle=True)

gen_test_loader = DataLoader(test_gen_dataset, batch_size=batch_size, shuffle=True)

# real_train_loader = torch.load('../../Real life data/Pre-processed/partial_real_train_loader')

# real_val_loader = torch.load('../../Real life data/Pre-processed/partial_real_val_loader')

# real_test_loader = torch.load('../../Real life data/Pre-processed/partial_real_test_loader')

# Model implementation

We can start by loading a pre-trained VGG16 model without the classification layers towards the end (Only the feature extractor).

In [53]:
vgg16 = torchvision.models.vgg16(weights='VGG16_Weights.IMAGENET1K_V1')

We can now visualize its layers:

In [54]:
vgg16

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

Because we are looking for a pre-trained feature extractor here, we decide to only use the features part and freeze its weights. We can then add a few subsequent layers to fine tune predictions. We can thus define the following model:

In [55]:
class BaseModel(nn.Module):

    def __init__(self, num_classes=13, dropout_rate=0.5):
        
        super(BaseModel, self).__init__()
        
        # Define the layers of the model
        self.features = torchvision.models.vgg16(weights='VGG16_Weights.IMAGENET1K_V1').features
        self.classifier = nn.Sequential(
            nn.Linear(4608, 1024),
            nn.ReLU(True),
            nn.Dropout(dropout_rate),
            nn.Linear(1024, 512),
            nn.ReLU(True),
            nn.Dropout(dropout_rate),
            nn.Linear(512, num_classes)
        )

        # Set the features to not require gradients
        for param in self.features.parameters():
            param.requires_grad = False


    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


# Training

We can start by finding the device to use for training:

In [56]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if DEVICE == "cuda":
    torch.cuda.empty_cache()

We can then go ahead and define the loss function we will be using. Because we will opt for a balanced focal loss instead of a regular cross entropy loss which gives more importance to the classes that are harder to classify. We thus implement the focal loss defined by the following formula:
<center><img src="focal loss.png"></center>

where gamma is a tunable hyperparameter. We also further add an alpha term to handle class imbalance, making our loss function a class-balanced focal loss, as shown in https://github.com/AdeelH/pytorch-multi-class-focal-loss.


Finally, we need an accuracy metric to tune the hyperparameters of the model. We will opt for a balanced accuracy score, which is just regular classification accuracy but adapted to weigh each class by its frequency:

In [57]:
f1_score = torchmetrics.F1Score(task="multiclass", num_classes=13, average="weighted").to(DEVICE)

Finally, because we are using balanced accuracy scores, we can use the class analytics gathered during pre-processing to define the following class distribution:

In [58]:
class_proportions = [0.3198, 0.1602, 0.0405, 0.0400, 0.0406, 0.0201, 0.0404, 0.1596, 0.0392, 0.0397, 0.0400, 0.0196, 0.0404]

We can now proceed to train our model:

In [73]:
# To store the best model
best_model = None
best_balanced_accuracy = 0

for gamma in gamma_focal_loss_choices:

    for dropout_rate in dropout_rate_choices:
        
        # Define the new loss function
        # focal_loss = torch.hub.load(
        #     'adeelh/pytorch-multi-class-focal-loss',
        #     model='FocalLoss',
        #     alpha=torch.tensor(class_proportions),
        #     gamma=gamma,
        #     reduction='mean',
        #     force_reload=False,
        #     verbose = False
        # ).to(DEVICE)
        focal_loss = nn.CrossEntropyLoss().to(DEVICE)

        # To store the metrics through epochs
        training_loss_through_iterations = np.array[])
        gen_validation_loss_through_epochs = np.array([])
        gen_validation_acc_through_epochs = np.array([])
        gen_validation_f1_through_epochs = np.array([])
        real_validation_loss_through_epochs = np.array([])
        real_validation_acc_through_epochs = np.array([])
        real_validation_f1_through_epochs = np.array([])

        basemodel = BaseModel(dropout_rate=dropout_rate).to(DEVICE)
        opt = optim.Adam(basemodel.parameters(), lr=learning_rate)

        for epoch in range(num_epochs):
            
            ##########################################################
            # Training the model until the full real dataset is used #
            ##########################################################
            for i, (X_train_gen, y_train_gen) in enumerate(gen_train_loader):
                
                # Set the model to training mode
                basemodel.train()

                # Move the data to the device
                X_train_gen = X_train_gen.to(DEVICE)
                y_train_gen = y_train_gen.to(DEVICE)

                # Forward pass
                y_train_pred_prob_gen = basemodel(X_train_gen)
                y_train_pred_gen = torch.argmax(y_train_pred_prob_gen, dim=1)

                # Compute the loss
                loss = focal_loss(y_train_pred_prob_gen, y_train_gen.long())

                # Backward pass
                opt.zero_grad()
                loss.backward()
                opt.step()

                # Print the loss & save it every n_loss iteration
                print(f'Epoch [{epoch + 1}], Iteration [{i + 1}], Training Loss: [{loss.item():.4f}]')
                training_loss_through_iterations = np.append(training_loss_through_iterations, (loss.item(), i+1))
           

            ########################################################
            # Evaluate the model every epoch on the validation set #
            ########################################################

            # Set the model to evaluation mode
            basemodel.eval()
            
            # Disable gradient calculation
            with torch.no_grad():

                # 1) Evaluate on the generated validation set
                # TODO: balanced_acc_sum = 0
                weighted_f1_sum = 0

                # Extract the iterator from the data loader
                gen_val_iter = iter(gen_val_loader)

                # Iterate for n_eval_minibatches
                for i in range(n_eval_minibatches):

                    # Get the next minibatch
                    X_val_gen, y_val_gen = next(gen_val_iter)
                        
                    # Move the data to the device
                    X_val_gen = X_val_gen.to(DEVICE)
                    y_val_gen = y_val_gen.to(DEVICE)

                    # Forward pass
                    y_val_pred_prob_gen = basemodel(X_val_gen)
                    y_val_pred_gen = torch.argmax(y_train_pred_prob_gen, dim=1)

                    # Compute the balanced accuracy score
                    # TODO: balanced_acc_sum += metrics.balanced_accuracy_score(y_val_gen.long().cpu(), y_val_pred_gen.long().cpu(), sample_weight=None)
                    weighted_f1_sum += f1_score(y_val_pred_gen, y_val_gen)

                # Compute the average balanced accuracy score & loss
                # TODO: balanced_acc_gen = balanced_acc_sum / n_eval_minibatches
                loss_gen = focal_loss(y_val_pred_prob_gen, y_val_gen)
                weighted_f1_gen = weighted_f1_sum / n_eval_minibatches

                # # 2) Evaluate on the real validation set
                # # TODO: balanced_acc_sum = 0
                # weighted_f1_sum = 0

                # # Iterate for n_eval_minibatches
                # for (X_val_real, y_val_real) in real_val_loader:

                #     # Move the data to the device
                #     X_val_real = X_val_real.to(DEVICE)
                #     y_val_real = y_val_real.to(DEVICE)

                #     # Forward pass
                #     y_val_pred_prob_real = basemodel(X_val_real)
                #     y_val_pred_real = torch.argmax(y_val_pred_prob_real, dim=1)

                #     # Compute the balanced accuracy score
                #     # TODO: balanced_acc_sum += metrics.balanced_accuracy_score(y_val_real.long().cpu(), y_val_pred_real.long().cpu(), sample_weight=None) # TODO: Add sample weights for all 13 classes
                #     weighted_f1_sum += f1_score(y_val_pred_real, y_val_real)

                # # Compute the average balanced accuracy score
                # # TODO: balanced_acc_real = balanced_acc_sum / n_eval_minibatches
                # loss_real = focal_loss(y_val_pred_prob_real, y_val_real)
                # weighted_f1_real = weighted_f1_sum / n_eval_minibatches

                # Store all 6 metrics
                # TODO: gen_validation_acc_through_epochs = np.append(gen_validation_acc_through_epochs, (balanced_acc_gen, epoch+1))
                # TODO: real_validation_acc_through_epochs = np.append(real_validation_acc_through_epochs, (balanced_acc_real, epoch+1))
                gen_validation_loss_through_epochs = np.append(gen_validation_loss_through_epochs, (loss_gen.item(), epoch+1))
                # real_validation_loss_through_epochs = np.append(real_validation_loss_through_epochs, (loss_real.item(), epoch+1))
                gen_validation_f1_through_epochs = np.append(gen_validation_f1_through_epochs, (weighted_f1_gen.item(), epoch+1))
                # real_validation_f1_through_epochs = np.append(real_validation_f1_through_epochs, (weighted_f1_real.item(), epoch+1))

                # TODO: Remove this
                print(f'VALIDATION => Epoch [{epoch + 1}], Balanced Accuracy (Gen): [TODO], Balanced Accuracy (Real): [TODO]')
                print(f'VALIDATION => Epoch [{epoch + 1}], F1 Score (Gen): [{weighted_f1_gen:.4f}], F1 Score (Real): [TODO]')


            ##############################################
            # Save the model every epoch as a checkpoint #
            ##############################################

            torch.save(basemodel.state_dict(), f'./checkpoints/basemodel_{epoch+1}.ckpt')


        # Compare to the best model and save this one if it is better
        if best_model is None or best_balanced_accuracy < balanced_acc_real:
            best_model = basemodel
            best_balanced_accuracy = balanced_acc_real
            torch.save(basemodel.state_dict(), f'./best_model_gamma_{gamma}_dropout_{dropout_rate}.ckpt')


        # Save the training & validation metrics
        np.save(f'./training_loss_gamma_{gamma}_dropout_{dropout_rate}.npy', training_loss_through_iterations, allow_pickle=True)
        np.save(f'./gen_validation_loss_gamma_{gamma}_dropout_{dropout_rate}.npy', gen_validation_loss_through_epochs, allow_pickle=True)
        np.save(f'./gen_validation_acc_gamma_{gamma}_dropout_{dropout_rate}.npy', gen_validation_acc_through_epochs, allow_pickle=True)
        np.save(f'./real_validation_loss_gamma_{gamma}_dropout_{dropout_rate}.npy', real_validation_loss_through_epochs, allow_pickle=True)
        np.save(f'./real_validation_acc_gamma_{gamma}_dropout_{dropout_rate}.npy', real_validation_acc_through_epochs, allow_pickle=True)
        np.save(f'./gen_validation_f1_gamma_{gamma}_dropout_{dropout_rate}.npy', gen_validation_f1_through_epochs, allow_pickle=True)
        np.save(f'./real_validation_f1_gamma_{gamma}_dropout_{dropout_rate}.npy', real_validation_f1_through_epochs, allow_pickle=True)



Epoch [1], Iteration [1], Training Loss: [2.5568]
Epoch [1], Iteration [2], Training Loss: [2.2777]
Epoch [1], Iteration [3], Training Loss: [1.8181]
Epoch [1], Iteration [4], Training Loss: [1.3611]
Epoch [1], Iteration [5], Training Loss: [1.9429]
Epoch [1], Iteration [6], Training Loss: [1.5171]
Epoch [1], Iteration [7], Training Loss: [1.2792]
Epoch [1], Iteration [8], Training Loss: [0.8377]
Epoch [1], Iteration [9], Training Loss: [1.0715]
Epoch [1], Iteration [10], Training Loss: [0.8796]
Epoch [1], Iteration [11], Training Loss: [0.8335]
Epoch [1], Iteration [12], Training Loss: [0.8278]
Epoch [1], Iteration [13], Training Loss: [0.7963]
Epoch [1], Iteration [14], Training Loss: [0.9164]
Epoch [1], Iteration [15], Training Loss: [0.3773]
Epoch [1], Iteration [16], Training Loss: [0.8018]
Epoch [1], Iteration [17], Training Loss: [0.6292]
Epoch [1], Iteration [18], Training Loss: [0.6726]
Epoch [1], Iteration [19], Training Loss: [0.5733]
Epoch [1], Iteration [20], Training Loss

KeyboardInterrupt: 

Now that the best model has been found, we can use it to compute the testing accuracy & f1 score.