## FIT3162 Group MCS2 Model Training
#### Import PyTorch and related stuff

In [None]:
from torch.utils.data import DataLoader, random_split, Subset, ConcatDataset
from sklearn.model_selection import train_test_split
from torchvision import transforms, models
from sklearn.metrics import classification_report
from PIL import ImageFile
from math import modf
from CustomDataset import CustomDataset


import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
import seaborn as sn
import pandas as pd
import numpy as np

import torch
import time
import random
import os

#### Select Device

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

#### Seed everything to try and make everything reproducible

In [None]:
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)
os.environ['PYTHONHASHSEED'] = str(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True

#### Init parameters and transforms

In [None]:
# set training parameters
BATCH_SIZE = 15 ## Change batch size as required, if when training there is not enough memory, decrease it,
                            ## Else try to increase it and push it, bigger batch size, less epoch may be required to reach
                            ## desired accuracy/diminishing return point
PROCESSES = 8 ## Maximum is how much logical processors.
EPOCHS = 10  

# Initialize transformation
transform = transforms.Compose([
    # resize
    transforms.Resize(256),
    # center_crop
    transforms.CenterCrop(224),
    transforms.GaussianBlur(3, 1),
    transforms.RandomGrayscale(0.1),
    transforms.RandomHorizontalFlip(0.25),
    transforms.RandomVerticalFlip(0.25),
    transforms.ToTensor(),
    transforms.RandomErasing(p=0.4, scale=(0.02, 0.25)),
    transforms.RandomApply(transforms=[transforms.RandomAffine(degrees=(-30, 30), translate=(0.1, 0.3),
                                                               scale=(0.75, 0.95))], p=0.5),
    transforms.Normalize(mean=[0.4704, 0.4565, 0.4425], std=[0.3045, 0.2898, 0.2999])
])

transform_test = transforms.Compose([
    # resize
    transforms.Resize(256),
    # center_crop
    transforms.CenterCrop(224),
    transforms.GaussianBlur(3, 1),  # Remove noise
    transforms.ToTensor(),
        transforms.Normalize(mean=[0.4704, 0.4565, 0.4425], std=[0.3045, 0.2898, 0.2999])
])

transform_unnormalize = transforms.Compose([
    transforms.Normalize(mean=[-0.4704 / 0.3045, -0.4565 / 0.2898, -0.4425 / 0.2999],
                         std=[1.0 / 0.3045, 1.0 / 0.2898, 1.0 / 0.2999])
])

### Get Dataset
#### Original dataset -> get sub label for all images -> (magic bodge: built in function to go and get labels, which corresponds correctly)

#### -> Split using sklearn

In [None]:
# timing dataset preprocessing in case it takes too long in the future when dataset gets bloated
start = time.time()
print(f'Loading main dataset... ', end="")
labels_map = ["Non Sports", "Sports"]

dataset = CustomDataset(root="./dataset", transform=transform)
dataset_test = CustomDataset(root="./dataset", transform=transform_test)
# Get sub classes in each category
classes_sport = dataset.find_classes(dataset.root+"/sport")
classes_non_sport = dataset.find_classes(dataset.root+"/non sport")
# Combined classes
both_combined =  classes_non_sport[0] + classes_sport[0]
# Bodge to use built in function
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
# Create labels for built in func

non_sport_dict = {}
for i in range(len(classes_non_sport[0])):
    non_sport_dict[classes_non_sport[0][i]] = i

sport_dict = {}
for i in range(len(classes_non_sport[0]), len(classes_sport[0])+len(classes_non_sport[0])):
    sport_dict[classes_sport[0][i-len(classes_non_sport[0])]] = i

print("Done\n")
    
# Use build in function to get us labels, however it is inefficient because
# it also recreates a whole new dataset variable...
sports_temp = dataset.make_dataset(dataset.root+"/sport", sport_dict, IMG_EXTENSIONS, None)
non_sports_temp = dataset.make_dataset(dataset.root+"/non sport", non_sport_dict, IMG_EXTENSIONS, None)
# Retrieve the sub label 
# Must be non_sports first! A-Z!
non_sports_label = [s[1] for s in non_sports_temp]
sports_label = [s[1] for s in sports_temp]
# Combine and add as attribute
combined = non_sports_label + sports_label
dataset.sub_labels = combined
dataset_test.sub_labels = combined

In [None]:
print("Splitting into Train, Validation and Test sets... ", end="")
# Split into train, test and validation sets
train_idx, test_idx, train_labels, test_labels = train_test_split(
    np.arange(len(dataset.sub_labels)), dataset.sub_labels, test_size=0.2, random_state=42, shuffle=True, stratify=dataset.sub_labels)
test_idx, valid_idx, test_labels, val_label = train_test_split(test_idx, test_labels, test_size=0.5, random_state=42, shuffle=True, stratify=test_labels)

# Create the sets
train_data = Subset(dataset, train_idx)
test_data = Subset(dataset_test, test_idx)
val_data = Subset(dataset_test, valid_idx)

train_size = len(train_data)
val_size = len(val_data)
test_size = len(test_data)

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, num_workers=PROCESSES , shuffle=True, pin_memory=True)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, num_workers=PROCESSES , shuffle=True, pin_memory=True)
val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, num_workers=PROCESSES , shuffle=True, pin_memory=True)

