In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.optim import Adam
from torchvision import datasets
from torch.utils.data import ConcatDataset
from torchvision import models
from torchvision.models import MobileNet_V3_Small_Weights

from sklearn.model_selection import KFold

import matplotlib.pyplot as plt
import random

from tqdm import tqdm
from colorama import Fore, Style

from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score, precision_score, recall_score, roc_auc_score, roc_curve, auc
from sklearn.metrics import ConfusionMatrixDisplay

In [None]:
# colorama
red = Fore.RED
green = Fore.GREEN
blue = Fore.BLUE
yellow = Fore.YELLOW
cyan = Fore.CYAN

reset = Style.RESET_ALL

In [None]:
# Data
d = ".../Rock/"

fld = 'PyDL_C'

# Sub-Categorized data
train_dir = d + "k_fold_data/train"
test_dir = d + "k_fold_data/test"

In [None]:
# Setting the seed
seed = 42
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

print(f'{blue}Global seed set to : {yellow}{seed}\n')

In [None]:
img_dimen = (256, 256)
bs = 16

In [None]:
# preprocessing | get the data mean and std for normalization
transform = transforms.Compose([
    transforms.Resize(img_dimen),
    transforms.ToTensor()
])

calc_ms = datasets.ImageFolder(root=train_dir, transform=transform)
loader_ms = torch.utils.data.DataLoader(dataset=calc_ms, batch_size=bs, shuffle=False)

mean_calc = 0
std_calc = 0
total_images = 0

for images, _ in tqdm(loader_ms):
    batch_samples = images.size(0)
    images = images.view(batch_samples, images.size(1), -1)
    mean_calc += images.mean(2).sum(0)
    std_calc += images.std(2).sum(0)
    total_images += batch_samples

mean_calc /= total_images
std_calc /= total_images

print(f'{blue}mean: {yellow}{mean_calc}')
print(f'{blue}std: {yellow}{std_calc}{reset}')

In [None]:
# ImageNet Normalization
#mean_calc = [0.485, 0.456, 0.406]
#std_calc = [0.229, 0.224, 0.225]

In [None]:
# Data transformations training set
transform_all = transforms.Compose([
    transforms.RandomResizedCrop(img_dimen),
    transforms.ToTensor(),
    transforms.Normalize(mean_calc, std_calc)
])

# Data transformations for validation and test sets
transform_test = transforms.Compose([
    transforms.Resize(img_dimen),
    transforms.ToTensor(),
    transforms.Normalize(mean_calc, std_calc)
])

dataset = datasets.ImageFolder(root=train_dir, transform=transform_all)

In [None]:
# test set
batch_size = bs
dataset_test = datasets.ImageFolder(test_dir, transform=transform_test)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=False)

In [None]:
# Classes
num_classes = len(dataset.classes)

# Define the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'{blue}Device: {yellow}{device}{reset}')

In [None]:
# Hyperparameters
fold = 5
max_epoch = 30
batch_size = 16
learningRate = 0.0001
WeightDecay = 1e-08

# All Information
print(f'{blue}Fold: {yellow}{fold}{reset}')
print(f'{blue}Epochs: {yellow}{max_epoch}{reset}')
print(f'{blue}Batch size: {yellow}{batch_size}{reset}')
print(f'{blue}Learning rate: {yellow}{learningRate}{reset}')
print(f'{blue}Weight decay: {yellow}{WeightDecay}{reset}')

In [None]:
# k-fold cross-validation
kf = KFold(n_splits=fold, shuffle=True, random_state=seed)

In [None]:
# K fold cross-validation

# Define your train and validation scores for all folds
# Loss metrics
train_loss_all = []
val_loss_all = []
# Accuracy metrics
train_acc_all = []
val_acc_all = []

# validation accuracy for calculating average
fold_val_acc = []

