# Imports

This notebook is made to compare and evaluate each "best model" we obtained for DANN, CORAL and BASE model (VGG-16 trained on the generated data). All models are trained and tuned in their corresponding files (in the Models folder). This file evaluates them, and compares the best models obtained after hyperparameter tuning based on the result of multiple metrics: testing accuracy, f1-score, confusion matrix and average AUPRC.

In [5]:
import torch 
import os
import torchmetrics
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
import torchvision
from PIL import Image
from sklearn.metrics import f1_score, confusion_matrix, precision_recall_curve, auc

# If VSCode doesn't pick up this import, see answer here: 
# https://stackoverflow.com/questions/65252074/import-path-to-own-script-could-not-be-resolved-pylance-reportmissingimports
import sys
sys.path.append("../../Datasets/")
from Custom_Dataset import * 

We define the device to use

In [6]:
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.has_mps:
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")

if DEVICE == "cuda":
    torch.cuda.empty_cache()

print("Device:", DEVICE)

Device: mps


# Plot Training and Validation Curves

We plot all the metrics collected during the training of the models. These metrics were stored in the corresponding JSON files. We will only plot the metrics for the best models.

In [8]:
# Import the JSON files containing the results


### DANN

### CORAL

### BASE

# Model Classes

We now define the classes for each model. The base model and CORAL model have one class, while the DANN has two classes: one for the feature extractor and one for the classifier.

### DANN

In [9]:
class Classifier(nn.Module):

    def __init__(self, input_size=4608, num_classes=13):
        super(Classifier, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(input_size, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, num_classes)
        )

    def forward(self, h):
        c = self.layer(h)
        return c
    
class FeatureExtractor(nn.Module):
    """
        Feature Extractor
    """
    def __init__(self):
        super(FeatureExtractor, self).__init__()

        # Import the VGG16 model
        self.conv = torchvision.models.vgg16(weights='VGG16_Weights.IMAGENET1K_V1').features

        # Freeze all the weights in modules 0 up-to and including 25
        for param in self.conv[:25].parameters():
            param.requires_grad = False

        
    def forward(self, x):
        x = self.conv(x)
        x = torch.flatten(x, 1)
        return x

### CORAL

In [10]:
class CoralModel(nn.Module):

    def __init__(self, num_classes=13):
        
        super(CoralModel, self).__init__()
        
        # Define the layers of the model
        self.features = torchvision.models.vgg16(weights='VGG16_Weights.IMAGENET1K_V1').features
        self.classifier = nn.Sequential(
            nn.Linear(4608, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 512),
            nn.ReLU(True),
            nn.Linear(512, num_classes)
        )

        # Freeze all the weights in modules 0 up-to and including 25
        for param in self.features[:25].parameters():
            param.requires_grad = False


    def forward(self, x):
        h = self.features(x)
        h = torch.flatten(h, 1)
        output = self.classifier(h)
        return h, output

### BASE

In [11]:
# Define the class for the Base Model
class BaseModel(nn.Module):

    def __init__(self, num_classes=13, dropout_rate=0.5):
        
        super(BaseModel, self).__init__()
        
        # Define the layers of the model
        self.features = torchvision.models.vgg16(weights='VGG16_Weights.IMAGENET1K_V1').features
        self.classifier = nn.Sequential(
            nn.Linear(4608, 1024),
            nn.ReLU(True),
            nn.Dropout(dropout_rate),
            nn.Linear(1024, 512),
            nn.ReLU(True),
            nn.Dropout(dropout_rate),
            nn.Linear(512, num_classes)
        )

        # Set the features to not require gradients
        for param in self.features.parameters():
            param.requires_grad = False


    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


We now load the models we saved as the best ones for each model type (DANN, CORAL, and Base Model).

In [13]:
DANN_path = "../../Models/DANN/"
DANN_C_path = DANN_path + "best_model_C.ckpt"
DANN_F_path = DANN_path + "best_model_F.ckpt"
CORAL_path = "../../Models/CORAL/best_CORAL_model.ckpt"
BASE_path = "../../Models/BASE/TRAIN_GEN/best_BASE_TRAIN_GEN_model.ckpt"

# # Instantiate the classifier model object for the dann model
# DANN_C = Classifier().to(DEVICE)
# # Load the state dictionary from the checkpoint file
# state_dict = torch.load(DANN_C_path, map_location=torch.device(DEVICE))
# # Load the state dictionary into the model
# DANN_C.load_state_dict(state_dict)
# # Put the model into evaluation mode
# DANN_C.eval()

