# Importing libraries

In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18
from torch.optim.lr_scheduler import StepLR
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
from sklearn.exceptions import UndefinedMetricWarning
from sklearn.utils.class_weight import compute_class_weight
import matplotlib.pyplot as plt
import os
import time
import numpy as np
import seaborn as sns
import warnings

# Set up directories
* This code was used to train a ResNet model on Video Plankton Recorder derived images, called Regions of Interest (ROI). 
* To use this code, all you need to do is to have a training and validation dataset split beforehand. 
#### Please input your base directory where you "Train" and "Validation" folders are located.

In [None]:
# Define base directory (modify this path to your local dataset location)
base_dir = r"./data/Combined_dataset"

train_folder = os.path.join(base_dir, "Train")
validation_folder = os.path.join(base_dir, "Validation")

# Check if directories exist
for folder, name in zip([train_folder, validation_folder], ["Training", "Validation"]):
    if not os.path.exists(folder):
        raise FileNotFoundError(f"{name} folder '{folder}' not found. Please check your dataset path.")

def count_tif_images(folder):
    count = 0
    for _, _, files in os.walk(folder):
        count += sum(1 for f in files if f.lower().endswith('.tif'))
    return count

print(f"Training images: {count_tif_images(train_folder)}")
print(f"Validation images: {count_tif_images(validation_folder)}")

# Calculate normalization values
* Based on the training dataset
* This may take some minutes

In [None]:
# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

image_size = (224, 224) # Input size necessary for ResNet images. 

# Define a transformation without normalization
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor()
])

# Load the training dataset
train_dataset = datasets.ImageFolder(root=train_folder, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, num_workers=4)

# Function to calculate mean and std
def calculate_mean_std(loader, device):
    mean = torch.zeros(3, device=device)  
    std = torch.zeros(3, device=device)  
    total_samples = 0

    with torch.no_grad():
        for images, _ in loader:
            images = images.to(device)
            batch_samples = images.size(0)
            images = images.view(batch_samples, 3, -1)
            mean += images.mean(dim=[0, 2]) * batch_samples
            std += images.std(dim=[0, 2]) * batch_samples
            total_samples += batch_samples

    mean /= total_samples
    std /= total_samples
    return mean.cpu(), std.cpu()

# Calculate and display mean/std
mean, std = calculate_mean_std(train_loader, device)
calculated_mean = mean.tolist()
calculated_std = std.tolist()
print(f"Calculated Mean: {calculated_mean}")
print(f"Calculated Std: {calculated_std}")

# Data augmentation
* Includes random horizontal+vertical flips to the training data, to improve generalization of the model

In [None]:
# Transformations for training data
transform_train = transforms.Compose([
    transforms.Resize((image_size)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=calculated_mean, 
                         std=calculated_std)    
])

# Define data transformations for validation data
transform_val = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=calculated_mean, 
                         std=calculated_std)    
])

# Load train and validation datasets with transformations
train_dataset = datasets.ImageFolder(root=train_folder, transform=transform_train)
validation_dataset = datasets.ImageFolder(root=validation_folder, transform=transform_val)

# Set model and hyperparameters
* Also prints a table of all hyperparameters used.
* Inputs: ResNet model (number of layers) and hyperparameters
* Batch size, learning rate, epochs, weighted loss, and LRScheduler
* Current parameters are what was used for a 3-class model classifying VPR images of Marine snow, Fecal pellets, and Others.

In [None]:
# Choose ResNet model
resnet_model = resnet34 # resnetXX, where XX is number of layers desired. 
# Initiate model with no weights, training from scratch
model = resnet_model(weights=None, num_classes=len(train_dataset.classes)).to(device)

#### Hyperparameters
batch_size = 32
learning_rate = 0.001 # initial learning rate
num_epochs = 30 # number of epochs
loss = "weighted" # change this  if not using weighted loss
LRScheduler = "ROP" # reduce on plateau | = "ROP" if using. Change if not using.  

#### LOSS FUNCTION 
## Implement weighted loss function, assigning higher penalties to underrepresented classes
## Define min and max weight thresholds
if loss == "weighted":
    max_weight_threshold = 5.0  # Prevent extreme rare class weighting
    min_weight_threshold = 0.5  # Ensure majority classes are not ignored
    ## Get class distribution from dataset
    class_counts = np.bincount(train_dataset.targets)  # Count samples per class
    num_classes = len(class_counts)
    ## Compute weights: inverse of class frequency
    class_weights = compute_class_weight(class_weight="balanced", classes=np.arange(num_classes), y=train_dataset.targets)
    class_weights = np.clip(class_weights, min_weight_threshold, max_weight_threshold) # cap the weights
    ## Convert to PyTorch tensor and send to device
    class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
