In [None]:
import os
import copy
import torch
import itertools
import numpy as np
import torch.nn as nn
import torch.optim as optim
from matplotlib import pyplot as plt
from torchvision import datasets, transforms
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, random_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report

In [None]:
# find device
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print("Using {} device".format(device))

# Data processing

In [None]:
class processDataset():
    def __init__(self, datasetPath, resize_size):
        # Data augmentation and normalisation
        transform = transforms.Compose([
            transforms.Resize((resize_size, resize_size)),                      # Resizing for VGG input
            transforms.ToTensor(),                                              # Convert to tensor
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # Normalisation
        ])
        
        # Use ImageFolder to load dataset:
        self.fullDataset = datasets.ImageFolder(root=datasetPath, transform=transform)        

        # Extract labels
        labels = [label for _, label in self.fullDataset]

        # Perform stratified split
        train_indices, val_indices = train_test_split(
            range(len(self.fullDataset)), test_size=0.2, stratify=labels, random_state=42
        )

        # Create subsets
        train_dataset = torch.utils.data.Subset(self.fullDataset, train_indices)
        val_dataset = torch.utils.data.Subset(self.fullDataset, val_indices)            

        # Compute class weights using scikit-learn
        train_labels = [self.fullDataset[idx][1] for idx in train_indices]
        self.class_weights = compute_class_weight(
            class_weight="balanced",
            classes=np.unique(train_labels),
            y=train_labels
        )
        
        self.class_weights = torch.tensor(self.class_weights, dtype=torch.float32).to(device)
        print("Class weights:", self.class_weights)        

        self.train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=2)
        self.val_loader = DataLoader(val_dataset, batch_size=10, shuffle=False, num_workers=2) 

datasetPath = r""
dataset = processDataset(datasetPath, 224)   

## Model architecture