runtime = time.time() - start
seconds, minutes = modf(runtime/60)
seconds *= 60

print(f'Done\n\nDataset images    : {len(dataset)}')
print(f'Train set         : {train_size} images, {train_size/len(dataset)*100}%')
print(f'Validation set    : {val_size} images, {val_size/len(dataset)*100}%')
print(f'Test set          : {test_size} images, {test_size/len(dataset)*100}%')
print(f'Sum of split sets : {train_size+val_size+test_size} images\n')
print(f'Total time taken for dataset: {minutes:.0f} min {seconds:.2f} sec\n-----------------------------\n')

In [None]:
both_combined

In [None]:
for i in both_combined:
    print(i)
print(f'\nTotal sub labels: {len(both_combined)}')

In [None]:
# Count each class in original
count_original = {classname: 0 for classname in range(len(both_combined))}
for i in range(len(dataset)):
    label = dataset.sub_labels[i]
    count_original[label] += 1
count_original

In [None]:
# Count each class in train
count_train = {classname: 0 for classname in range(len(both_combined))}
for each in train_idx:
    label = dataset.sub_labels[each]
    count_train[label] += 1
count_train

In [None]:
# Visualise image to double check we have correct sublabel
# We check validation set
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(val_size, size=(1,)).item()
    img, label = val_data[sample_idx]
    sublabel = val_label[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(both_combined[sublabel])
    plt.axis("off")
    img = transform_unnormalize(img)     # attempt to unnormalize...
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()

#### Visualize according to sports and non sports

In [None]:
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(train_size, size=(1,)).item()
    img, label = train_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label["main_label"]])
    plt.axis("off")
    img = transform_unnormalize(img)     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()

#### Set up model

In [None]:
class MultiOutputModel(nn.Module):

    def __init__(self, n_main_classes, n_sub_classes):
        super().__init__()
        self.base_model = nn.Sequential(*list(models.resnet50(weights="ResNet50_Weights.DEFAULT").children())[:-2])  # take the model without classifier
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        # create separate classifiers for our outputs
        #self.main_label = models.alexnet(weights='AlexNet_Weights.DEFAULT').classifier
        self.main_label = nn.Linear(in_features=2048, out_features=n_main_classes)
        #self.sub_label = models.alexnet(weights='AlexNet_Weights.DEFAULT').classifier
        self.sub_label = nn.Linear(in_features=2048, out_features=n_sub_classes)

    def forward(self, x):
        x = self.base_model(x)
        x = self.pool(x)

        # reshape from [batch, channels, 1, 1] to [batch, channels] to put it into classifier
        x = torch.flatten(x, start_dim=1)

        return {
            'main_label': self.main_label(x),
            'sub_label': self.sub_label(x),
        }
    
    def get_loss(self, net_output, ground_truth):
        main_loss = F.cross_entropy(net_output['main_label'], ground_truth['main_label'])
        sub_loss = F.cross_entropy(net_output['sub_label'], ground_truth['sub_label'])
        loss = main_loss + sub_loss
        return loss, {'main_loss': main_loss, 'sub_loss': sub_loss}