# # Instantiate the feature extractor model object for the dann model
# DANN_F = FeatureExtractor().to(DEVICE)
# # Load the state dictionary from the checkpoint file
# state_dict = torch.load(DANN_F_path, map_location=torch.device(DEVICE))
# # Load the state dictionary into the model
# DANN_F.load_state_dict(state_dict)
# # Put the model into evaluation mode
# DANN_F.eval()

# Instantiate the model object for the coral model
coral_model = CoralModel().to(DEVICE)
# Load the state dictionary from the checkpoint file
state_dict = torch.load(CORAL_path, map_location=torch.device(DEVICE))
# Load the state dictionary into the model
coral_model.load_state_dict(state_dict)
# Put the model into evaluation mode
coral_model.eval()

# Instantiate the model object for the base model
base_model = BaseModel().to(DEVICE)
# Load the state dictionary from the checkpoint file
state_dict = torch.load(BASE_path, map_location=torch.device(DEVICE))
# Load the state dictionary into the model
base_model.load_state_dict(state_dict)
# Put the model into evaluation mode
base_model.eval()

# To not show the Base architecture
print("")




# Metric Computation

### Plotting and metric collection functions

We will compare the performance of the best models using the testing accuracy, f1-score, confusion matrix and average AUPRC metrics. 

* Specifically, for the testing accuracy, we will use the balanced accuracy score, which is similar to regular classification accuracy, but it takes into account the frequency of each class. The balanced accuracy score will be computed using on both the real-life dataset and the generated dataset, however, our ultimate goal is to predict labels for real-life images.

* As a reminder, the f1-score is a performance metric that provides a balance between the precision and recall of the model. The f1-score ranges from 0 to 1, with a higher score indicating a better performance of the model. The equation for f1-score is: $$ f1 = \frac{2 (precision * recall)}{precision + recall}$$ where $precision$ is the number of true positives (correctly predicted positive samples) divided by the total number of predicted positive samples, and $recall$ is the number of true positives divided by the total number of actual positive samples. In other words, f1-score takes into account both the model's ability to correctly identify positive samples (recall) and its tendency to not mislabel negative samples as positive (precision). 

* The confusion matrix compares the predicted class labels with the true class labels and counts the number of true positives (TP), true negatives (TN), false positives (FP), and false negatives (FN) for each class. It can help identify which classes the model is having difficulty classifying correctly, and whihc classes it can predict with ease.

* Finally, the AUPRC measures the overall quality of the model's predictions, taking into account both precision and recall across all possible classification thresholds. A high average AUPRC indicates that the model has a good balance of precision and recall across all possible classification thresholds, and is able to accurately distinguish between positive and negative cases. Conversely, a low average AUPRC indicates poor performance, and suggests that the model is making many incorrect predictions or missing many true positive cases.








We first define the balanced accuracy function as follows:

In [14]:
# Define the balanced accuracy function
accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=13, average="weighted").to(DEVICE)

To compare the final testing accuracies of each model (DANN, CORAL, and Base model), we evaluate their performance on both the real-life dataset and the generated dataset. Additionally, we analyze the metrics with and without oversampling. For this purpose, we use a function that takes the model(s), batch size, dataset and oversampling boolean as input to compute the testing accuracy, f1-score, confusion matrix and average AUPRC. 

It's important to note that the most important metrics we will use for comparison are the one on the real-life dataset without oversampling (so a classic distribution of chess boards).

In [None]:
# Define the testing accuracy evaluation function
def evaluate(model, batch_size, dataset, model2=None, oversampling=True, coral=False):
    # We use our custom dataset and loader defined in the "Datasets" folder.
    test_dataset = CustomDataset(dataset, "test", balance=oversampling)
    loader = get_loader(test_dataset, batch_size=batch_size)

    # Makes sur emodel is in evaluation mode
    model.eval()

    # Set accumulated accuracy to 0
    acc = 0

    y_true = [] # Ground truth labels
    y_hat = [] # Predicted labels

    # Remove grad
    with torch.no_grad():
        for images, labels in loader:
            # Move the data to the device
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            if coral:
                # Forward pass
                _, y_pred_prob = model(images)
            else:
                # Forward pass
                y_pred_prob = model(images)

            # If there is a second model 
            # (for DANN we first gothrough the Feature extractor and then the Classifier)
            if model2:
                y_pred_prob = model2(y_pred_prob)

            y_pred = torch.argmax(y_pred_prob, dim=1)

            # Compute the metrics
            acc +=  accuracy(y_pred, labels)

            y_true.extend(labels.cpu().numpy())
            y_hat.extend(y_pred.cpu().numpy())

        # Compute the average accuracy
        final_accuracy = acc / len(loader)
        
        # Compute the f1 score
        f1 = f1_score(y_true, y_hat, average='weighted')

        # Compute the confusion matrix
        cm = confusion_matrix(y_true, y_hat, normalize='true')

        y_true = np.asarray(y_true)
        y_hat = np.asarray(y_hat)

        n_classes = 13
        from sklearn.preprocessing import LabelBinarizer
        lb = LabelBinarizer()
        y_true_binary = lb.fit_transform(y_true)
        y_hat_binary = lb.fit_transform(y_hat)

        precision = dict()
        recall = dict()
        for i in range(n_classes):
            precision[i], recall[i], _ = precision_recall_curve(y_true_binary[:, i], y_hat_binary[:, i])
            
        auprc = dict()
        for i in range(n_classes):
            auprc[i] = auc(recall[i], precision[i])

        average_auprc = np.mean(list(auprc.values()))

    # Return all metrics computed
    return final_accuracy, f1, cm, average_auprc