# Loop over each fold
for fold, (train_index, val_index) in enumerate(kf.split(dataset)):    
    print(f'{yellow}\n##############################################')
    print(f'{green}                   FOLD {fold + 1}')
    print(f'{yellow}##############################################{reset}')

    # Define your train and validation datasets
    train_dataset = torch.utils.data.Subset(dataset, train_index)
    val_dataset = torch.utils.data.Subset(dataset, val_index)

    # Define your train and validation dataloaders
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    # -----------------------------------------------------------------------------------
    
    # MobileNet V3 Small
    model = models.mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1)

    num_classes = len(dataset.classes)

    num_ftrs = model.classifier[3].in_features
    model.classifier[3] = nn.Linear(in_features=num_ftrs, out_features=num_classes)

    model.to(device)
    
    # -----------------------------------------------------------------------------------
    
    # Loss and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=learningRate, weight_decay=WeightDecay)
    
    # -----------------------------------------------------------------------------------  
    
    # TRAINING
    # loss metrics
    train_loss = []
    val_loss = []
    # Accuracy metrics
    train_acc = []
    val_acc = []

    # Max score for the current fold
    max_curr_fold = 0

    # Loop over each epoch
    for epoch in range(max_epoch):
        model.train()

        # Metrics initialization
        running_loss = 0.0
        num_correct = 0

        # TRAINING
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Predictions | forward pass | OUTPUT
            outputs = model(inputs)
            # Loss | backward pass | GRADIENT
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # Metrics
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            # Count correct predictions
            num_correct += (predicted == labels).sum().item()
            
        # ---------------------------------------------------------------------------
        # Training loss
        train_lss = running_loss / len(train_loader)
        train_loss.append(train_lss)

        # Training accuracy
        train_accuracy = 100 * num_correct / len(train_loader.dataset)
        train_acc.append(train_accuracy)
        # ---------------------------------------------------------------------------

        model.eval()
        correct = 0
        valid_loss = 0

        # VALIDATION
        with torch.no_grad():
            for data in val_loader:
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                # Predictions
                outputs = model(inputs)
                # Count correct predictions
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                # Loss
                valid_loss += criterion(outputs, labels).item()

        # --------------------------------------------------------------------------
        #Validation loss
        val_lss = valid_loss / len(val_loader)
        val_loss.append(val_lss)

        # Validation accuracy
        val_accuracy = 100 * correct / len(val_loader.dataset)
        val_acc.append(val_accuracy)
        
        # --------------------------------------------------------------------------
        
        print(f'{cyan}\nEPOCH {epoch + 1}{reset}')
        print(f"Loss: {red}{train_lss}{reset}, Validation Accuracy: {red}{val_accuracy}%{reset}, Training Accuracy: {red}{train_accuracy}%")
        
        # Save the best model of each fold
        if val_accuracy > max_curr_fold:
            max_curr_fold = val_accuracy
            ff = fold + 1
            path = d + fld + '/models/MobileNetV3_s_fold_T_' + str(ff) +'.pth'
            torch.save(model.state_dict(), path)
            print(f'{green}Improvement! Model saved!{reset}')
    
    # save last model
    ff = fold + 1
    path = d + fld + '/models/MobileNetV3_s_fold_F_' + str(ff) +'.pth'
    torch.save(model.state_dict(), path)
    
    # ------------------------------------------------------------------------------
    
    # metrics for graph for current fold
    train_loss_all.append(train_loss)
    val_loss_all.append(val_loss)
    
    train_acc_all.append(train_acc)
    val_acc_all.append(val_acc)
    
    # the highest validation accuracy of each fold       
    fold_val_acc.append(max_curr_fold)
    
    # ------------------------------------------------------------------------------
        
print(f'{yellow}\nTraining finished!')

In [None]:
# Graph of training and validation: loss and accuracy | dual plots for each fold
fig, axis = plt.subplots(5, 2, figsize=(20, 40))

for i in range(5):
    # Loss plot
    axis[i, 0].set_title("Fold " + str(i+1) + ": Loss")
    axis[i, 0].plot(val_loss_all[i], color='red', label='Validation loss', linestyle='dashed')
    axis[i, 0].plot(train_loss_all[i], color='orange', label='Training loss')
    axis[i, 0].legend()
    axis[i, 0].set_xlabel("Iterations")
    axis[i, 0].set_ylabel("Loss")

    # Accuracy plot
    axis[i, 1].set_title("Fold " + str(i+1) + ": Accuracy")
    axis[i, 1].plot(val_acc_all[i], color='red', label='Validation accuracy', linestyle='dashed')
    axis[i, 1].plot(train_acc_all[i], color='orange', label='Training accuracy')
    axis[i, 1].legend()
    axis[i, 1].set_xlabel("Iterations")
    axis[i, 1].set_ylabel("Accuracy")

plt.show()

In [None]:
import numpy as np