else:
    criterion = nn.CrossEntropyLoss() # replace with above if using weight loss function. 

# Optimizer uses calculated loss to adjust the weights, reducing loss.
# Adam optimizer automatically adjusts learning rate for each parameter
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Create a learning rate scheduler, either StepLR or ReduceLROnPlateau
if LRScheduler == "ROP": 
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=4, verbose=True) #remember to alter the factor
else:
    LRScheduler = "StepLR"
    scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

# Create a pandas DataFrame for a proper table
hyperparams = pd.DataFrame([
    ["Model", str(resnet_model)],
    ["Batch Size", batch_size],
    ["Learning Rate", learning_rate],
    ["Epochs", num_epochs],
    ["Loss Function", loss],
    ["Optimizer", str(optimizer)],
    ["Scheduler", LRScheduler]
], columns=["Parameter", "Value"])

hyperparams

# Load data loaders with transformations
* Num_workers controls parallel loading of data. Higher values lead to faster data fetching, but requires a more powerful CPU. 

In [None]:
num_workers = 8

# DataLoader for train & validation datasets
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, 
    num_workers=num_workers, pin_memory=True, persistent_workers=True,
)

validation_loader = torch.utils.data.DataLoader(
    validation_dataset, batch_size=batch_size, shuffle=False, 
    num_workers=num_workers, pin_memory=True, persistent_workers=True,    
)

# Lists to store metrics
* Remember to reset these lists between each training!

In [None]:
# Initialize lists to store true labels and predicted labels
true_labels_list = []
predicted_labels_list = []

# Lists to store loss and accuracies
train_losses = []
train_accuracies = [] 
validation_losses = []
validation_accuracies = []

# Lists for classwise precision and recall + F1
classwise_precision = []
classwise_recall = []
classwise_f1 = []


#Store all metrics for each class across each epoch
class_metrics = {class_name: {"train": [], "val": [], "precision": [], "recall": [], "f1": []} for class_name in train_dataset.classes}
# Initialize dictionaries to store FP FN TP and TN for each class and epoch. This will containt accumualted metrics over all epochs. 
class_metrics2 = {class_name: {"FP": [], "FN": [], "TP": [], "TN": []} for class_name in train_dataset.classes}

# Initialize counters for each class
num_classes = len(train_dataset.classes)
class_correct = [0] * num_classes
class_total = [0] * num_classes
class_train_correct = [0] * num_classes
class_train_total = [0] * num_classes
class_val_correct = [0] * num_classes
class_val_total = [0] * num_classes

# Lists to store F1 scores
macro_f1_scores = []
weighted_f1_scores = []
epochs_list = []

