In [None]:
import os
import copy
import torch
import itertools
import numpy as np
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from torchvision import transforms
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report

In [None]:
# find device
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print("Using {} device".format(device))

# Data processing

In [9]:
class processDataset():
    def __init__(self, datasetPath, resize_size):
        self.uniqueClassNames = []   # [class1, class 2, class3, ...]
        self.imagesAndLabels = []    # [{images: tensor(30x224x224), label: class1}, ...]
        self.classNameToIndex = {}   # {class1: 0, class2: 1, class3: 2, ...}
        self.IndexToClassName = {}   # {0: class1, 1: class2, 2: class3, ...}
        self.allImages = []          # [tensor(30x224x224), tensor(30x224x224), tensor(30x224x224)]
        self.allLabels = []          # [class1, class3, class3, ...]
        self.transform = None
        self.NumFramesPerVid = 10
        
        # get class
        for className in os.listdir(datasetPath):
            if os.path.isdir(os.path.join(datasetPath, className)):
                self.uniqueClassNames.append(className)
        
        self.uniqueClassNames = set(self.uniqueClassNames)
        
        # (className -> index) & (index -> className) dict (unique)
        self.classNameToIndex = {className: i for i, className in enumerate(self.uniqueClassNames)}
        self.IndexToClassName = {i: className for i, className in enumerate(self.uniqueClassNames)}            
        
        # Image augmentation and normalization transform
        self.transform = transforms.Compose([
            transforms.Resize((resize_size, resize_size)),  # Resizing for VGG input
            transforms.ToTensor(),          # Convert to tensor
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # Normalization
        ])
        
        # Collect Image Groups
        for className in self.uniqueClassNames:
            currClassFolder = os.path.join(datasetPath, className)
            allImagesInFolder = sorted([
                    f for f in os.listdir(currClassFolder)
                    if f.lower().endswith(('.jpg', '.png', '.jpeg'))
                ])
        
            # Group images by their file prefix
            imageGroups = {}
            for currImageName in allImagesInFolder:
                prefix = '_'.join(currImageName.split('_')[:-1])
                if prefix not in imageGroups:
                    imageGroups[prefix] = []
                imageGroups[prefix].append(currImageName)
                
            # For each group (video), we only proceed if exactly 10 frames exist
            for prefix, sortedImageList in imageGroups.items():
                if len(sortedImageList) == self.NumFramesPerVid:
                    # Sort frames by index (0001, 0002, ...)
                    sortedImageList = sorted(sortedImageList)
                    stackedImages = []
            
                    # Apply transform to each frame and stack them
                    for imName in sortedImageList:
                        imPath = os.path.join(currClassFolder, imName)
                        img = Image.open(imPath).convert('RGB')
                        img = self.transform(img)   # shape: [3, 224, 224]
                        stackedImages.append(img)

                    # concat along channel dimension => [30, 224, 224]
                    # (10 frames × 3 channels each = 30 channels)
                    stackedTensor = torch.cat(stackedImages, dim=0)

                    self.imagesAndLabels.append({
                        "images": stackedTensor,
                        "label": className
                    })
                else:
                    print(f"Warning: Skipping {prefix} because it does not have exactly {self.NumFramesPerVid} images (found {len(sortedImageList)}).")                
                
        for item in self.imagesAndLabels:
                self.allImages.append(item["images"])
                self.allLabels.append(self.classNameToIndex[item["label"]])   
        
        train_images, val_images, train_labels, val_labels = train_test_split(
            self.allImages, 
            self.allLabels, 
            test_size=0.2, 
            random_state=42            
        )

        self.train_images = torch.stack(train_images, dim=0)  # [N, 30, 224, 224]
        self.val_images   = torch.stack(val_images,   dim=0)  # [N, 30, 224, 224]
        self.train_labels = torch.tensor(train_labels, dtype=torch.long)
        self.val_labels   = torch.tensor(val_labels,   dtype=torch.long) 
        
        # Compute Class Weights
        train_labels_np = self.train_labels.numpy()
        classes = np.unique(train_labels_np)  
        class_weights = compute_class_weight(
            class_weight='balanced', 
            classes=classes, 
            y=train_labels_np
        )        
        
        _, label_count = np.unique(self.train_labels, return_counts=True)

        # Convert to a PyTorch tensor and move to device
        self.class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
        
        print("Class Weights (on device):", self.class_weights)
        print("Class:", label_count)