# Load Model
model_ft = MultiOutputModel(2, len(both_combined))
model_ft = model_ft.to(device)

In [None]:
def to_device_dict(obj, device):
    res = {}
    for k, v in obj.items():
      res[k] = v.to(device)
    return res
# Training information
TRAIN_STAT = {
    "train_loss": [],
    "train_acc": [],
    "val_loss": [],
    "val_acc": []
}
optimizer = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.8) # Multiply learning rate by 0.8 every 2 epoch

#### Train model

In [None]:
print("Training model...")
start = time.time()

for epoch in range(EPOCHS):  # loop over the training set multiple times
    
    print(f'\n-----------------------------\n\nEPOCH: {epoch+1}/{EPOCHS}')
    print('Current lr: {0}'.format(optimizer.param_groups[0]['lr']))
    model_ft.train()
    
    # reset loss and correct values
    train_loss = val_loss = 0.0
    total_train_loss = total_val_loss = 0.0
    train_correct = val_correct = 0.0
    print('\nTraining...')
    
    for i, data in enumerate(train_loader):
        # get the inputs; data is a list of [inputs, labels]        
        inputs, labels = data[0].to(device), data[1]
        labels = to_device_dict(labels, device)
        # zero the parameter gradients
        optimizer.zero_grad()
        # make prediction and calculate train loss
        outputs = model_ft(inputs)
        loss, losses_by_class = model_ft.get_loss(outputs, labels)
        
        # forward + backward + optimize
        loss.backward()
        optimizer.step()

        # keep track of loss and correct statistics
        train_loss += loss.item()
        total_train_loss += loss.item() 
        train_correct += (outputs["main_label"].argmax(1) == labels["main_label"]).float().sum().item()
        
        if i % 20 == 19:    # print every 20 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {train_loss / 20:.10f}')
            train_loss = 0.0
        
    print('\nValidating...')
    
    # turn off gradient tracking and computation
    with torch.no_grad():
        
        model_ft.eval()
        
        for i, data in enumerate(val_loader):
            # make predictions and calculate validation loss
            inputs, labels = data[0].to(device), data[1]
            labels = to_device_dict(labels, device)
            outputs = model_ft(inputs)
            loss, losses_by_class = model_ft.get_loss(outputs, labels)
            
            # keep track of loss and correct statistics
            val_loss += loss.item()
            total_val_loss += loss.item()
            val_correct += (outputs["main_label"].argmax(1) == labels["main_label"]).float().sum().item()
            
            if i % 5 == 4:    # print every 5 mini-batches
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {val_loss / 5:.10f}')
                val_loss = 0.0
    
    # calculate average loss at current epoch
    avg_train_loss = total_train_loss / train_size
    avg_val_loss = total_val_loss / val_size
    
    # calculate accuracy at current epoch
    train_acc = (train_correct / train_size) * 100
    val_acc = (val_correct / val_size) * 100
    
    # store statistics
    TRAIN_STAT["train_loss"].append(avg_train_loss)
    TRAIN_STAT["train_acc"].append(train_acc)
    TRAIN_STAT["val_loss"].append(avg_val_loss)
    TRAIN_STAT["val_acc"].append(val_acc)
    
    # print statistics of current epoch
    print(f'\nAverage Train loss: {avg_train_loss:.5f}, Train accuracy: {train_acc:.2f}%')
    print(f'Average Validation loss: {avg_val_loss:5f}, Validation accuracy: {val_acc:.2f}%')
    
    #scheduler step
    scheduler.step()

