# Model training with CREDO image dataset

In [None]:
%run ./notebook_init.py

import os
import torch
import torchvision

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from collections import Counter
from torchsummary import summary
from torchvision import transforms
from sklearn import metrics

from core import DATA_FOLDER
from scripts.credo_training_utils import TRAINING_FOLDERPATH,\
    ModelTraining, ImageFolderWithPath, Seed,\
    resnet18_model, predict_model

In [None]:
processed_data_folder = os.path.join(DATA_FOLDER, "credo_processed_dataset")

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
seed = Seed()

In [None]:
img_size = (60,60)

data_transforms = {
    "train": transforms.Compose([
        transforms.Resize(img_size),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation((0, 360), fill=(0,)),
        transforms.ToTensor(),
        transforms.Normalize(0, 1)
    ]),
    "val": transforms.Compose([
        transforms.Resize(img_size),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation((0, 360), fill=(0,)),
        transforms.ToTensor(),
        transforms.Normalize(0, 1)
    ]),
    "test": transforms.Compose([
        transforms.Resize(img_size),
        transforms.ToTensor(),
        transforms.Normalize(0, 1)
    ])
}

In [None]:
folders_list = ["train", "val", "test"]

In [None]:
image_datasets = {x: ImageFolderWithPath(os.path.join(processed_data_folder, x),
                                         data_transforms[x])
                                         for x in folders_list}

dataloaders = {
    x: torch.utils.data.DataLoader(
        image_datasets[x],
        batch_size=64,
        shuffle=True,
        num_workers=2,
        worker_init_fn=seed.seed_worker
    ) for x in folders_list
}

In [None]:
dataset_sizes = {x: len(image_datasets[x]) for x in folders_list}
class_names = image_datasets["train"].classes
dataset_class_qty = {x: dict(Counter(image_datasets[x].targets)) for x in folders_list}
class_qty = len(class_names)

print(f"Class quantity: {class_qty}")
print(f"Class names: {class_names}")
for i in folders_list: print(f"{i}: {dataset_class_qty[i]}")

## Plotting a batch

In [None]:
def imshow(input_img, title=None):
    """ Imshow for Tensor """
    img = np.asarray(input_img).transpose((1, 2, 0))
    plt.imshow(img, vmin=0, vmax=5)
    if title:
        plt.title(title)
    # pause a bit to update plots
    plt.pause(0.001)

In [None]:
print("Batch of training data")
# Iterate through the data loader
inputs, classes, _ = next(iter(dataloaders["train"]))
# Generate image grid
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])

print("Batch of validation data")
inputs, classes, _ = next(iter(dataloaders["val"]))
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])

print("Batch of test data")
inputs, classes, _ = next(iter(dataloaders["test"]))
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])

## Instantiating the model

In [None]:
model_data_folder = os.path.join(TRAINING_FOLDERPATH, "best_model_weight")
os.makedirs(model_data_folder, exist_ok=True)

In [None]:
best_model_filepath = os.path.join(model_data_folder, "best_model_params.pt")

In [None]:
resnet18 = resnet18_model(device, class_qty)
print(resnet18)

In [None]:
summary(resnet18, (3, 64, 64))

## Training

In [None]:
num_epochs = 150

model_training = ModelTraining(resnet18)
model_ft_randstart = model_training.train_model(device, dataloaders,
                                                dataset_sizes,
                                                num_epochs,
                                                best_model_filepath)

In [None]:
acc_train, acc_val, loss_train, loss_val = model_training.get_acc_loss()

In [None]:
plt.plot(loss_train, label="Train")
plt.plot(loss_val,label="Validation")
plt.ylabel("Loss")
plt.xlabel("Epoch")
plt.legend()
plt.xlim(0, num_epochs)
plt.grid()
plt.show()

In [None]:
plt.plot(acc_train, label="Train")
plt.plot(acc_val, label="Validation")
plt.ylabel("Accuracy")
plt.xlabel("Epoch")
plt.legend()
plt.xlim(0, num_epochs)
plt.grid()
plt.show()

## Analysing the metrics

In [None]:
metrics_folder = os.path.join(TRAINING_FOLDERPATH, "metrics")
os.makedirs(metrics_folder, exist_ok=True)

In [None]:
saved_model = resnet18_model(device, class_qty)
saved_model.load_state_dict(torch.load(best_model_filepath))

In [None]:
predicted_label, true_label = predict_model(device, saved_model, class_names, dataloaders["test"])

In [None]:
test_accuracy = metrics.accuracy_score(true_label, predicted_label)
test_precision = metrics.precision_score(true_label, predicted_label, average="macro")
test_recall = metrics.recall_score(true_label, predicted_label, average="macro")
test_bal_accuracy = metrics.balanced_accuracy_score(true_label, predicted_label)
test_f1 = metrics.f1_score(true_label, predicted_label, average="macro")

print("Test Accuracy: {:.4f}".format(test_accuracy))
print("Test Precision: {:.4f}".format(test_precision))
print("Test Recall: {:.4f}".format(test_recall))
print("Test Balanced Accuracy: {:.4f}".format(test_bal_accuracy))
print("Test F1-Score: {:.4f}".format(test_f1))

print("\nConfusion Matrix - Test data")
confusion_mtx = metrics.confusion_matrix(true_label, predicted_label)
print(confusion_mtx)

with open(os.path.join(metrics_folder, "metrics.txt"), "w") as metrics_txt:
    metrics_txt.write(f"Test Accuracy\t {test_accuracy:.4f}\n")
    metrics_txt.write(f"Test Precision\t {test_precision:.4f}\n")
    metrics_txt.write(f"Test Recall\t {test_recall:.4f}\n")
    metrics_txt.write(f"Test Balanced Accuracy\t {test_bal_accuracy:.4f}\n")
    metrics_txt.write(f"Test F1-Score:\t {test_f1:.4f}")

In [None]:
metrics.ConfusionMatrixDisplay(confusion_mtx, display_labels=class_names).plot()
plt.title("Confusion Matrix - Test data")
plt.grid(False)
plt.savefig(os.path.join(metrics_folder, "confusion_mtx.png"))
plt.show()

In [None]:
classification_report = metrics.classification_report(true_label, predicted_label,
                                                      zero_division=1, output_dict=True,
                                                      target_names=class_names)
sns.heatmap(pd.DataFrame(classification_report).iloc[:-1, :].T, annot=True)
plt.title("Classification report - Test data")
plt.savefig(os.path.join(metrics_folder, "classification_report.png"), bbox_inches="tight")
plt.show()