# Graph of training and validation: loss and accuracy | single plot for all folds
fig, axis = plt.subplots(1, 2, figsize=(20, 10))

acc_mean = []
loss_mean = []

for i in range(5):
    acc_mean.append(sum(val_acc_all[i]) / len(val_acc_all[i]))
    loss_mean.append(sum(val_loss_all[i]) / len(val_loss_all[i]))
    
acc_std = []
loss_std = []

for i in range(5):
    acc_std.append(np.std(val_acc_all[i]))
    loss_std.append(np.std(val_loss_all[i]))
    
# Loss plot
axis[0].set_title("Loss")
axis[0].errorbar(range(1, 6), loss_mean, yerr=loss_std, color='red', label='Validation loss', linestyle='dashed')
axis[0].plot(range(1, 6), loss_mean, color='orange', label='Training loss')
axis[0].legend()
axis[0].set_xlabel("Folds")
axis[0].set_ylabel("Loss")

# Accuracy plot
axis[1].set_title("Accuracy")
axis[1].errorbar(range(1, 6), acc_mean, yerr=acc_std, color='red', label='Validation accuracy', linestyle='dashed')
axis[1].plot(range(1, 6), acc_mean, color='orange', label='Training accuracy')
axis[1].legend()
axis[1].set_xlabel("Folds")
axis[1].set_ylabel("Accuracy")

In [None]:
# TESTING on BEST Model (fold wise)

# All Accuracy for average calculation
acc = []
y_pred_ll = []
y_true_ll = []