## Model Architecture

In [None]:
class VisualGeometryGroup(nn.Module):
    def __init__(self, output_dim, resize_size):
        super(VisualGeometryGroup, self).__init__()        
        self.features = nn.Sequential(
            nn.Conv2d(30, 64, kernel_size=3, padding=1),   # Conv64 (takes 10images * 3RGB = 30 feature chnls)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),         # MaxPool
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),  # Conv128
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),         # MaxPool
            
            nn.Conv2d(128, 256, kernel_size=3, padding=1), # Conv256
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1), # Conv256
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),         # MaxPool
            
            nn.Conv2d(256, 512, kernel_size=3, padding=1), # Conv512
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1), # Conv512
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),         # MaxPool
            
            nn.Conv2d(512, 512, kernel_size=3, padding=1), # Conv512
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1), # Conv512
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),         # MaxPool
        )
        
        self.classifier = nn.Sequential(            
            nn.Linear(512 * int(resize_size/32) * int(resize_size/32), 4096),  # Adjust based on input size
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, output_dim),
            nn.Softmax(dim=1)
        )
        
    def forward(self, x):        
        x = self.features(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.classifier(x)
        return x        

## Training and Validation pipeline

In [None]:
# hyperparameters
batch_sizes = [8, 16, 32]
resize_sizes = [128, 224, 256, 320] 
learning_rates = [0.01, 0.001, 0.0001, 0.00001] 
epochs = 50
patience = 10

# Generate all combinations of hyperparameters
hyperparameter_combinations = list(itertools.product(batch_sizes, resize_sizes, learning_rates))

# Base directory where output folders will be created
datasetPath = r""
base_output_dir = r""
os.makedirs(base_output_dir, exist_ok=True)

# Global trackers for best result
maxValAcc = 0.0
maxValAccFolder = None

prevResize_Size = 0

# Iterate over hyperparameter combinations
for index, (batch_size, resize_size, lr) in enumerate(hyperparameter_combinations):
    folder_name = f"{index + 1:03d}"  # e.g., '001', '002', etc.
    newFolderPath = os.path.join(base_output_dir, folder_name)
    os.makedirs(newFolderPath, exist_ok=True)
        
    # Logging
    epochLogs = []
    reportText = []    
    
    print(f"Running combination {index + 1}: Batch Size={batch_size}, Resize={resize_size}, LR={lr}")

    # Initialise/re-initialise dataset only when resize_size changes
    if prevResize_Size != resize_size:
        dataset = processDataset(datasetPath, resize_size) # saves time by executing only when needed
    prevResize_Size = resize_size
    
    # Initialise model and lossFn, optimiser
    model = VisualGeometryGroup(output_dim=len(dataset.uniqueClassNames), resize_size=resize_size).to(device)
    lossFn = nn.CrossEntropyLoss(weight=dataset.class_weights)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Variables for early stopping
    bestValAccuracyForConfig = 0.0
    bestModelStateForConfig = copy.deepcopy(model.state_dict())
    epochsNoImprove = 0
    
    # Train and validate model
    for epoch in range(epochs):
        model.train()
        train_correct = 0
        train_total = 0
        train_loss_total = 0.0

        for i in range(0, len(dataset.train_images), batch_size):
            Xbatch = dataset.train_images[i : i + batch_size].to(device)
            ybatch = dataset.train_labels[i : i + batch_size].to(device)
            
            optimizer.zero_grad()           # Clear old gradients
            y_pred = model(Xbatch)          # Forward pass          
            loss = lossFn(y_pred, ybatch)   # Compute loss                                    
            loss.backward()                 # Backpropagation
            optimizer.step()                # Update parameters
            
            # Calculate training accuracy
            predictions = torch.argmax(y_pred, dim=1)
            train_correct += (predictions == ybatch).sum().item()
            train_total += ybatch.size(0)
            train_loss_total += loss.item()
        
        # Training accuracy for the epoch
        train_accuracy = 100 * train_correct / train_total
        number_of_batches = len(dataset.train_images) / batch_size
        train_loss_avg = train_loss_total / number_of_batches
        
        # Validate & save outputs
        model.eval()
        val_correct = 0
        numValSamples = 0
        val_loss_total = 0 # loss over multiple batches
        
        with torch.no_grad():                
            for j in range(0, len(dataset.val_images), batch_size):
                Xval = dataset.val_images[j : j + batch_size].to(device)
                yval = dataset.val_labels[j : j + batch_size].to(device)            
                
                y_pred_val = model(Xval)

                # Calculate validation loss & accuracy
                val_loss_total += lossFn(y_pred_val, yval).item()
                val_predictions = torch.argmax(y_pred_val, dim=1)            
                val_correct += (val_predictions == yval).sum().item()
                numValSamples += yval.size(0)
                                
        val_accuracy = 100 * val_correct / numValSamples
        number_of_batches = len(dataset.val_images) / batch_size
        val_loss_avg = val_loss_total / number_of_batches

        print(f"Epoch {epoch+1}/{epochs}, Train Acc: {train_accuracy:.4f}%, Train loss: {train_loss_avg:.4f}, "
                                          f"Val Acc: {val_accuracy:.4f}%, Val loss: {val_loss_avg:.4f}")                
        epoch_line = (f"Epoch {epoch+1}/{epochs}, "
                        f"Train Acc: {train_accuracy:.4f}%, " f"Train loss: {train_loss_avg:.4f}, "
                        f"Val Acc: {val_accuracy:.4f}%, Val loss: {val_loss_avg:.4f}")   
        epochLogs.append(epoch_line)

        # Check if this is the best validation accuracy so far for curr config
        if val_accuracy > bestValAccuracyForConfig:
            bestValAccuracyForConfig = val_accuracy
            bestModelStateForConfig = copy.deepcopy(model.state_dict())
            epochsNoImprove = 0
        else:
            epochsNoImprove += 1

        # Early stopping check
        if epochsNoImprove >= patience:
            print(f"Early stopping triggered at epoch {epoch+1} (no improvement in {patience} consecutive epochs).")
            break
    
    # Inference    
    model.load_state_dict(bestModelStateForConfig)
    
    # Evaluate on the validation set using the best model
    model.eval()    
    valPredsList = [] # predictions over multiple batches
    valLabelsList = []
    
    with torch.no_grad():                
            for j in range(0, len(dataset.val_images), batch_size):
                Xval = dataset.val_images[j : j + batch_size].to(device)
                yval = dataset.val_labels[j : j + batch_size].to(device)            
                
                y_pred_val = model(Xval)
                val_predictions = torch.argmax(y_pred_val, dim=1)         
                
                valPredsList.extend(val_predictions.cpu().numpy())
                valLabelsList.extend(yval.cpu().numpy())
    
    # Confusion matrix & classification report
    valPredsArray  = np.array(valPredsList)
    valLabelsArray = np.array(valLabelsList)
    
    # Identify which classes actually appear in validation
    valClassesInUse = np.unique(valLabelsArray)

    # Map numeric labels (0,1,2,...) to string names        
    matchedValLabelsNameArray = [dataset.IndexToClassName[i] for i in valClassesInUse]

    sorted_indices = np.argsort(matchedValLabelsNameArray)
    valClassNamesSorted = [matchedValLabelsNameArray[idx] for idx in sorted_indices]
    valClassesInUse_sorted = valClassesInUse[sorted_indices]

    valConfMatrix = confusion_matrix(
        valLabelsArray,
        valPredsArray,
        labels=valClassesInUse_sorted
    )

    # Plot confusion matrix
    plt.figure(figsize=(8, 8))          # Increase figure size for large # of classes
    disp = ConfusionMatrixDisplay(
        confusion_matrix=valConfMatrix,
        display_labels=valClassNamesSorted
    )
    disp.plot(include_values=True, 
                cmap=plt.cm.Blues,
                xticks_rotation='vertical'
    )            
    confusionMatrixPath = os.path.join(newFolderPath, "confusion_matrix.png")
    plt.savefig(confusionMatrixPath, bbox_inches='tight')
    plt.close()

    # Classification report
    stringLabelsForReport = [dataset.IndexToClassName[i] for i in valClassesInUse]
    class_report = classification_report(
        valLabelsArray,
        valPredsArray,
        labels=valClassesInUse,        
        targetNames=stringLabelsForReport,
        zero_division=0
    )

    # Compute Per-Class TP, FP, FN, TN
    #num_classes = len(valClassesInUse)
    totalSamples = valConfMatrix.sum()
    tp_fp_fn_tn_lines = ["Per-Class TP/FP/FN/TN:"]

    for i, class_label in enumerate(valClassesInUse):
        TP = valConfMatrix[i, i]
        FP = valConfMatrix[:, i].sum() - TP
        FN = valConfMatrix[i, :].sum() - TP
        TN = totalSamples - (TP + FP + FN)
        
        line = (f"Class {class_label} --> "
                f"TP: {TP}, FP: {FP}, FN: {FN}, TN: {TN}")
        tp_fp_fn_tn_lines.append(line)

    tp_fp_fn_tn_report = "\n".join(tp_fp_fn_tn_lines)
    
    # Update global best if improved
    if bestValAccuracyForConfig > maxValAcc:
        maxValAcc = bestValAccuracyForConfig
        maxValAccFolder = index + 1 # +1 so it matches folder numbering
    
    # Combine everything into a SINGLE text file        
    reportText = []
    reportText.append("Hyperparameters and Results\n" + "="*40)
    reportText.append(f"Batch Size: {batch_size}")
    reportText.append(f"Image Resize: {resize_size}")
    reportText.append(f"Learning Rate: {lr}")
    reportText.append(f"Validation Loss: {val_loss_avg:.4f}")
    reportText.append(f"Validation Accuracy: {val_accuracy:.2f}%\n")
    
    # Add epochLogs
    reportText.append("Epoch Logs:")
    reportText.extend(epochLogs)
    
    reportText.append(f"\nGlobal Best Accuracy So Far: {maxValAcc:.4f}%")
    reportText.append(f"Folder with Global Best Accuracy So Far: {maxValAccFolder}\n")      

    # Confusion matrix (text-version)
    reportText.append("Confusion Matrix (text version):")

    maxLabelLength = max(len(lbl) for lbl in valClassNamesSorted)
    
   # 2) Build the header row
    header = " " * (maxLabelLength + 2)  # Padding for row labels
    for label in valClassNamesSorted:
        header += f"{label:>5} "  # Each label right-aligned
    reportText.append(header)

    # 3) Build each row with sorted class labels
    for row_label, row_values in zip(valClassNamesSorted, valConfMatrix):
        row_str = f"{row_label:<{maxLabelLength}} "  # Left-align class label
        for val in row_values:
            row_str += f"{val:>5} "  # Right-align each value
        reportText.append(row_str)
    reportText.append("")  # Blank line for readability

    # Add the classification report
    reportText.append("Classification Report:")
    reportText.append(classification_report(
        valLabelsArray,
        valPredsArray,
        labels=valClassesInUse_sorted,
        targetNames=valClassNamesSorted,
        zero_division=0
    ))
    
    # Per-class TP/FP/FN/TN
    reportText.append(tp_fp_fn_tn_report)

    finalReport = "\n".join(reportText)

    # Save to a text file
    combinedReportPath = os.path.join(newFolderPath, "combined_report.txt")
    with open(combinedReportPath, "w") as f:
        f.write(finalReport)

    # Save the model
    modelPath = os.path.join(newFolderPath, "model.pth")
    torch.save(model.state_dict(), modelPath)

print("Grid search complete!")
print(f"Best accuracy overall: {maxValAcc:.2f}%")
print(f"Best folder overall: {maxValAccFolder}") 