# Multi-Class PyTorch Model Trainer

In [None]:
import os
import random
import shutil
import numpy as np
import pandas as pd
from glob import glob
from PIL import Image
#Load truncated images regardless
Image.LOAD_TRUNCATED_IMAGES = True
from tqdm import tqdm
import torchvision
from matplotlib import *
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataloader import default_collate
from torchvision import datasets, models, transforms
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.io import read_image, ImageReadMode
%matplotlib inline
import torch.optim.lr_scheduler as lr_scheduler

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

# Variables to edit

In [None]:
EPOCHS = 50
btch_sz = 256
# Set image size for model training
pic_size = 244
main_directory = './data_directory'
train_path = main_directory+'/train/'
test_path = main_directory+'/validation/'


# Setting paths and assigning classes

In [None]:
#Get class names from directories
classes = sorted([f.name for f in os.scandir(test_path) if f.is_dir()])
num_imgs = []
train_folders = sorted(glob(train_path+'/*'))
for path in train_folders:
    num = len(glob(path+'/*'))
    num_imgs.append(num)
set_file_count = str(num_imgs[0])+"_"
# Declare vars for Confusion matrix
preds_var = []
actual_var = []

In [None]:
def calculate_accuracy(TN, FP, FN, TP):
    total = TN + FP + FN + TP
    accuracy = (TP + TN) / total
    return accuracy

# Display Classes

In [None]:
classes

# Establish configuration settings

In [None]:
CFG = dict(
        batch_size = btch_sz,
        learning_rate = 0.001,
        epochs = EPOCHS,
        lin1_size = 300,
        lin2_size = 300,
        activation = 'relu',
        model = 'resnet50'
)

# Set random seed

In [None]:
# Random seeds
def set_seed(seed=0):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
set_seed()

# Generate Pie chart visualizations

In [None]:
import os
import glob
import matplotlib.pyplot as plt
from natsort import natsorted

def count_images_in_subdirs(main_dir):
    labels = []
    counts = []

    # Loop through each subdirectory in the main directory
    for subdir in os.listdir(main_dir):
        subdir_path = os.path.join(main_dir, subdir)
        if os.path.isdir(subdir_path):
            # Count the number of images in the subdirectory
            image_files = natsorted(glob.glob(f"{subdir_path}/*.jpg"))
            num_images = len(image_files)
            labels.append(subdir)
            counts.append(num_images)
    
    return labels, counts

def generate_pie_chart(labels, counts, locf='Train'):
    myexplode = [0.1] * len(labels)  # Adjust this list if you want specific slices to be exploded
    
    # Combine labels and counts for legend
    legend_labels = [f"{label} ({count})" for label, count in zip(labels, counts)]
    
    fig, ax = plt.subplots(figsize=(12.8, 9.6))  # Make the chart 2x larger
    ax.pie(counts, labels=labels, autopct='%1.1f%%',
           colors=plt.cm.tab20.colors, explode=myexplode, shadow=True, startangle=90)
    plt.legend(legend_labels, loc='upper right', title=locf + " Image Count")
    plt.title(f"{locf} - Image Distribution")
    plt.savefig('torch_'+locf+'_dataset_pie.png')
    plt.show()
    



In [None]:
labels, counts = count_images_in_subdirs(train_path)
generate_pie_chart(labels, counts, locf='Train')

In [None]:
labels, counts = count_images_in_subdirs(test_path)
generate_pie_chart(labels, counts, locf='Validation')

# Data operations

In [None]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

train_transforms = transforms.Compose([
        transforms.Resize((300,300)),
        transforms.RandomAffine(degrees=15, translate=(0.1,0.1), scale=(0.8,1.2), shear=5),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

test_transforms = transforms.Compose([
        transforms.Resize((300,300)),
        transforms.ToTensor(),
    ])

In [None]:
train_dataset = datasets.ImageFolder(train_path, train_transforms)
test_dataset = datasets.ImageFolder(test_path, test_transforms)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=CFG['batch_size'], shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=CFG['batch_size'], shuffle=False)

# Display Random Training set images

In [None]:
plot_loader = DataLoader(train_dataset, batch_size=CFG['batch_size'], shuffle=True)

# Visualise some examples
plt.figure(figsize=(15,15))
for i in range(9):
    ax = plt.subplot(3,3,i+1)
    ax.grid(False)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    batch = next(iter(plot_loader))
    label = int(batch[1].numpy()[0])
    image = np.transpose(batch[0][0].numpy(), (1, 2, 0))/2
    plt.imshow(image)
    plt.title(classes[label])
plt.show()

# Load model and display state_dict

In [None]:
model = models.resnet50(pretrained=True).to(device)

# Freeze the layers of the ResNet50 model
for param in model.parameters():
    param.requires_grad = False

# Add a new classification head to the model
model.fc = nn.Sequential(nn.Linear(2048, CFG['lin1_size']),
                         nn.ReLU(),
                         nn.Linear(CFG['lin2_size'], len(classes))).to(device)

# Initialize optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Print model's state_dict
print("Model's state_dict:")
for layer in model.state_dict():
    print(layer, "\t", model.state_dict()[layer].size())

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for state_var in optimizer.state_dict():
    print(state_var, "\t", optimizer.state_dict()[state_var])

# Print Model.state.dict()

In [None]:
import pprint
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(model.state_dict())

# Set extra model parameters

In [None]:
for param in model.parameters():
    param.requires_grad = False

for param in model.fc.parameters():
    param.requires_grad = True

# Set optimizer and learning rate

In [None]:
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.fc.parameters())