We define a last function "get_metrics" to obtain all the metrics we defined above for all models depending on the dataset and if the data is oversampled or not.

In [None]:
# Fix the batch size to 32 for all accuracy computations.
batch_size = 32

In [None]:
'''
Function to get the testing accuracy, f1 score, confusion matrix and average AUPRC for all DANN, CORAL and Base models.
It takes as input:
- the dataset to evaluate the metrics on ("Real Life" or "Generated")
- if oversampling must be performed
'''
def get_metrics(dataset, oversampling):
    # Create empty dictionaries to store the results in
    # We will save these dictionaries in JSON files
    accuracies = {}
    f1s = {}
    auprcs = {}

    # We will not save the confusion matrix as JSON, but we will save it as a numpy array
    # We will direclty save it using matplotlib
    confusion_matrices = np.empty([0]) 
    
    # # Evaluate the DANN model
    # accuracy, f1, conf, auprc= evaluate(DANN_F, batch_size, dataset, DANN_C, oversampling=oversampling)
    # # Append results
    # accuracies["DANN"] = accuracy.item()
    # f1s["DANN"] = f1.item()
    # auprcs["DANN"] = auprc
    # confusion_matrices = np.append(confusion_matrices, conf)

    # # Evaluate the CORAL model
    # accuracy, f1, conf, auprc= evaluate(coral_model, batch_size, dataset, oversampling=oversampling, coral=True)
    # # Append results
    # accuracies["CORAL"] = accuracy.item()
    # f1s["CORAL"] = f1.item()
    # auprcs["CORAL"] = auprc
    # confusion_matrices = np.append(confusion_matrices, conf)

    # Evaluate the Base model
    accuracy, f1, conf, auprc = evaluate(base_model, batch_size, dataset, oversampling=oversampling)
    # Append results
    accuracies["BASE"] = accuracy.item()
    f1s["BASE"] = f1.item()
    auprcs["BASE"] = auprc
    confusion_matrices = np.append(confusion_matrices, conf)

    return accuracies, f1s, auprcs, confusion_matrices.reshape(3, 13, 13)

We also specify a function to save the confusion matrices obtained. This function takes as input the matrix to save, the model and dataset it corresponds to and if the data was oversampled or not.

In [None]:
'''
Function to save the confusion matrices as heatmaps
'''
def save_confusion_matrix(confusion_matrix, model, dataset, oversampling):
    # Define class names
    class_names = ['Empty Square', 'White Pawn', 'White Knight', 'White Bishop', 'White Rook', 'White Queen', 'White King', 
                'Black Pawn', 'Black Knight', 'Black Bishop', 'Black Rook', 'Black Queen', 'Black King']
    
    # Define figure and axis size
    fig, ax = plt.subplots(figsize=(8, 8))

    # Create heatmap
    heatmap = ax.imshow(confusion_matrix, cmap='Greens')

    # Add colorbar
    cbar = ax.figure.colorbar(heatmap, ax=ax)

    # Set tick labels for x and y axis
    ax.set_xticks(np.arange(len(class_names)))
    ax.set_yticks(np.arange(len(class_names)))
    ax.set_xticklabels(class_names, fontsize=10)
    ax.set_yticklabels(class_names, fontsize=10)

    # Rotate tick labels and set alignmentto have them appear vertically
    ax.set_xticklabels(class_names, fontsize=10, rotation=90, ha='center')
    plt.setp(ax.get_yticklabels(), rotation=0, ha='right')

    #Loop over data dimensions and create text annotations
    for i in range(len(class_names)):
        for j in range(len(class_names)):
            text = ax.text(j, i, '{:.1f}'.format(confusion_matrix[i, j]*100),
                        ha="center", va="center", color="black", fontsize=8)

    #Set title and axis labels
    if oversampling:
        title = "Confusion Matrix on " + dataset + " with oversampling (in %) - "+model
    else:
        title = "Confusion Matrix on " + dataset + " without oversampling (in %) - "+model
    
    ax.set_title(title, fontsize=12)
    ax.set_xlabel("Predicted label", fontsize=10)
    ax.set_ylabel("True label", fontsize=10)

    
    if oversampling:
        # Save figure
        plt.savefig(f"../{model}/Confusion Matrices/Real Life/{dataset} Oversample.png")
    else:
        # Save figure
        plt.savefig(f"../{model}/Confusion Matrices/Real Life/{dataset} No Oversample.png")