for f in range(5):
    acc_val = 0
    acc_final = 0
    best = 0
    
    y_pred_T = []
    y_true_T = []
    
    y_pred_F = []
    y_true_F = []
    
    # -----------------------------------------------------------------------------------------
    # Training Model
    b_model = models.mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1)
    num_ftrs = b_model.classifier[3].in_features
    b_model.classifier[3] = nn.Linear(in_features=num_ftrs, out_features=num_classes)
    
    path = d + 'PyDL_Mango/models/MobileNetV3_s_fold_T_' + str(f+1) +'.pth'
    b_model.load_state_dict(torch.load(path))
    
    correct = 0
    total = 0
    
    b_model.eval()
    b_model.to(device)

    with torch.no_grad():
        for i, data in enumerate(test_loader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            # Predictions | forward pass | OUTPUT
            outputs = b_model(inputs)
            # Count correct predictions
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # for classification report
            y_pred_T.extend(predicted.tolist())
            y_true_T.extend(labels.tolist())
    
    # Validation best model accuracy    
    acc_val = 100 * correct / total
    
    # -----------------------------------------------------------------------------------------
    # Final Model
    f_model = models.mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1)
    num_ftrs = f_model.classifier[3].in_features
    f_model.classifier[3] = nn.Linear(in_features=num_ftrs, out_features=num_classes)
    
    path = d + 'PyDL_Mango/models/MobileNetV3_s_fold_F_' + str(f+1) +'.pth'
    f_model.load_state_dict(torch.load(path))
    
    correct = 0
    total = 0
    
    f_model.eval()
    f_model.to(device)
    
    with torch.no_grad():
        for i, data in enumerate(test_loader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            # Predictions | forward pass | OUTPUT
            outputs = f_model(inputs)
            # Count correct predictions
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # for classification report
            y_pred_F.extend(predicted.tolist())
            y_true_F.extend(labels.tolist())
       
    # Final model accuracy     
    acc_final = 100 * correct / total
    
    # -----------------------------------------------------------------------------------------
    if acc_val > acc_final:
        y_pred_ll.append(y_pred_T)
        y_true_ll.append(y_true_T)
    else:
        y_pred_ll.append(y_pred_F)
        y_true_ll.append(y_true_F)
        
    # -----------------------------------------------------------------------------------------
    best = max(acc_val, acc_final)
    
    # fold
    print(f"{green}\nFold {f+1}:")
    print(f"{blue}Validation Accuracy: {red}{fold_val_acc[f]}%")
    print(f"{blue}Test Accuracy: {red}{best}%")
    acc.append(best)
 
print(f"{blue}\n\nAverage Validation Accuracy: {red}{sum(fold_val_acc) / len(fold_val_acc)}%")  
print(f"{blue}Average Test Accuracy: {red}{sum(acc) / len(acc)}%")

In [None]:
# Classification Report
for i in range(5):
    print(f"{green}\nFold {i+1}:")
    print(f"{blue}Classification Report:")
    print(classification_report(y_true_ll[i], y_pred_ll[i], target_names=dataset_test.classes), end='\n\n')

In [None]:
# Accuracy
acc_metric = []

for i in range(5):
    acc = accuracy_score(y_true_ll[i], y_pred_ll[i])
    acc_metric.append(acc)
    print(f"{green}\nFold {i+1}:")
    print(f"{blue}Accuracy: {red}{acc}")
    
print(f"{blue}\n\nAverage Accuracy: {red}{sum(acc_metric) / len(acc_metric)}")  

In [None]:
# F1 Score
f_score_metric = []

for i in range(5):
    f_score = f1_score(y_true_ll[i], y_pred_ll[i], average='macro')
    f_score_metric.append(f_score)
    print(f"{green}\nFold {i+1}:")
    print(f"{blue}F1 Score: {red}{f_score}")
    
print(f"{blue}\n\nAverage F1 Score: {red}{sum(f_score_metric) / len(f_score_metric)}")

In [None]:
# Precision
precision_metric = []

for i in range(5):
    precision = precision_score(y_true_ll[i], y_pred_ll[i], average='macro')
    precision_metric.append(precision)
    print(f"{green}\nFold {i+1}:")
    print(f"{blue}Precision: {red}{precision}")
    
print(f"{blue}\n\nAverage Precision: {red}{sum(precision_metric) / len(precision_metric)}")

In [None]:
# Recall | Sensitivity
sen_metric = []

for i in range(5):
    sen = recall_score(y_true_ll[i], y_pred_ll[i], average='macro')
    sen_metric.append(sen)
    print(f"{green}\nFold {i+1}:")
    print(f"{blue}Recall/Sensitivity: {red}{sen}")
    
print(f"{blue}\n\nAverage Recall/Sensitivity: {red}{sum(sen_metric) / len(sen_metric)}")

In [None]:
def tp_calc (y_true, y_pred, class_label):
    tp = 0
    for i in range(len(y_true)):
        if y_true[i] == class_label and y_pred[i] == class_label:
            tp += 1
    return tp
    
def tn_calc (y_true, y_pred, class_label):
    tn = 0
    for i in range(len(y_true)):
        if y_true[i] != class_label and y_pred[i] != class_label:
            tn += 1
    return tn
    
def fp_calc (y_true, y_pred, class_label):
    fp = 0
    for i in range(len(y_true)):
        if y_true[i] != class_label and y_pred[i] == class_label:
            fp += 1
    return fp
    
def fn_calc (y_true, y_pred, class_label):
    fn = 0
    for i in range(len(y_true)):
        if y_true[i] == class_label and y_pred[i] != class_label:
            fn += 1
    return fn

In [None]:
import numpy as np

def calculate_specificity(y_true, y_pred, class_index):
    # Convert y_true and y_pred to numpy arrays if they are lists
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    # Identify true positive, false positive, true negative, and false negative counts
    #true_positive = np.sum((y_true == class_index) & (y_pred == class_index))
    false_positive = np.sum((y_true != class_index) & (y_pred == class_index))
    true_negative = np.sum((y_true != class_index) & (y_pred != class_index))
    #false_negative = np.sum((y_true == class_index) & (y_pred != class_index))

    # Calculate specificity
    specificity = true_negative / (true_negative + false_positive)

    return specificity

def calculate_multi_class_specificity(y_true, y_pred):
    num_classes = len(np.unique(y_true))
    specificity_scores = []

    for class_index in range(num_classes):
        specificity = calculate_specificity(y_true, y_pred, class_index)
        specificity_scores.append(specificity)

    # Calculate the average specificity across all classes
    average_specificity = np.mean(specificity_scores)

    return average_specificity, specificity_scores

average_specificity, specificity_scores = calculate_multi_class_specificity(y_true_ll[2], y_pred_ll[2])

print(f'Average Specificity: {average_specificity}')
print('Specificity for Each Class:', specificity_scores)

In [None]:
# Confusion Matrix
for i in range(5):
    # plot confusion matrix on 10, 10 figure with a blue color map
    cm = confusion_matrix(y_true_ll[i], y_pred_ll[i])
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=dataset_test.classes)
    disp.plot(cmap=plt.cm.Blues, xticks_rotation='vertical')
    plt.title("Fold " + str(i+1))
    plt.show()