# Train and validate the model
* This code block will begin training the model, and performance data.
* This code also saves the best weights (determined by highest F1 score).
* Inputs: Patience (stops training if no improvements in X epochs.
* After each epoch, the code will report time taken for training and validation. Then, it will report F1 scores. 

In [None]:
# Suppress all UndefinedMetricWarning warnings where precision and recall = 0. 
warnings.filterwarnings("ignore", category=UndefinedMetricWarning)

# Best F1 score and model weights initialization
best_weighted_f1 = 0.0  # Stores the best weighted F1-score
best_macro_f1 = 0.0 # Stores best macro F1 score
best_model_wts = None  # Stores the best model weights in memory

# Early stopping parameters
epochs_without_improvement = 0
patience = 8  # Stop if no improvement in x epochs

#################### Training loop ####################################################################
print("Initiating training")

for epoch in range(1, num_epochs + 1):
    start_time = time.time()  # Start the timer

    # Reset class_train_correct and class_train_total for each epoch
    class_train_correct = [0] * num_classes
    class_train_total = [0] * num_classes


    model.train()
    train_loss = 0.0
    correct_train = 0
    total_train = 0
    true_labels_train = []
    predicted_labels_train = []
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)  # Move data to GPU
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        
        _, predicted = torch.max(outputs.data, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()
        true_labels_train.extend(labels.cpu().numpy())
        predicted_labels_train.extend(predicted.cpu().numpy())

        # Update class-wise counts and training metrics
        for i in range(labels.size(0)):
            label = labels[i].item()
            class_correct[label] += (predicted[i] == label)
            class_total[label] += 1
            class_train_correct[label] += (predicted[i] == label)
            class_train_total[label] += 1
    
    # End the timer and calculate elapsed time
    end_time = time.time()
    epoch_time = end_time - start_time
    print(f"Time taken for training epoch {epoch}: {epoch_time:.0f} seconds")
        
    train_losses.append(train_loss / len(train_loader))
    train_accuracies.append(100 * correct_train / total_train)  # Store total training accuracy
    
    # Calculate precision, recall, and F1 score for training
    train_precision = precision_score(true_labels_train, predicted_labels_train, average='weighted')
    train_recall = recall_score(true_labels_train, predicted_labels_train, average='weighted')
    train_f1 = f1_score(true_labels_train, predicted_labels_train, average='weighted')    

########################### Validation loop #######################################################################
    model.eval()
    start_time = time.time()  # Start the timer
    validation_loss = 0.0
    correct_val = 0
    total_val = 0
    true_labels_val = []
    predicted_labels_val = []

    # Reset class_val_correct and class_val_total for each epoch
    class_val_correct = [0] * num_classes
    class_val_total = [0] * num_classes   
    
    # For each epoch, reset class_metrics3. 
    class_metrics3 = {class_name: {"FP": [], "FN": [], "TP": [], "TN": []} for class_name in train_dataset.classes}  # Create a new dictionary for each epoch

    for images, labels in validation_loader:
        images, labels = images.to(device), labels.to(device)  # Move data to GPU
        outputs = model(images)
        loss = criterion(outputs, labels)
        validation_loss += loss.item()

        _, predicted = torch.max(outputs.data, 1)
        total_val += labels.size(0)

        
        # Calculate TP, FP, TN, FN for each class
        for class_name in train_dataset.classes:
            class_index = train_dataset.class_to_idx[class_name]
            true_positive = ((predicted == class_index) & (labels == class_index)).sum().item()
            false_positive = ((predicted == class_index) & (labels != class_index)).sum().item()
            true_negative = ((predicted != class_index) & (labels != class_index)).sum().item()
            false_negative = ((predicted != class_index) & (labels == class_index)).sum().item()

            # Append these metrics to the class_metrics2 dictionary, for accumulated values.
            class_metrics2[class_name]["TP"].append(true_positive)
            class_metrics2[class_name]["FP"].append(false_positive)
            class_metrics2[class_name]["TN"].append(true_negative)
            class_metrics2[class_name]["FN"].append(false_negative)

            # Append these metrics to the class_metrics2 dictionary, for individual epoch values.
            #These dictionaries are exactly the same, but class_metrics3 is reset after every epoch, thus empty at this point.
            class_metrics3[class_name]["TP"].append(true_positive)
            class_metrics3[class_name]["FP"].append(false_positive)
            class_metrics3[class_name]["TN"].append(true_negative)
            class_metrics3[class_name]["FN"].append(false_negative)
            # and at this point it has filled in the NEWEST values of TP FP TN and FN. 
        
        correct_val += (predicted == labels).sum().item()
        true_labels_val.extend(labels.cpu().numpy())
        predicted_labels_val.extend(predicted.cpu().numpy())

        # Update class-wise counts for validation
        for i in range(labels.size(0)):
            label = labels[i].item()
            class_val_correct[label] += (predicted[i] == label)
            class_val_total[label] += 1
       
    # Append validation accuracy and loss
    validation_losses.append(validation_loss / len(validation_loader))
    validation_accuracies.append(100 * correct_val / total_val)

    # Calculate and append training and validation accuracies for each class
    for i, class_name in enumerate(train_dataset.classes):
        class_index = train_dataset.class_to_idx[class_name]        
        # Calculate and append training accuracy for the class
        train_acc = (class_train_correct[class_index] / class_train_total[class_index]) * 100 if class_train_total[class_index] != 0 else 0
        class_metrics[class_name]["train"].append(train_acc)
        # Calculate and append validation accuracy for the class
        val_acc = (class_val_correct[class_index] / class_val_total[class_index]) * 100 if class_val_total[class_index] != 0 else 0
        class_metrics[class_name]["val"].append(val_acc)

        #calculate classwise precision and recall
    for i, class_name in enumerate(train_dataset.classes):
        class_index = train_dataset.class_to_idx[class_name]
        # Calculate precision, recall, and F1 for validation for each class (labels = [class_index] to calculate for each class)
        val_precision = precision_score(true_labels_val, predicted_labels_val, labels=[class_index], average=None)
        val_recall = recall_score(true_labels_val, predicted_labels_val, labels=[class_index], average=None)
        val_f1 = f1_score(true_labels_val, predicted_labels_val, labels=[class_index], average=None)

        # Append these metrics to the class_metrics dictionary
        class_metrics[class_name]["precision"].append(val_precision)
        class_metrics[class_name]["recall"].append(val_recall)
        class_metrics[class_name]["f1"].append(val_f1)


    # Calculate weighted precision, recall, and F1 for validation for entire validation loop 
    val_precision = precision_score(true_labels_val, predicted_labels_val, average='weighted')
    val_recall = recall_score(true_labels_val, predicted_labels_val, average='weighted')
    val_weighted_f1 = f1_score(true_labels_val, predicted_labels_val, average='weighted')
    val_macro_f1 = f1_score(true_labels_val, predicted_labels_val, average='macro')

    ### Scheduler step
    prev_lr = optimizer.param_groups[0]['lr']
    
    if LRScheduler == "ROP":
        # Get previous learning rate before step
        # Update learning rate based on avg F1-score
        scheduler.step((val_weighted_f1 + val_macro_f1) / 2)
        # Get updated learning rate
        new_lr = optimizer.param_groups[0]['lr']
    else: 
        scheduler.step()
        new_lr = optimizer.param_groups[0]['lr']

    # Print only if learning rate has changed
    if new_lr < prev_lr:
        print(f"\033[93mLearning Rate Reduced: {prev_lr:.6f} → {new_lr:.6f}\033[0m")

    #append label values
    true_labels_list.extend(true_labels_val)
    predicted_labels_list.extend(predicted_labels_val)
    
    # Store F1 values
    macro_f1_scores.append(val_macro_f1)
    weighted_f1_scores.append(val_weighted_f1)
    epochs_list.append(epoch)    

    # End the timer and calculate elapsed time
    end_time = time.time()
    epoch_time = end_time - start_time
    print(f"Time taken for validation epoch and calculations {epoch}: {epoch_time:.0f} seconds")
    
    #Print metrics for individual epochs here:
    print(f'Epoch [{epoch}/{num_epochs}]')
    
    #print(f'Individual epoch metrics per class:')
    # Print metrics for each class from class_metrics3 - these should be different each epoch. 
    for class_name in train_dataset.classes:
        TP = sum(class_metrics3[class_name]["TP"])
        FP = sum(class_metrics3[class_name]["FP"])
        TN = sum(class_metrics3[class_name]["TN"])
        FN = sum(class_metrics3[class_name]["FN"])
    
        # Calculate precision, recall, and accuracy for each class
        #precision = TP / (TP + FP) if (TP + FP) != 0 else 0
        #recall = TP / (TP + FN) if (TP + FN) != 0 else 0
        #accuracy = 100 * (TP + TN) / (TP + TN + FP + FN) if (TP + TN + FP + FN) != 0 else 0
        #print(f'Class: {class_name}, Training Accuracy: {class_metrics[class_name]["train"][-1]:.2f}%, Validation Accuracy: {accuracy:.4f}%, '
        #      f'Precision: {precision:.4f}, Recall: {recall:.4f}') 

        # Early stopping and best model tracking
    if val_weighted_f1 > best_weighted_f1 and val_macro_f1 > best_macro_f1: # If using "and", both must improve!
        best_weighted_f1 = val_weighted_f1
        best_macro_f1 = val_macro_f1
        best_model_wts = model.state_dict()  # Store best weights in memory
        epochs_without_improvement = 0  # Reset counter when improvement is found
        print(f"\033[92mNew best model detected! Weighted/macro F1-score: {best_weighted_f1:.2f}/{best_macro_f1:.2f} at epoch {epoch}\033[0m")
    else:
        epochs_without_improvement += 1
        print(f"\033[93mNo improvements detected. Weighted/macro F1-score:{val_weighted_f1:.2f}/{val_macro_f1:.2f}\nEpochs without improvement: {epochs_without_improvement}\033[0m")

    # Check early stopping condition
    if epochs_without_improvement >= patience:
        print(f"\033[91mEarly stopping triggered at epoch {epoch} due to no improvement in {patience} epochs.\033[0m")
        print(f"\033[92mBest model had: Weighted/macro F1-score: {best_weighted_f1:.2f}/{best_macro_f1:.2f} at epoch {epoch}\033[0m")
        break  # Exit training loop
        
    print()

# Save the model

In [None]:
num_classes = len(train_dataset.classes)
save_path = rf'.\Models\{num_classes}.pth'

torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'training_loss': train_loss,
    'validation_loss': validation_loss,
    'training_accuracies': train_accuracies,
    'validation_accuracies': validation_accuracies,
    'true_labels_list': true_labels_list,
    'predicted_labels_list': predicted_labels_list,
    'class_mapping': train_dataset.class_to_idx,  # Save the class mapping
    'num_classes': num_classes,  # Save the number of classes
    'normalization': {  # Save normalization values
        'mean': calculated_mean,
        'std': calculated_std
    }
}, save_path)