# Learning rate scheduler
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG['epochs'])

# Training Function

In [None]:
def train_one_epoch(train_loader, model, criterion, optimizer, scheduler):
    global preds_var 
    global actual_var 
    print("Training...")
    # Train mode
    model.train()
    # Track metrics
    loss_epoch = 0
    accuracy_epoch = 0
    # Loop over minibatches
    for inputs, labels in tqdm(train_loader):
        # Send to device
        inputs = inputs.to(device)
        labels = labels.to(device)
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        # Backprop
        loss.backward()
        # Update parameters
        optimizer.step()
        # Zero gradients
        optimizer.zero_grad()
        # Track loss
        loss_epoch += loss.detach().item()
        # Accuracy
        _, preds = torch.max(outputs, 1)
        accuracy_epoch += torch.sum(preds == labels)/inputs.shape[0]

        preds_var += preds.tolist()
        actual_var += labels.tolist()
        
    # Update learning rate
    scheduler.step()
        
    return loss_epoch/len(train_loader), accuracy_epoch.item()/len(train_loader)

# Eval function

In [None]:
def evaluate_one_epoch(test_loader, model, criterion):
    print("Evaluation...")
    # Eval mode
    model.eval()
    
    # Track metrics
    loss_epoch = 0
    accuracy_epoch = 0
    
    # Don't update weights
    with torch.no_grad():
        # Loop over minibatches
        for inputs, labels in tqdm(test_loader):
            # Send to device
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            # Track loss
            loss_epoch += loss.detach().item()
            
            # Accuracy
            _, preds = torch.max(outputs, 1)
            accuracy_epoch += torch.sum(preds == labels)/inputs.shape[0]
    
    return loss_epoch/len(test_loader), accuracy_epoch.item()/len(test_loader)

# Show Performance plot

In [None]:
# Plot history
def plot_hist(train_loss_hist, test_loss_hist, train_acc_hist, test_acc_hist):    
    plt.figure(figsize=(16,6))
    plt.subplot(1,2,1)
    plt.plot(train_loss_hist, label='Train_Loss')
    plt.plot(test_loss_hist, label='Validation_loss')
    plt.title('Cross Entropy Loss')
    plt.legend()
    
    plt.subplot(1,2,2)
    plt.plot(train_acc_hist, label='Train_Accuracy')
    plt.plot(test_acc_hist, label='Validation_Accuracy')
    plt.title('Accuracy')
    plt.legend()
    plt.savefig("torch_performance.png")
    plt.show()
    

# Training function to call train and Eval

In [None]:
from tqdm import tqdm

def train_model(model, criterion, optimizer, scheduler, train_loader, test_loader, verbose=True):
    # Initialise outputs
    train_loss_hist = []
    test_loss_hist = []
    train_acc_hist = []
    test_acc_hist = []
    
    # Loop over epochs
    for epoch in range(CFG['epochs']):
        # Train
        train_loss, train_accuracy = train_one_epoch(train_loader, model, criterion, optimizer, scheduler)
        
        # Evaluate
        test_loss, test_accuracy = evaluate_one_epoch(test_loader, model, criterion)
        
        # Track metrics
        train_loss_hist.append(train_loss)
        test_loss_hist.append(test_loss)
        train_acc_hist.append(train_accuracy)
        test_acc_hist.append(test_accuracy)
        

        # Save model
        torch.save(model.state_dict(), "model/chkpt_torch_model.pth")
        # Save model
        torch.save(model, "model/chkpt_torch_model-full.pth")
        print("Checkpoints saved")
        
        # Print loss
        if verbose:
            if (epoch+1)%1==0:
                print(f'Epoch {epoch+1}/{CFG["epochs"]}, loss {train_loss:.5f}, test_loss {test_loss:.5f}, accuracy {train_accuracy:.5f}, test_accuracy {test_accuracy:.5f}')
    
    return train_loss_hist, test_loss_hist, train_acc_hist, test_acc_hist

# Start Training