# calculate total time taken for the training process
runtime = time.time() - start
seconds, minutes = modf(runtime/60)
seconds *= 60
print('\n-----------------------------\nFinished Training\nTotal time taken for training: %d min %d sec' % (minutes, seconds))

#### Visualise training stats

In [None]:
plt.style.use("ggplot")

plt.figure(figsize = (12,5))
plt.plot(TRAIN_STAT["train_acc"], label="train_acc")
plt.plot(TRAIN_STAT["val_acc"], label="val_acc")
plt.title("Training Accuracy on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Accuracy")
plt.legend(loc="lower left")

plt.savefig('accuracy.png')

plt.figure(figsize = (12,5))
plt.plot(TRAIN_STAT["train_loss"], label="train_loss")
plt.plot(TRAIN_STAT["val_loss"], label="val_loss")
plt.title("Training Loss on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss")
plt.legend(loc="lower left")

plt.savefig('loss.png')

#### Save Model

In [None]:
model_scripted = torch.jit.script(model_ft) # Export to TorchScript
model_scripted.save('resnet50_sports_non_sports_multilabel.pt') # Save

In [None]:
print("Testing model...\n-----------------------------")
with torch.no_grad():
    
    model_ft.eval()
    
    predictions = []
    actual = []
    test_correct = 0
    
    for i, data in enumerate(test_loader):
        # make predictions and add to list of predictions
        inputs, labels = data[0].to(device), data[1]
        main_label = labels["main_label"]
        outputs = model_ft(inputs)
        output_main_label = outputs["main_label"]
        predictions.extend(output_main_label.argmax(axis=1).cpu().numpy())
        actual.extend(main_label.cpu().numpy())
        
# print the results of the predictions in the form of a confusion matrix
print(classification_report(np.array(predictions), np.array(actual), target_names=labels_map))

In [None]:
print("Testing model on sub-category...\n-----------------------------")
with torch.no_grad():
    
    model_ft.eval()
    
    predictions = []
    actual = []
    test_correct = 0
    
    for i, data in enumerate(test_loader):
        # make predictions and add to list of predictions
        inputs, labels = data[0].to(device), data[1]
        main_label = labels["sub_label"]
        outputs = model_ft(inputs)
        output_main_label = outputs["sub_label"]
        predictions.extend(output_main_label.argmax(axis=1).cpu().numpy())
        actual.extend(main_label.cpu().numpy())
        
# print the results of the predictions in the form of a confusion matrix
print(classification_report(np.array(predictions), np.array(actual), target_names=both_combined))

#### Visualize and intepret model
#### Init values and related stuff

In [None]:
from PIL import Image
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = torch.jit.load('resnet50_sports_non_sports_multilabel.pt')
model.to(device)
model.eval();

In [None]:
transform_resize = transforms.Compose([
    # resize
    transforms.Resize(256),
    # center_crop
    transforms.CenterCrop(224),
    transforms.GaussianBlur(3, 1),  # Remove noise
    transforms.ToTensor(),
])
transform_normalize = transforms.Compose([transforms.Normalize(mean=[0.4750, 0.4603, 0.4470], std=[0.3053, 0.2899, 0.2997])])

#### Load image and get predicted label

In [None]:
img_name = 'fake_injury_soccer.jpg'
img = Image.open('test_images/'+img_name).convert('RGB')
resized_img = transform_resize(img)
transformed_img = transform_normalize(resized_img)
input = transformed_img.unsqueeze(0)
input = input.to(device)

In [None]:
output = model(input)
main_label_output = F.softmax(output['main_label'], dim=1)
sub_label_output = F.softmax(output['sub_label'], dim=1)

In [None]:
prediction_score, pred_label_idx = torch.topk(main_label_output, 1)
pred_label_main = labels_map[pred_label_idx.item()]

prediction_score_sub, pred_label_idx_sub = torch.topk(sub_label_output, 1)
pred_label_sub = both_combined[pred_label_idx_sub.item()]
print(f'Testing image: "{img_name}"')
print(f'Main label: {pred_label_main}, Sub label: {pred_label_sub}')
print(f'Main label score: {prediction_score.item()*100:.2f}%, Sub label score: {prediction_score_sub.item()*100:.2f}%')