print(f"Model and additional data saved as {save_path}")

# Load saved model

In [None]:
# Load the model checkpoint
checkpoint = torch.load(save_path, map_location=device)

model_class_mapping = {v: k for k, v in checkpoint['class_mapping'].items()}  # Reverse mapping for index-to-class name

# Initialize the model and load weights
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()

print("Model loaded successfully.")

# Final Validation and Confusion matrix

In [None]:
# Track predictions and labels
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in validation_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)

        # Directly use predictions and labels as they are already aligned
        all_preds.extend(preds.cpu().tolist())
        all_labels.extend(labels.cpu().tolist())

print("Test done.")

# Generate the classification report directly
report = classification_report(all_labels, all_preds, target_names=validation_loader.dataset.classes)
print(report)

# Confusion matrix calculated with % of each class.

# Compute the confusion matrix
cm = confusion_matrix(all_labels, all_preds)

# Normalize the confusion matrix to percentages
cm_percentage = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100

# Define class names directly from the dataset
class_names = validation_loader.dataset.classes  # Assumes the dataset has a 'classes' attribute

# Create a heatmap with percentages
plt.figure(figsize=(24, 20))
sns.heatmap(cm_percentage, annot=True, fmt=".0f", cmap="Blues", xticklabels=class_names, yticklabels=class_names)