In [None]:
class VisualGeometryGroup(nn.Module):
    def __init__(self, output_dim, resize_size):
        super(VisualGeometryGroup, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),    # Conv64
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),         # MaxPool
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),  # Conv128
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),         # MaxPool
            
            nn.Conv2d(128, 256, kernel_size=3, padding=1), # Conv256
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1), # Conv256
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),         # MaxPool
            
            nn.Conv2d(256, 512, kernel_size=3, padding=1), # Conv512
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1), # Conv512
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),         # MaxPool
            
            nn.Conv2d(512, 512, kernel_size=3, padding=1), # Conv512
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1), # Conv512
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),         # MaxPool
        )        
        
        self.classifier = nn.Sequential(            
            nn.Linear(512 * int(resize_size // 32) * int(resize_size // 32), 4096),  # Adjust based on input size
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, output_dim),
        )
        
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.classifier(x)
        return x

## Training and validation pipeline

In [None]:
# Hyper-parameters
batch_sizes = [ 8, 16, 32]
resize_sizes = [128, 224, 256, 320] # 128, 224, 256, 320
learning_rates = [0.01, 0.001, 0.0001, 0.00001] # 0.01, 0.001, 0.0001, 0.00001
epochs = 20
patience = 3

# Generate all combinations
hyperparameter_combinations = list(itertools.product(batch_sizes, resize_sizes, learning_rates))

# Base directory where output folders will be created
datasetPath = r""
base_output_dir = r""
os.makedirs(base_output_dir, exist_ok=True)

# Global trackers for the best result
maxValAcc = 0.0
maxValAccFolder = None
prevResize_Size = 0

# Iterate over hyper-parameter combinations
for index, (batch_size, resize_size, lr) in enumerate(hyperparameter_combinations):    
    folderName = f"{index + 1:03d}"  # e.g., '001', '002', etc.
    newFolderPath = os.path.join(base_output_dir, folderName)
    os.makedirs(newFolderPath, exist_ok=True)
    
    # Logging
    epochLogs = []
    reportText = []
    reportText.append("Hyperparameters and Results\n" + "="*40)
    reportText.append(f"Batch Size: {batch_size}")
    reportText.append(f"Learning Rate: {lr}")
    reportText.append(f"Image resize: {resize_size}\n")
    
    print(f"Running combination {index + 1}: Batch Size={batch_size}, LR={lr}, resize={resize_size}")
    
    # re-initialise dataset only when resize_size changes
    if prevResize_Size != resize_size:
        dataset = processDataset(datasetPath, resize_size) # saves time by executing only when needed
    prevResize_Size = resize_size
    
    # Initialise model and lossFn, optimiser
    model = VisualGeometryGroup(output_dim=5, resize_size=resize_size).to(device)
    lossFn = nn.CrossEntropyLoss(weight=dataset.class_weights)
    optimizer = optim.Adam(model.parameters(), lr=lr)        
    
    # Variables for early stopping
    bestValAccuracyForConfig = 0.0
    bestModelStateForConfig = copy.deepcopy(model.state_dict())
    epochsNoImprove = 0
    
    # Train and validate model
    for epoch in range(epochs):
        model.train()
        train_correct = 0
        train_total   = 0
        train_loss_total = 0.0
        
        for Xbatch, ybatch in dataset.train_loader:
            Xbatch, ybatch = Xbatch.to(device), ybatch.to(device)
            
            optimizer.zero_grad()
            y_pred = model(Xbatch)
            loss = lossFn(y_pred, ybatch)            
            loss.backward()
            optimizer.step()
            
            # Calculate training accuracy
            predictions = torch.argmax(y_pred, dim=1)
            train_correct += (predictions == ybatch).sum().item()
            train_total   += ybatch.size(0)
            train_loss_total += loss.item()                 
        
        # Training accuracy for the epoch
        train_accuracy = (train_correct / train_total) * 100.0
        train_loss_avg = train_loss_total / len(dataset.train_loader)            
        
        # Validate & save outputs
        model.eval()
        val_correct = 0
        numValSamples = 0
        val_loss_total = 0.0        
        
        with torch.no_grad():
            for Xval, yval in dataset.val_loader:
                Xval, yval = Xval.to(device), yval.to(device)
                
                y_pred_val = model(Xval)
                
                # Calculate validation loss & accuracy
                val_loss_total += lossFn(y_pred_val, yval).item()                                    
                val_predictions = torch.argmax(y_pred_val, dim=1)
                val_correct += (val_predictions == yval).sum().item()
                numValSamples += yval.size(0)                                
        
        val_accuracy = 100.0 * val_correct / numValSamples
        val_loss_avg = val_loss_total / len(dataset.val_loader)
        
        print(f"Epoch {epoch+1}/{epochs}, Train Acc: {train_accuracy:.4f}%, Train loss: {train_loss_avg:.4f}, "
                                          f"Val Acc: {val_accuracy:.4f}%, Val loss: {val_loss_avg:.4f}")        
        epoch_line = (f"Epoch {epoch+1}/{epochs}, "
                      f"Train Acc: {train_accuracy:.4f}%, " f"Train loss: {train_loss_avg:.4f}, "
                      f"Val Acc: {val_accuracy:.4f}%, " f"Val loss: {val_loss_avg:.4f}")
        epochLogs.append(epoch_line)
        
        # Check if this is the best validation accuracy so far for curr config
        if val_accuracy > bestValAccuracyForConfig:
            bestValAccuracyForConfig = val_accuracy
            bestModelStateForConfig = copy.deepcopy(model.state_dict())
            epochsNoImprove = 0
        else:
            epochsNoImprove += 1

        # Early stopping check
        if epochsNoImprove >= patience:
            print(f"Early stopping triggered at epoch {epoch+1} (no improvement in {patience} consecutive epochs).")
            break
    
    # Inference    
    model.load_state_dict(bestModelStateForConfig)
    
    # Evaluate on the validation set using the best model
    model.eval()
    valPredsList  = []
    valLabelsList = []
    
    with torch.no_grad():
            for Xval, yval in dataset.val_loader:
                Xval, yval = Xval.to(device), yval.to(device)
                
                y_pred_val = model(Xval)                                                                                  
                val_predictions = torch.argmax(y_pred_val, dim=1)                
                                
                valPredsList.extend(val_predictions.cpu().numpy())
                valLabelsList.extend(yval.cpu().numpy())                        

    # Confusion matrix & classification report
    valPredsArray  = np.array(valPredsList)
    valLabelsArray = np.array(valLabelsList)
    
    # Identify which classes actually appear in validation
    valClassesInUse = np.unique(valLabelsArray)
    
    # Map numeric labels (0,1,2,...) to string names
    valLabelsNameArray = dataset.val_loader.dataset.dataset.classes
    matchedValLabelsNameArray = [valLabelsNameArray[i] for i in valClassesInUse]
    
    valConfMatrix = confusion_matrix(
        valLabelsArray,
        valPredsArray,
        labels=valClassesInUse
    )
    
    # Plot confusion matrix
    plt.figure(figsize=(8, 8))      # Increase figure size for large # of classes
    disp = ConfusionMatrixDisplay(
        confusion_matrix=valConfMatrix,
        display_labels=matchedValLabelsNameArray
    )
    disp.plot(cmap=plt.cm.Blues, xticks_rotation='vertical')    
    confusionMatrixPath = os.path.join(newFolderPath, "confusion_matrix.png")
    plt.savefig(confusionMatrixPath, bbox_inches='tight')
    plt.close()
    
    # Classification report
    stringLabelsForReport = [matchedValLabelsNameArray[i] for i in valClassesInUse]
    class_report = classification_report(
        valLabelsArray,
        valPredsArray,
        labels=valClassesInUse,        
        target_names=stringLabelsForReport,
        zero_division=0
    )
    
    # Compute per-class TP/FP/FN/TN
    total_samples = valConfMatrix.sum()
    tp_fp_fn_tn_lines = ["Per-Class TP/FP/FN/TN:"]
        
    for i, class_label in enumerate(valClassesInUse):
        TP = valConfMatrix[i, i]
        FP = valConfMatrix[:, i].sum() - TP
        FN = valConfMatrix[i, :].sum() - TP
        TN = total_samples - (TP + FP + FN)
        
        line = (f"Class {class_label} --> "
                f"TP: {TP}, FP: {FP}, FN: {FN}, TN: {TN}")
        tp_fp_fn_tn_lines.append(line)
        
    tp_fp_fn_tn_report = "\n".join(tp_fp_fn_tn_lines)
    
    # Update global best if improved
    if bestValAccuracyForConfig > maxValAcc:
        maxValAcc = bestValAccuracyForConfig
        maxValAccFolder = index + 1 # +1 so it matches folder numbering
    
    # Combine everything into a SINGLE text file   
   
    reportText.append("Epoch Logs:")
    reportText.extend(epochLogs)
    reportText.append(f"\n")
    
    reportText.append(f"Global Best Accuracy So Far: {maxValAcc:.4f}%")
    reportText.append(f"Folder with Global Best Accuracy So Far: {maxValAccFolder}\n")

    # confusion matrix text
    maxLabelLength = max(len(lbl) for lbl in matchedValLabelsNameArray)
    # Header
    # header = " " * (maxLabelLength + 1)
    # for label in matchedValLabelsNameArray:
    #     header += f"{label:>5} "
    # reportText.append(header)

    # Rows
    for row_label, row_values in zip(matchedValLabelsNameArray, valConfMatrix):
        row_str = f"{row_label:<{maxLabelLength}} "
        for val in row_values:
            row_str += f"{val:>5} "
        reportText.append(row_str)

    reportText.append("")  # blank line
    
    # Classification report
    reportText.append("Classification Report:")
    reportText.append(class_report)
    
    # Per-class TP/FP/FN/TN
    reportText.append(tp_fp_fn_tn_report)
    
    final_report = "\n".join(reportText)

    # Save to a text file
    combined_report_path = os.path.join(newFolderPath, "combinedReport.txt")
    with open(combined_report_path, "w") as f:
        f.write(final_report)
    
    # Save the model
    model_path = os.path.join(newFolderPath, "model.pth")
    torch.save(bestModelStateForConfig, model_path)

print("Grid search complete!")
print(f"Best accuracy overall: {maxValAcc:.4f}%")
print(f"Best folder overall: {maxValAccFolder}") 