#### Import captum library, setup for model interpretation

In [None]:
from captum.attr import IntegratedGradients
from captum.attr import LayerIntegratedGradients
from captum.attr import GradientShap
from captum.attr import Occlusion
from captum.attr import NoiseTunnel
from captum.attr import visualization as viz
from captum.attr._utils.input_layer_wrapper import ModelInputWrapper
from matplotlib.colors import LinearSegmentedColormap

def wrapped_model(inp):
    return model(inp)["main_label"]

def wrapped_model_sub(inp):
    return model(inp)["sub_label"]

In [None]:
default_cmap = LinearSegmentedColormap.from_list('custom blue', 
                                                 [(0, '#ffffff'),
                                                  (0.25, '#000000'),
                                                  (1, '#000000')], N=256)

#### Visualise interpretation based on predicted label (main label)

In [None]:
integrated_gradients = IntegratedGradients(wrapped_model)
noise_tunnel = NoiseTunnel(integrated_gradients)
attributions_ig_nt = noise_tunnel.attribute(input, nt_samples=10, nt_type='smoothgrad_sq', target=pred_label_idx, nt_samples_batch_size=1)
plt = viz.visualize_image_attr_multiple(np.transpose(attributions_ig_nt.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      np.transpose(resized_img.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      ["original_image", "heat_map"],
                                      ["all", "positive"],
                                      cmap=default_cmap,
                                      show_colorbar=True)


In [None]:
plt[0].savefig("result_intepret.png")

#### Visualise interpretation based on predicted label (sub label)

In [None]:
integrated_gradients = IntegratedGradients(wrapped_model_sub)
noise_tunnel = NoiseTunnel(integrated_gradients)
attributions_ig_nt = noise_tunnel.attribute(input, nt_samples=10, nt_type='smoothgrad_sq', target=pred_label_idx_sub, nt_samples_batch_size=1)
plt = viz.visualize_image_attr_multiple(np.transpose(attributions_ig_nt.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      np.transpose(resized_img.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      ["original_image", "heat_map"],
                                      ["all", "positive"],
                                      cmap=default_cmap,
                                      show_colorbar=True)

#### Visualisation on our test data through Captum insights

In [None]:
from captum.insights import AttributionVisualizer, Batch
from captum.insights.attr_vis.features import ImageFeature

In [None]:
def get_classes():
    classes = ['Pedestrian', 'Queue', 'Reading', 'cello',
               'driving', 'guitar', 'harp', 'using computer',
               'violin', 'Badminton', 'Basketball', 'Cycling',
               'Football', 'Tennis', 'squash']
    return classes

def get_classes_main():
    classes = ['Non Sports', 'Sports']
    return classes

def get_pretrained_model():
    model = torch.jit.load('resnet50_sports_non_sports_multilabel.pt')
    model = model.cpu()
    model.eval()
    def wrapped_model_sub(inp):
        return model(inp)["main_label"]
    return wrapped_model_sub

In [None]:
def baseline_func(input):
    return input * 0


def formatted_data_iter():
    dataloader = iter(test_loader)
    while True:
        images, labels = next(dataloader)
        yield Batch(inputs=images, labels=labels["main_bel"])

In [None]:
normalize = transforms.Normalize([0.4405, 0.4096, 0.3896], [0.3089, 0.2917, 0.2924])
model = get_pretrained_model()
visualizer = AttributionVisualizer(
    models=[model],
    score_func=lambda o: torch.nn.functional.softmax(o, 1),
    classes=get_classes(),
    features=[
        ImageFeature(
            "Photo",
            baseline_transforms=[baseline_func],
            input_transforms=[normalize],
        )
    ],
    dataset=formatted_data_iter(),
)

Run the cell below after running everything above which sets everything up
to open Captum insights

In [None]:
visualizer.render()