In [None]:
print("Training for ",train_path)
# Train model
train_loss_hist, test_loss_hist, train_acc_hist, test_acc_hist = train_model(model, criterion, optimizer, 
                                                                             scheduler, train_loader, test_loader,  verbose=True)

## Save the model

In [None]:
# Save model
torch.save(model.state_dict(), "model/torch_model.pth")
# Save model
torch.save(model, "model/torch_model-full.pth")
# Remove Checkpoints
os.remove("model/chkpt_torch_model-full.pth")
os.remove("model/chkpt_torch_model.pth")


## Show model Features

In [None]:
features = []
for key,value in model._modules.items():
    features.append(value)
features

# Show performance

In [None]:
plot_hist(train_loss_hist, test_loss_hist, train_acc_hist, test_acc_hist)

# Show Confusion Matrix

In [None]:
import numpy as np
import torch
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from tqdm import tqdm

def get_predictions(model, data_loader):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for data in tqdm(data_loader):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    return all_labels, all_preds

def calculate_scores(cm):
    correct_scores = [cm[i, i] / np.sum(cm[i]) for i in range(len(cm))]
    return correct_scores

def plot_confusion_matrix(actual, preds, class_labels):
    # Generate confusion matrix
    cm = confusion_matrix(actual, preds)
    num_classes = len(class_labels)
    
    # Calculate scores for each label
    scores = calculate_scores(cm)
    
    # Create a custom figure
    fig, ax = plt.subplots(figsize=(12, 12))
    
    # Use a gradient color map for intensity
    cmap = plt.cm.RdYlGn
    
    # Plot the confusion matrix with intensity proportional to scores
    cax = ax.matshow(cm, cmap=cmap)
    
    # Set text color based on correct or incorrect predictions
    for i in range(num_classes):
        for j in range(num_classes):
            color = 'white' if i == j else 'black'
            ax.text(j, i, str(cm[i, j]), va='center', ha='center', color=color)
    
    # Set labels for the axes
    ax.set_xticks(np.arange(num_classes))
    ax.set_yticks(np.arange(num_classes))
    ax.set_xticklabels(class_labels)
    ax.set_yticklabels(class_labels)
    plt.xlabel('Predicted')
    plt.ylabel('Actual')

    # Calculate accuracy
    def calculate_accuracy(cm):
        return np.trace(cm) / np.sum(cm)
    
    accuracy = calculate_accuracy(cm)
    str_title = f"Confusion Matrix\n{accuracy * 100:.2f}% accuracy."
    plt.title(str_title)
    
    # Create custom legend handles and labels with scores
    legend_handles = [
        plt.Line2D([0], [0], color=cmap(score), lw=4, label=f'{class_labels[i]}: {score:.2f}')
        for i, score in enumerate(scores)
    ]
    
    # Sort the legend handles based on scores
    legend_handles = sorted(legend_handles, key=lambda x: float(x.get_label().split(': ')[1]), reverse=True)
    
    # Add a legend
    legend = plt.legend(handles=legend_handles, loc='upper left', bbox_to_anchor=(1, 1), title="Legend")
    plt.setp(legend.get_texts(), color='black')  # Set legend text color to black
    
    # Save CM without cropping the legend
    plt.savefig("torch_confusion_matrix.png", bbox_inches='tight')
    
    # Display the plot
    plt.show()

# Assuming `model`, `test_loader`, and `classes` are defined elsewhere
# Get predictions and actual labels from the test set
actual_labels, predicted_labels = get_predictions(model, test_loader)

# Plot the confusion matrix
plot_confusion_matrix(actual_labels, predicted_labels, classes)


# Confusion Matrix Explanation

A confusion matrix is a summary of prediction results on a classification problem. The number of correct and incorrect predictions are summarized with count values and broken down by each class. This is the key to the confusion matrix.

## Structure of the Confusion Matrix

The confusion matrix shows the ways in which your classification model is confused when it makes predictions. It not only gives you insight into the errors being made by your classifier but also more importantly the types of errors that are being made.

### Components of the Confusion Matrix

For a binary classification problem, the confusion matrix looks like this:

|                    | Predicted Negative | Predicted Positive |
|--------------------|--------------------|--------------------|
| **Actual Negative**| True Negative (TN) | False Positive (FP)|
| **Actual Positive**| False Negative (FN)| True Positive (TP) |

- **True Positive (TP):** The model correctly predicted the positive class.
- **True Negative (TN):** The model correctly predicted the negative class.
- **False Positive (FP):** The model incorrectly predicted the positive class (Type I error).
- **False Negative (FN):** The model incorrectly predicted the negative class (Type II error).

For a multi-class classification problem, the matrix expands to include rows and columns for each class.

### Accuracy Calculation

Accuracy is one metric for evaluating classification models. It is the ratio of correctly predicted instances to the total instances:

\[ \text{Accuracy} = \frac{TP + TN}{TP + TN + FP + FN} \]