### Metrics computation

This section focuses on computing the metrics we will save as JSON format or plot later (confusion matrices).

In [None]:
# Create Dictionaries to store the results in
accuracies = {}
f1s = {}
auprcs = {}

In [None]:
# Add the JSON files for the real life dataset with oversampling to the dictionaries created above 
# Save the confusion matrices as numpy arrays
accuracies_real_life_oversample, f1s_real_life_oversample, confusion_matrices_real_life_oversample, auprcs_real_life_oversample = get_metrics("Real Life", True)

accuracies["Real Life Oversample"] = accuracies_real_life_oversample
f1s["Real Life Oversample"] = f1s_real_life_oversample
auprcs["Real Life Oversample"] = auprcs_real_life_oversample
save_confusion_matrix(confusion_matrices_real_life_oversample[0], "DANN", "Real Life", True)
save_confusion_matrix(confusion_matrices_real_life_oversample[1], "CORAL", "Real Life", True)
save_confusion_matrix(confusion_matrices_real_life_oversample[2], "BASE", "Real Life", True)

In [None]:
# Add the JSON files for the real life dataset without oversampling to the dictionaries created above
# Save the confusion matrices as numpy arrays
accuracies_real_life_no_oversample, f1s_real_life_no_oversample, confusion_matrices_real_life_no_oversample, auprcs_real_life_no_oversample = get_metrics("Real life", False)

accuracies["Real Life No Oversample"] = accuracies_real_life_no_oversample
f1s["Real Life No Oversample"] = f1s_real_life_no_oversample
auprcs["Real Life No Oversample"] = auprcs_real_life_no_oversample
save_confusion_matrix(confusion_matrices_real_life_no_oversample[0], "DANN", "Real Life", False)
save_confusion_matrix(confusion_matrices_real_life_no_oversample[1], "CORAL", "Real Life", False)
save_confusion_matrix(confusion_matrices_real_life_no_oversample[2], "BASE", "Real Life", False)

In [None]:
# Add the JSON files for the generated dataset with oversampling to the dictionaries created above
# Save the confusion matrices as numpy arrays
accuracies_generated_oversample, f1s_generated_oversample, confusion_matrices_generated_oversample, auprcs_generated_oversample = get_metrics("Generated", True)

accuracies["Generated Oversample"] = accuracies_generated_oversample
f1s["Generated Oversample"] = f1s_generated_oversample
auprcs["Generated Oversample"] = auprcs_generated_oversample
save_confusion_matrix(confusion_matrices_generated_oversample[0], "DANN", "Generated", True)
save_confusion_matrix(confusion_matrices_generated_oversample[1], "CORAL", "Generated", True)
save_confusion_matrix(confusion_matrices_generated_oversample[2], "BASE", "Generated", True)

In [None]:
# Add the JSON files for the generated dataset without oversampling to the dictionaries created above
# Save the confusion matrices as numpy arrays
accuracies_generated_no_oversample, f1s_generated_no_oversample, confusion_matrices_generated_no_oversample, auprcs_generated_no_oversample = get_metrics("Generated", False)

accuracies["Generated No Oversample"] = accuracies_generated_no_oversample
f1s["Generated No Oversample"] = f1s_generated_no_oversample
auprcs["Generated No Oversample"] = auprcs_generated_no_oversample
save_confusion_matrix(confusion_matrices_generated_no_oversample[0], "DANN", "Generated", False)
save_confusion_matrix(confusion_matrices_generated_no_oversample[1], "CORAL", "Generated", False)
save_confusion_matrix(confusion_matrices_generated_no_oversample[2], "BASE", "Generated", False)