# Add labels and title
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.title("Confusion Matrix (Percentage)")

# Save the figure to a file
#plt.savefig(rf"C:\Users\kriss\Documents\HI Jobb\Confusion matrix\Run{run_nr}.png", dpi=300, bbox_inches='tight')
plt.show()

# Plot Precision, Recall, Accuracy, Loss and F1 scores per class.

In [None]:
# Define the list of classes
classes = train_dataset.classes

# Loop through each class and create precision and recall plots
for class_name in classes:
    # Get precision and recall values for the current class across epochs
    precision_values = class_metrics[class_name]["precision"]
    recall_values = class_metrics[class_name]["recall"]
    
    # Get the actual number of completed epochs for this class
    actual_epochs = len(precision_values)
    
    # Create a new figure for each class
    plt.figure(figsize=(10, 6))
    plt.title(f'Precision and Recall for Class: {class_name}')
    
    # Plot precision
    plt.plot(range(1, actual_epochs + 1), precision_values, label='Precision', marker='o')
    
    # Plot recall
    plt.plot(range(1, actual_epochs + 1), recall_values, label='Recall', marker='x')
    
    plt.xlabel('Epoch')
    plt.ylabel('Score')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    
    # Show the plot for this class
    plt.show()

In [None]:
plt.figure(figsize=(10, 6))

# Plot Macro F1
plt.plot(epochs_list, macro_f1_scores, label="Macro F1", linestyle="-", marker="o")

# Plot Weighted F1
plt.plot(epochs_list, weighted_f1_scores, label="Weighted F1", linestyle="-", marker="s")

# Formatting
plt.xlabel("Epoch")
plt.ylabel("F1 Score")
plt.title("Macro & Weighted F1 Score Over Training")
plt.legend()
plt.grid(True)

# Show the plot
plt.show()

In [None]:
# Plot training and validation losses. Big gap indicates overfitting.
plt.figure(figsize=(10, 5))
plt.plot(range(1, actual_epochs + 1), train_losses, label='Training Loss')
plt.plot(range(1, actual_epochs + 1), validation_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss Curves')
plt.show()