In a multi-class scenario, accuracy is calculated as the trace of the confusion matrix divided by the sum of all elements in the matrix:

\[ \text{Accuracy} = \frac{\sum_{i} \text{cm}[i,i]}{\sum_{i,j} \text{cm}[i,j]} \]

Where \(\text{cm}[i,j]\) is the element of the confusion matrix at row \(i\) and column \(j\).

### Visualizing the Confusion Matrix

A visual representation of the confusion matrix can provide more insight into the model’s performance:

1. **Matrix Plot:** Each cell's color intensity corresponds to the number of instances in that cell, with darker shades representing higher counts.
2. **Axes Labels:** The x-axis represents the predicted classes, and the y-axis represents the actual classes.
3. **Annotations:** Each cell shows the count of instances for that actual-predicted class pair.


- The diagonal cells represent the number of times the model correctly predicted each class.
- Off-diagonal cells represent the number of times the model confused one class with another.

### Interpretation

- **High Diagonal Values:** Indicates that the model performs well for those classes.
- **High Off-Diagonal Values:** Indicates areas where the model is confused and may need improvement.

### Usage

By analyzing the confusion matrix, you can:

- Identify which classes are being predicted correctly and which are not.
- Understand the types of errors your model is making.
- Make informed decisions to improve your model, such as collecting more data for classes with high error rates or adjusting the model’s complexity.

Understanding the confusion matrix is crucial for improving the performance of your classification models and ensuring they are reliable and accurate.


## Test the model across a random sampling of test_dir images

In [None]:
def get_pred(img_size=pic_size):
    from PIL import Image
    import math
    import torch
    from torchvision import models, transforms
    test_transforms = transforms.Compose([
        transforms.Resize((300,300)),
        transforms.ToTensor(),
    ])

    subdirs = [os.path.join(test_path, d) for d in os.listdir(test_path) if os.path.isdir(os.path.join(test_path, d))]
    all_images = []

    for subdir in subdirs:
        images = [os.path.join(subdir, f) for f in os.listdir(subdir) if os.path.isfile(os.path.join(subdir, f)) and f.lower().endswith(('.jpg'))]
        all_images.extend([(subdir, img) for img in images])

    img = str(random.choice(all_images)[1])
    # Load the image
    image_path = img
    image = Image.open(image_path)
    # Transform target image
    custom_image_transformed = test_transforms(image)
    pred_model = torch.load("model/torch_model-full.pth") 
    #Load for inferrence
    pred_model.eval()
    # Move the input data to the GPU
    input_data = custom_image_transformed.cuda()
    # Perform computations on the GPU
    output_data = pred_model(input_data.unsqueeze(0))
    _, index = torch.max(output_data, 1)
    percentage = torch.nn.functional.softmax(output_data, dim=1)[0] * 100
    #print(percentage)
    #print(math.ceil(index))
    message  = str(round(percentage[index[0]].item(),2))+"% confident, this is "+str(classes[index[0]])
    #message  = str(percentage[index[0]].item())+"% confident, this is "+str(classes[index[0]])
    img = image.resize((pic_size,pic_size))
    plt.imshow(img)
    print(message)


In [None]:
get_pred()

In [None]:
get_pred()

In [None]:
get_pred()

## Load and return only the classification from random image selection

In [None]:
import os
import random

def only_class(image_file, classes,img_size=pic_size):
    from PIL import Image
    import math
    import torch
    from torchvision import models, transforms
    import matplotlib.pyplot as plt
    test_transforms = transforms.Compose([
        transforms.Resize((img_size,img_size)),
        transforms.ToTensor(),
    ])
    # Load the image
    image_path = image_file
    image = Image.open(image_path)
    # Transform target image
    custom_image_transformed = test_transforms(image)
    pred_model = torch.load("model/torch_model-full.pth") 
    #Load for inferrence
    pred_model.eval()
    # Move the input data to the GPU
    input_data = custom_image_transformed.cuda()
    # Perform computations on the GPU
    output_data = pred_model(input_data.unsqueeze(0))
    _, index = torch.max(output_data, 1)
    return image,str(classes[index[0]])

def find_random_image(root_dir, extensions=['.jpg']):
    # List to hold paths of all images
    image_paths = []

    # Walk through the directory and subdirectories
    for subdir, _, files in os.walk(root_dir):
        for file in files:
            # Check if the file extension is one of the image extensions
            if any(file.lower().endswith(ext) for ext in extensions):
                image_paths.append(os.path.join(subdir, file))

    # If there are no images, return None
    if not image_paths:
        return None

    # Return a random image path
    
    return random.choice(image_paths)

In [None]:
# Example usage
random_image_path = find_random_image(main_directory)
print(random_image_path)
only_class(random_image_path, classes,img_size=pic_size)
