# Imports

This notebook is made to compare and evaluate each "best model" we obtained for DANN, CORAL and BASE model (VGG-16 trained on the generated data). All models are trained and tuned in their corresponding files (in the Models folder). This file evaluates them, and compares the best models obtained after hyperparameter tuning based on the result of multiple metrics: testing accuracy, f1-score, confusion matrix and average AUPRC.

In [51]:
import torch 
import os
import torchmetrics
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
import torchvision
from PIL import Image
from sklearn.metrics import f1_score, confusion_matrix, precision_recall_curve, auc
import json

# If VSCode doesn't pick up this import, see answer here: 
# https://stackoverflow.com/questions/65252074/import-path-to-own-script-could-not-be-resolved-pylance-reportmissingimports
import sys
sys.path.append("../../Datasets/")
from Custom_Dataset import * 

We define the device to use

In [52]:
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.has_mps:
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")

if DEVICE == "cuda":
    torch.cuda.empty_cache()

print("Device:", DEVICE)

Device: cuda


# Plot Training and Validation Curves

We plot all the metrics collected during the training of the models. These metrics were stored in the corresponding JSON files. We will only plot the metrics for the best models.

### DANN

In [53]:
def plot_metrics_DANN(metrics_data, hyperparameters, sourcecolor="blue", targetcolor="orange"):
    folder_path = f"../DANN/Plots/lambda_{hyperparameters['lambda_DA']}_lr_{hyperparameters['learning_rate']}_gamma_{hyperparameters['gamma_focal_loss']}"
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)

    # Extract the relevant metrics from the JSON data
    gen_training_discriminator_losses = metrics_data["gen_training_discriminator_losses/iteration"]
    gen_training_classifier_losses = metrics_data["gen_training_classifier_losses/iteration"]
    gen_training_feature_extractor_losses = metrics_data["gen_training_feature_extractor_losses/iteration"]
    real_training_discriminator_losses = metrics_data["real_training_discriminator_losses/iteration"]
    gen_training_accs = metrics_data["gen_training_accs/iteration"]
    gen_training_f1s = metrics_data["gen_training_f1s/iteration"]
    gen_validation_discriminator_losses = metrics_data["gen_validation_discriminator_losses/n_validation"]
    gen_validation_classifier_losses = metrics_data["gen_validation_classifier_losses/n_validation"]
    gen_validation_feature_extractor_losses = metrics_data["gen_validation_feature_extractor_losses/n_validation"]
    real_validation_discriminator_losses = metrics_data["real_validation_discriminator_losses/n_validation"]
    real_validation_classifier_losses = metrics_data["real_validation_classifier_losses/n_validation"]
    real_validation_feature_extractor_losses = metrics_data["real_validation_feature_extractor_losses/n_validation"]
    gen_validation_accs = metrics_data["gen_validation_accs/n_validation"]
    gen_validation_f1s = metrics_data["gen_validation_f1s/n_validation"]
    real_validation_accs = metrics_data["real_validation_accs/n_validation"]
    real_validation_f1s = metrics_data["real_validation_f1s/n_validation"]
    real_validation_accs_full = metrics_data["real_validation_accs_full/epoch"]

    # Compute the array used for the validation x axis
    n_validation = hyperparameters["n_validation"]
    validation_x_axis = np.arange(0, len(gen_validation_discriminator_losses) * n_validation, n_validation)

    # Plot gen_training_discriminator_losses 
    plt.figure(figsize=(10, 5))
    plt.plot(gen_training_discriminator_losses, color=sourcecolor)
    plt.xlabel("Iteration")
    plt.ylabel("Loss")
    plt.title("Source Training Discriminator Losses")
    plt.savefig(f"{folder_path}/source_training_losses.png", dpi=500)
    plt.close()

    # Plot gen_training_classifier_losses
    plt.figure(figsize=(10, 5))
    plt.plot(gen_training_classifier_losses, color=sourcecolor)
    plt.xlabel("Iteration")
    plt.ylabel("Loss")
    plt.title("Source Training Classifier Losses")
    plt.savefig(f"{folder_path}/source_training_classifier_losses.png", dpi=500)
    plt.close()

    # Plot gen_training_feature_extractor_losses
    plt.figure(figsize=(10, 5))
    plt.plot(gen_training_feature_extractor_losses, color=sourcecolor)
    plt.xlabel("Iteration")
    plt.ylabel("Loss")
    plt.title("Source Training Feature Extractor Loss")
    plt.savefig(f"{folder_path}/source_training_feature_extractor_losses.png", dpi=500)
    plt.close()

    # Plot real_training_discriminator_losses
    plt.figure(figsize=(10, 5))
    plt.plot(real_training_discriminator_losses, color=sourcecolor)
    plt.xlabel("Iteration")
    plt.ylabel("Loss")
    plt.title("Target Training Discriminator Loss")
    plt.savefig(f"{folder_path}/target_training_discriminator_losses.png", dpi=500)
    plt.close()

    # Plot gen_training_accs
    plt.figure(figsize=(10, 5))
    plt.plot(gen_training_accs, color=sourcecolor, label="Accuracy")
    plt.xlabel("Iteration")
    plt.ylabel("Accuracy")
    plt.title("Source Training Accuracy")
    plt.savefig(f"{folder_path}/source_training_accs.png", dpi=500)
    plt.close()

    # Plot gen_training_f1s
    plt.figure(figsize=(10, 5))
    plt.plot(gen_training_f1s, color=sourcecolor, label="F1 score")
    plt.xlabel("Iteration")
    plt.ylabel("F1 score")
    plt.title("Source Training F1 score")
    plt.savefig(f"{folder_path}/source_training_f1s.png", dpi=500)
    plt.close()

    # Plot gen_validation_discriminator_losses and real_validation_discriminator_losses
    plt.figure(figsize=(10, 5))
    plt.plot(validation_x_axis, gen_validation_discriminator_losses, color=sourcecolor, label="Source discriminator loss")
    plt.plot(validation_x_axis, real_validation_discriminator_losses, color=targetcolor, label="Target discriminator loss")
    plt.xlabel("Iteration")
    plt.ylabel("Loss")
    plt.title("Validation Discriminator losses (Source and Target)")
    plt.legend()
    plt.savefig(f"{folder_path}/validation_discriminator_losses.png", dpi=500)
    plt.close()

    # Plot gen_validation_feature_extractor_losses and real_validation_feature_extractor_losses
    plt.figure(figsize=(10, 5))
    plt.plot(validation_x_axis, gen_validation_feature_extractor_losses, color=sourcecolor, label="Source feature extractor loss")
    plt.plot(validation_x_axis, real_validation_feature_extractor_losses, color=targetcolor, label="Target feature extractor loss")
    plt.xlabel("Iteration")
    plt.ylabel("Loss")
    plt.title("Validation Feature Extractor losses (Source and Target)")
    plt.legend()
    plt.savefig(f"{folder_path}/validation_feature_extractor_losses.png", dpi=500)
    plt.close()

    # Plot gen_validation_classifier_losses and real_validation_classifier_losses
    plt.figure(figsize=(10, 5))
    plt.plot(validation_x_axis, gen_validation_classifier_losses, color=sourcecolor, label="Source classifier loss")
    plt.plot(validation_x_axis, real_validation_classifier_losses, color=targetcolor, label="Target classifier loss")
    plt.xlabel("Iteration")
    plt.ylabel("Loss")
    plt.title("Validation Classifier losses (Source and Target)")
    plt.legend()
    plt.savefig(f"{folder_path}/validation_classifier_losses.png", dpi=500)
    plt.close()

    # Plot gen_validation_accs and real_validation_accs
    plt.figure(figsize=(10, 5))
    plt.plot(validation_x_axis, gen_validation_accs, color=sourcecolor, label="Source accuracy")
    plt.plot(validation_x_axis, real_validation_accs, color=targetcolor, label="Target accuracy")
    plt.xlabel("Iteration")
    plt.ylabel("Accuracy")
    plt.title("Validation Accuracy (Source and Target)")
    plt.legend()
    plt.savefig(f"{folder_path}/validation_accs.png", dpi=500)
    plt.close()

    # Plot gen_validation_f1s and real_validation_f1s
    plt.figure(figsize=(10, 5))
    plt.plot(validation_x_axis, gen_validation_f1s, color=sourcecolor, label="Source F1 score")
    plt.plot(validation_x_axis, real_validation_f1s, color=targetcolor, label="Target F1 score")
    plt.xlabel("Iteration")
    plt.ylabel("F1 score")
    plt.title("Validation F1 score (Source and Target)")
    plt.legend()
    plt.savefig(f"{folder_path}/validation_f1s.png", dpi=500)
    plt.close()

    # Plot real_validation_accs_full
    plt.figure(figsize=(10, 5))
    plt.plot(real_validation_accs_full, color=targetcolor, label="Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Target validation accuracy (Using the full dataset)")
    plt.legend()
    plt.savefig(f"{folder_path}/target_validation_accs_full.png", dpi=500)
    plt.close()


In [54]:
# Import the JSON files and plot the metrics
folder_path = "../../Models/DANN/HP tuning results/"
for file_name in os.listdir(folder_path):
    # Check file format
    if file_name.endswith(".json"):
        # Get the file path
        file_path = os.path.join(folder_path, file_name)
        # Open the JSON file
        with open(file_path, "r") as json_file:
            json_data = json.load(json_file)
            hyperparameters = json_data["hyperparameters"]
            metrics_data = json_data["metrics"]
            plot_metrics_DANN(metrics_data, hyperparameters, sourcecolor="saddlebrown", targetcolor="orange")

### CORAL

In [55]:
def plot_metrics_CORAL(metrics_data, hyperparameters, sourcecolor="blue", targetcolor="orange"):
    folder_path = f"../CORAL/Plots/lr_{hyperparameters['learning_rate']}_lambdamax_{hyperparameters['lambda_max_DA']}_gamma_{hyperparameters['gamma_focal_loss']}"
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
    
    # Extract the relevant metrics from the JSON data
    training_losses = metrics_data["training_total_losses/iteration"]
    coral_losses = metrics_data["training_coral_losses/iteration"]
    gen_classification_losses = metrics_data["gen_training_classification_losses/iteration"]
    gen_accs = metrics_data["gen_training_accs/iteration"]
    gen_f1s = metrics_data["gen_training_f1s/iteration"]
    gen_val_accs = metrics_data["gen_validation_accs/n_validation"]
    real_val_accs = metrics_data["real_validation_accs/n_validation"]
    gen_val_f1s = metrics_data["gen_validation_f1s/n_validation"]
    real_val_f1s = metrics_data["real_validation_f1s/n_validation"]
    val_losses = metrics_data["validation_total_losses_using_gen_classification_losses/n_validation"]
    coral_val_losses = metrics_data["validation_CORAL_losses/n_validation"]
    gen_val_classification_losses = metrics_data["gen_validation_classification_losses/n_validation"]
    real_val_classification_losses = metrics_data["real_validation_classification_losses/n_validation"]
    real_val_accs_full = metrics_data["real_validation_accs_full/epoch"]

    # Compute the array used for the validation x axis
    n_validation = hyperparameters["n_validation"]
    validation_x_axis = np.arange(0, len(val_losses) * n_validation, n_validation)

    # Create plots for training_losses and coral_losses
    plt.figure(figsize=(10, 6))
    plt.plot(training_losses, label='Training Total Losses', color=sourcecolor)
    plt.plot(coral_losses, label='Training Coral Losses', color=targetcolor)
    plt.legend()
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.title('Training Losses')
    plt.savefig(f"{folder_path}/training_losses.png", dpi=500)
    plt.close()

    # Create plots for gen_training_classification_losses
    plt.figure(figsize=(10, 6))
    plt.plot(gen_classification_losses, color=sourcecolor)
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.title('Source Training Classification Losses')
    plt.savefig(f"{folder_path}/source_training_classification_losses.png", dpi=500)
    plt.close()

    # Create plots for gen_training_accs
    plt.figure(figsize=(10, 6))
    plt.plot(gen_accs, color=sourcecolor)
    plt.xlabel('Iteration')
    plt.ylabel('Accuracy')
    plt.title('Source Training Accs')
    plt.savefig(f"{folder_path}/source_training_accs.png", dpi=500)
    plt.close()

    # Create plots for gen_training_f1s
    plt.figure(figsize=(10, 6))
    plt.plot(gen_f1s, color=sourcecolor)
    plt.xlabel('Iteration')
    plt.ylabel('F1 Score')
    plt.title('Source Training F1s')
    plt.savefig(f"{folder_path}/source_training_f1s.png", dpi=500)
    plt.close()

    # Create plots for gen_validation_accs and real_validation_accs
    plt.figure(figsize=(10, 6))
    plt.plot(validation_x_axis, gen_val_accs, label='Source', color=sourcecolor)
    plt.plot(validation_x_axis, real_val_accs, label='Target', color=targetcolor)
    plt.legend()
    plt.xlabel('Iteration')
    plt.ylabel('Accuracy')
    plt.title('Validation Accuracies (Source vs. Target Domain)')
    plt.savefig(f"{folder_path}/validation_accs.png", dpi=500)
    plt.close()

    # Create plots for gen_validation_f1s and real_validation_f1s
    plt.figure(figsize=(10, 6))
    plt.plot(validation_x_axis, gen_val_f1s, label='Source', color=sourcecolor)
    plt.plot(validation_x_axis, real_val_f1s, label='Target', color=targetcolor)
    plt.legend()
    plt.xlabel('Iteration')
    plt.ylabel('F1 Score')
    plt.title('Validation F1 Scores (Source vs. Target Domain)')
    plt.savefig(f"{folder_path}/validation_f1s.png", dpi=500)
    plt.close()

    # Create plots for val_losses and coral_val_losses
    plt.figure(figsize=(10, 6))
    plt.plot(validation_x_axis, val_losses, label='Validation Total Losses', color=sourcecolor)
    plt.plot(validation_x_axis, coral_val_losses, label='Validation CORAL Losses', color=targetcolor)
    plt.legend()
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.title('Validation Losses')
    plt.savefig(f"{folder_path}/validation_losses.png", dpi=500)
    plt.close()

    # Create plots for gen_val_classification_losses and real_val_classification_losses
    plt.figure(figsize=(10, 6))
    plt.plot(validation_x_axis, gen_val_classification_losses, label='Source', color=sourcecolor)
    plt.plot(validation_x_axis, real_val_classification_losses, label='Target', color=targetcolor)
    plt.legend()
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.title('Validation Classification Losses (Source vs Target Domain)')
    plt.savefig(f"{folder_path}/classification_losses.png", dpi=500)
    plt.close()

    # Create plots for real_val_accs_full
    plt.figure(figsize=(10, 6))
    plt.plot(real_val_accs_full, label='Target', color=targetcolor)
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Target validation accuracy (Using the full dataset)')
    plt.savefig(f"{folder_path}/target_val_accs_full.png", dpi=500)
    plt.close()

In [56]:
# Import the JSON files and plot the metrics
folder_path = "../../Models/CORAL/HP tuning results/"
for file_name in os.listdir(folder_path):
    # Check file format
    if file_name.endswith(".json"):
        # Get the file path
        file_path = os.path.join(folder_path, file_name)
        # Open the JSON file
        with open(file_path, "r") as json_file:
            json_data = json.load(json_file)
            hyperparameters = json_data["hyperparameters"]
            metrics_data = json_data["metrics"]
            plot_metrics_CORAL(metrics_data, hyperparameters, sourcecolor="saddlebrown", targetcolor="orange")

### BASE

In [57]:
def plot_metrics_BASE_Source(metrics_data, hyperparameters, sourcecolor="blue", targetcolor="orange"):
    folder_path = f"../BASE Source/Plots/gamma_{hyperparameters['gamma']}_dropout_{hyperparameters['dropout_rate']}"
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)

    # Extract the relevant metrics from the JSON data
    gen_training_losses = metrics_data["gen_training_losses/iteration"]
    gen_training_accs = metrics_data["gen_training_accs/iteration"]
    gen_training_f1s = metrics_data["gen_training_f1s/iteration"]
    gen_validation_accs = metrics_data["gen_validation_accs/n_validation"]
    real_validation_accs = metrics_data["real_validation_accs/n_validation"]
    gen_validation_f1s = metrics_data["gen_validation_f1s/n_validation"]
    real_validation_f1s = metrics_data["real_validation_f1s/n_validation"]
    real_validation_accs_full = metrics_data["real_validation_accs_full/epoch"]

    # Compute the array used for the validation x axis
    n_validation = hyperparameters["n_validation"]
    validation_x_axis = np.arange(0, len(gen_validation_accs) * n_validation, n_validation)

    # Create plots for gen_training_losses
    plt.figure(figsize=(10, 6))
    plt.plot(gen_training_losses, color=sourcecolor)
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.title('Source Training Losses')
    plt.savefig(f"{folder_path}/source_training_losses.png", dpi=500)
    plt.close()

    # Create plots for gen_training_accs
    plt.figure(figsize=(10, 6))
    plt.plot(gen_training_accs, color=sourcecolor)
    plt.xlabel('Iteration')
    plt.ylabel('Accuracy')
    plt.title('Source Training Accuracies')
    plt.savefig(f"{folder_path}/gen_training_accs.png", dpi=500)
    plt.close()

    # Create plots for gen_training_f1s
    plt.figure(figsize=(10, 6))
    plt.plot(gen_training_f1s, color=sourcecolor)
    plt.xlabel('Iteration')
    plt.ylabel('F1 Score')
    plt.title('Source Training F1 Scores')
    plt.savefig(f"{folder_path}/source_training_f1s.png", dpi=500)
    plt.close()

    # Create plots for gen_validation_accs and real_validation_accs
    plt.figure(figsize=(10, 6))
    plt.plot(validation_x_axis, gen_validation_accs, label='Source', color=sourcecolor)
    plt.plot(validation_x_axis,  real_validation_accs, label='Target', color=targetcolor)
    plt.legend()
    plt.xlabel('Iteration')
    plt.ylabel('Accuracy')
    plt.title('Validation Accuracies (Source vs. Target Domain)')
    plt.savefig(f"{folder_path}/validation_accs.png", dpi=500)
    plt.close()

    # Create plots for gen_validation_f1s and real_validation_f1s
    plt.figure(figsize=(10, 6))
    plt.plot(validation_x_axis, gen_validation_f1s, label='Source', color=sourcecolor)
    plt.plot(validation_x_axis, real_validation_f1s, label='Target', color=targetcolor)
    plt.legend()
    plt.xlabel('Iteration')
    plt.ylabel('F1 Score')
    plt.title('Validation F1 Scores (Source vs. Target Domain)')
    plt.savefig(f"{folder_path}/validation_f1s.png", dpi=500)
    plt.close()

    # Create plots for real_validation_accs_full
    plt.figure(figsize=(10, 6))
    plt.plot(real_validation_accs_full, label='Target', color=targetcolor)
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Target validation accuracy (Using the full dataset)')
    plt.savefig(f"{folder_path}/target_validation_accs_full.png", dpi=500)
    plt.close()

In [58]:
# Import the JSON files and plot the metrics for the Base model trained on the source domain
folder_path = "../../Models/BASE/TRAIN_GEN/HP tuning results/"
for file_name in os.listdir(folder_path):
    # Check file format
    if file_name.endswith(".json"):
        # Get the file path
        file_path = os.path.join(folder_path, file_name)
        # Open the JSON file
        with open(file_path, "r") as json_file:
            json_data = json.load(json_file)
            hyperparameters = json_data["hyperparameters"]
            metrics_data = json_data["metrics"]
            plot_metrics_BASE_Source(metrics_data, hyperparameters, sourcecolor="saddlebrown", targetcolor="orange")

In [59]:
def plot_metrics_BASE_Target(metrics_data, hyperparameters, targetcolor="orange"):
    folder_path = f"../BASE Target/Plots/gamma_{hyperparameters['gamma']}_dropout_{hyperparameters['dropout_rate']}"
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
    
    # Extract the relevant metrics from the JSON data
    real_training_losses = metrics_data["real_training_losses/iteration"]
    real_training_accs = metrics_data["real_training_accs/iteration"]
    real_training_f1s = metrics_data["real_training_f1s/iteration"]
    real_validation_losses = metrics_data["real_validation_losses/n_validation"]
    real_validation_accs = metrics_data["real_validation_accs/n_validation"]
    real_validation_f1s = metrics_data["real_validation_f1s/n_validation"]
    real_validation_accs_full = metrics_data["real_validation_accs_full/epoch"]

    # Compute the array used for the validation x axis
    n_validation = hyperparameters["n_validation"]
    validation_x_axis = np.arange(0, len(real_validation_losses) * n_validation, n_validation)

    # Create plots for real_training_losses
    plt.figure(figsize=(10, 6))
    plt.plot(real_training_losses, color=targetcolor)
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.title('Target Training Losses')
    plt.savefig(f"{folder_path}/target_training_losses.png", dpi=500)
    plt.close()

    # Create plots for real_training_accs
    plt.figure(figsize=(10, 6))
    plt.plot(real_training_accs, color=targetcolor)
    plt.xlabel('Iteration')
    plt.ylabel('Accuracy')
    plt.title('Target Training Accuracies')
    plt.savefig(f"{folder_path}/target_training_accs.png", dpi=500)
    plt.close()

    # Create plots for real_training_f1s
    plt.figure(figsize=(10, 6))
    plt.plot(real_training_f1s, color=targetcolor)
    plt.xlabel('Iteration')
    plt.ylabel('F1 Score')
    plt.title('Target Training F1 Scores')
    plt.savefig(f"{folder_path}/target_training_f1s.png", dpi=500)
    plt.close()

    # Create plots for real_validation_losses
    plt.figure(figsize=(10, 6))
    plt.plot(validation_x_axis, real_validation_losses, color=targetcolor)
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.title('Target Validation Losses')
    plt.savefig(f"{folder_path}/target_validation_losses.png", dpi=500)
    plt.close()

    # Create plots for real_validation_accs
    plt.figure(figsize=(10, 6))
    plt.plot(validation_x_axis, real_validation_accs, color=targetcolor)
    plt.xlabel('Iteration')
    plt.ylabel('Accuracy')
    plt.title('Target Validation Accuracies')
    plt.savefig(f"{folder_path}/real_validation_accs.png", dpi=500)
    plt.close()

    # Create plots for real_validation_f1s
    plt.figure(figsize=(10, 6))
    plt.plot(validation_x_axis, real_validation_f1s, color=targetcolor)
    plt.xlabel('Iteration')
    plt.ylabel('F1 Score')
    plt.title('Target Validation F1 Scores')
    plt.savefig(f"{folder_path}/target_validation_f1s.png", dpi=500)
    plt.close()

    # Create plots for real_validation_accs_full
    plt.figure(figsize=(10, 6))
    plt.plot(real_validation_accs_full, label='Target', color=targetcolor)
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Target validation accuracy (Using the full dataset)')
    plt.savefig(f"{folder_path}/target_validation_accs_full.png", dpi=500)
    plt.close()

In [60]:
# Import the JSON files and plot the metrics for the Base model trained on the target domain
folder_path = "../../Models/BASE/TRAIN_VAL_REAL/HP tuning results/"
for file_name in os.listdir(folder_path):
    # Check file format
    if file_name.endswith(".json"):
        # Get the file path
        file_path = os.path.join(folder_path, file_name)
        # Open the JSON file
        with open(file_path, "r") as json_file:
            json_data = json.load(json_file)
            hyperparameters = json_data["hyperparameters"]
            metrics_data = json_data["metrics"]
            plot_metrics_BASE_Target(metrics_data, hyperparameters, targetcolor="orange")

# Model Classes

We now define the classes for each model. The base model and CORAL model have one class, while the DANN has two classes: one for the feature extractor and one for the classifier.

### DANN

In [61]:
class Classifier(nn.Module):

    def __init__(self, input_size=4608, num_classes=13):
        super(Classifier, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(input_size, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, num_classes)
        )

    def forward(self, h):
        c = self.layer(h)
        return c
    
class FeatureExtractor(nn.Module):
    """
        Feature Extractor
    """
    def __init__(self):
        super(FeatureExtractor, self).__init__()

        # Import the VGG16 model
        self.conv = torchvision.models.vgg16(weights='VGG16_Weights.IMAGENET1K_V1').features

        # Freeze all the weights in modules 0 up-to and including 25
        for param in self.conv[:25].parameters():
            param.requires_grad = False

        
    def forward(self, x):
        x = self.conv(x)
        x = torch.flatten(x, 1)
        return x

### CORAL

In [62]:
class CoralModel(nn.Module):

    def __init__(self, num_classes=13):
        
        super(CoralModel, self).__init__()
        
        # Define the layers of the model
        self.features = torchvision.models.vgg16(weights='VGG16_Weights.IMAGENET1K_V1').features
        self.classifier = nn.Sequential(
            nn.Linear(4608, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 512),
            nn.ReLU(True),
            nn.Linear(512, num_classes)
        )

        # Freeze all the weights in modules 0 up-to and including 25
        for param in self.features[:25].parameters():
            param.requires_grad = False


    def forward(self, x):
        h = self.features(x)
        h = torch.flatten(h, 1)
        output = self.classifier(h)
        return h, output

### BASE

In [63]:
# Define the class for the Base Model
class BaseModel(nn.Module):

    def __init__(self, num_classes=13, dropout_rate=0.5):
        
        super(BaseModel, self).__init__()
        
        # Define the layers of the model
        self.features = torchvision.models.vgg16(weights='VGG16_Weights.IMAGENET1K_V1').features
        self.classifier = nn.Sequential(
            nn.Linear(4608, 1024),
            nn.ReLU(True),
            nn.Dropout(dropout_rate),
            nn.Linear(1024, 512),
            nn.ReLU(True),
            nn.Dropout(dropout_rate),
            nn.Linear(512, num_classes)
        )

        # Set the features to not require gradients
        for param in self.features.parameters():
            param.requires_grad = False


    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


We now load the models we saved as the best ones for each model type (DANN, CORAL, and Base Model).

In [64]:
DANN_path = "../../Models/DANN/"
DANN_C_path = DANN_path + "best_DANN_C_model.ckpt"
DANN_F_path = DANN_path + "best_DANN_F_model.ckpt"
CORAL_path = "../../Models/CORAL/best_CORAL_model.ckpt"
BASE_Source_path = "../../Models/BASE/TRAIN_GEN/best_BASE_TRAIN_GEN_model.ckpt"
BASE_Target_path = "../../Models/BASE/TRAIN_VAL_REAL/best_BASE_TRAIN_VAL_REAL_model.ckpt"

# Instantiate the classifier model object for the dann model
DANN_C = Classifier().to(DEVICE)
# Load the state dictionary from the checkpoint file
state_dict = torch.load(DANN_C_path, map_location=torch.device(DEVICE))
# Load the state dictionary into the model
DANN_C.load_state_dict(state_dict)
# Put the model into evaluation mode
DANN_C.eval()

# Instantiate the feature extractor model object for the dann model
DANN_F = FeatureExtractor().to(DEVICE)
# Load the state dictionary from the checkpoint file
state_dict = torch.load(DANN_F_path, map_location=torch.device(DEVICE))
# Load the state dictionary into the model
DANN_F.load_state_dict(state_dict)
# Put the model into evaluation mode
DANN_F.eval()

# Instantiate the model object for the coral model
coral_model = CoralModel().to(DEVICE)
# Load the state dictionary from the checkpoint file
state_dict = torch.load(CORAL_path, map_location=torch.device(DEVICE))
# Load the state dictionary into the model
coral_model.load_state_dict(state_dict)
# Put the model into evaluation mode
coral_model.eval()

# Instantiate the model object for the base model trained on the source dataset
base_model_source = BaseModel().to(DEVICE)
# Load the state dictionary from the checkpoint file
state_dict = torch.load(BASE_Source_path, map_location=torch.device(DEVICE))
# Load the state dictionary into the model
base_model_source.load_state_dict(state_dict)
# Put the model into evaluation mode
base_model_source.eval()

# Instantiate the model object for the base model trained on the target dataset
base_model_target = BaseModel().to(DEVICE)
# Load the state dictionary from the checkpoint file
state_dict = torch.load(BASE_Target_path, map_location=torch.device(DEVICE))
# Load the state dictionary into the model
base_model_target.load_state_dict(state_dict)
# Put the model into evaluation mode
base_model_target.eval()

# To not show the Base architecture
print("")




# Metric Computation

### Plotting and metric collection functions

We will compare the performance of the best models using the testing accuracy, f1-score, confusion matrix and average AUPRC metrics. 

* Specifically, for the testing accuracy, we will use the balanced accuracy score, which is similar to regular classification accuracy, but it takes into account the frequency of each class. The balanced accuracy score will be computed using on both the real-life dataset and the generated dataset, however, our ultimate goal is to predict labels for real-life images.

* As a reminder, the f1-score is a performance metric that provides a balance between the precision and recall of the model. The f1-score ranges from 0 to 1, with a higher score indicating a better performance of the model. The equation for f1-score is: $$ f1 = \frac{2 (precision * recall)}{precision + recall}$$ where $precision$ is the number of true positives (correctly predicted positive samples) divided by the total number of predicted positive samples, and $recall$ is the number of true positives divided by the total number of actual positive samples. In other words, f1-score takes into account both the model's ability to correctly identify positive samples (recall) and its tendency to not mislabel negative samples as positive (precision). 

* The confusion matrix compares the predicted class labels with the true class labels and counts the number of true positives (TP), true negatives (TN), false positives (FP), and false negatives (FN) for each class. It can help identify which classes the model is having difficulty classifying correctly, and whihc classes it can predict with ease.

* Finally, the AUPRC measures the overall quality of the model's predictions, taking into account both precision and recall across all possible classification thresholds. A high average AUPRC indicates that the model has a good balance of precision and recall across all possible classification thresholds, and is able to accurately distinguish between positive and negative cases. Conversely, a low average AUPRC indicates poor performance, and suggests that the model is making many incorrect predictions or missing many true positive cases.








We first define the balanced accuracy function as follows:

In [65]:
# Define the balanced accuracy function
accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=13, average="weighted").to(DEVICE)

To compare the final testing accuracies of each model (DANN, CORAL, and Base model), we evaluate their performance on both the real-life dataset and the generated dataset. Additionally, we analyze the metrics with and without oversampling. For this purpose, we use a function that takes the model(s), batch size, dataset and oversampling boolean as input to compute the testing accuracy, f1-score, confusion matrix and average AUPRC. 

It's important to note that the most important metrics we will use for comparison are the one on the real-life dataset without oversampling (so a classic distribution of chess boards).

In [66]:
# Define the testing accuracy evaluation function
def evaluate(model, batch_size, dataset, model2=None, oversampling=True, coral=False):
    # We use our custom dataset and loader defined in the "Datasets" folder.
    test_dataset = CustomDataset(dataset, "test", balance=oversampling)
    loader = get_loader(test_dataset, batch_size=batch_size, balance=oversampling)

    # Makes sur emodel is in evaluation mode
    model.eval()

    # Set accumulated accuracy to 0
    acc = 0

    y_true = [] # Ground truth labels
    y_hat = [] # Predicted labels

    # Remove grad
    with torch.no_grad():
        for images, labels in loader:
            # Move the data to the device
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            if coral:
                # Forward pass
                _, y_pred_prob = model(images)
            else:
                # Forward pass
                y_pred_prob = model(images)

            # If there is a second model 
            # (for DANN we first go through the Feature extractor and then the Classifier)
            if model2:
                y_pred_prob = model2(y_pred_prob)

            y_pred = torch.argmax(y_pred_prob, dim=1)

            # Compute the metrics
            acc +=  accuracy(y_pred, labels)

            y_true.extend(labels.cpu().numpy())
            y_hat.extend(y_pred.cpu().numpy())

        # Compute the average accuracy
        final_accuracy = acc / len(loader)
        
        # Compute the f1 score
        f1 = f1_score(y_true, y_hat, average='weighted')

        # Compute the confusion matrix
        cm = confusion_matrix(y_true, y_hat, normalize='true')

        y_true = np.asarray(y_true)
        y_hat = np.asarray(y_hat)

        n_classes = 13
        from sklearn.preprocessing import LabelBinarizer
        lb = LabelBinarizer()
        y_true_binary = lb.fit_transform(y_true)
        y_hat_binary = lb.fit_transform(y_hat)

        precision = dict()
        recall = dict()
        for i in range(n_classes):
            precision[i], recall[i], _ = precision_recall_curve(y_true_binary[:, i], y_hat_binary[:, i])
            
        auprc = dict()
        for i in range(n_classes):
            auprc[i] = auc(recall[i], precision[i])

        average_auprc = np.mean(list(auprc.values()))

    # Return all metrics computed
    return final_accuracy, f1, cm, average_auprc


We define a last function "get_metrics" to obtain all the metrics we defined above for all models depending on the dataset and if the data is oversampled or not.

In [67]:
# Fix the batch size to 32 for all accuracy computations.
batch_size = 32

In [68]:
'''
Function to get the testing accuracy, f1 score, confusion matrix and average AUPRC for all DANN, CORAL and Base models.
It takes as input:
- the dataset to evaluate the metrics on ("Real Life" or "Generated")
- if oversampling must be performed
'''
def get_metrics(dataset, oversampling):
    # Create empty dictionaries to store the results in
    # We will save these dictionaries in JSON files
    accuracies = {}
    f1s = {}
    auprcs = {}

    # We will not save the confusion matrix as JSON, but we will save it as a numpy array
    # We will direclty save it using matplotlib
    confusion_matrices = np.empty([0]) 
    
    # Evaluate the DANN model
    accuracy, f1, conf, auprc= evaluate(DANN_F, batch_size, dataset, DANN_C, oversampling=oversampling)
    # Append results
    accuracies["DANN"] = accuracy.item()
    f1s["DANN"] = f1.item()
    auprcs["DANN"] = auprc
    confusion_matrices = np.append(confusion_matrices, conf)

    # Evaluate the CORAL model
    accuracy, f1, conf, auprc= evaluate(coral_model, batch_size, dataset, oversampling=oversampling, coral=True)
    # Append results
    accuracies["CORAL"] = accuracy.item()
    f1s["CORAL"] = f1.item()
    auprcs["CORAL"] = auprc
    confusion_matrices = np.append(confusion_matrices, conf)

    # Evaluate the Base model trained on the source dataset
    accuracy, f1, conf, auprc = evaluate(base_model_source, batch_size, dataset, oversampling=oversampling)
    # Append results
    accuracies["BASE Source"] = accuracy.item()
    f1s["BASE Source"] = f1.item()
    auprcs["BASE Source"] = auprc
    confusion_matrices = np.append(confusion_matrices, conf)

    # Evaluate the Base model trained on the target dataset
    accuracy, f1, conf, auprc = evaluate(base_model_target, batch_size, dataset, oversampling=oversampling)
    # Append results
    accuracies["BASE Target"] = accuracy.item()
    f1s["BASE Target"] = f1.item()
    auprcs["BASE Target"] = auprc
    confusion_matrices = np.append(confusion_matrices, conf)

    return accuracies, f1s, auprcs, confusion_matrices.reshape(4, 13, 13)

We also specify a function to save the confusion matrices obtained. This function takes as input the matrix to save, the model and dataset it corresponds to and if the data was oversampled or not.

In [69]:
'''
Function to save the confusion matrices as heatmaps
'''
def save_confusion_matrix(confusion_matrix, model, dataset, oversampling):
    # Define class names
    class_names = ['Empty Square', 'White Pawn', 'White Knight', 'White Bishop', 'White Rook', 'White Queen', 'White King', 
                'Black Pawn', 'Black Knight', 'Black Bishop', 'Black Rook', 'Black Queen', 'Black King']
    
    # Define figure and axis size
    fig, ax = plt.subplots(figsize=(8, 8))

    # Create heatmap
    heatmap = ax.imshow(confusion_matrix, cmap='Greens')

    # Add colorbar
    cbar = ax.figure.colorbar(heatmap, ax=ax)

    # Set tick labels for x and y axis
    ax.set_xticks(np.arange(len(class_names)))
    ax.set_yticks(np.arange(len(class_names)))
    ax.set_xticklabels(class_names, fontsize=10)
    ax.set_yticklabels(class_names, fontsize=10)

    # Rotate tick labels and set alignmentto have them appear vertically
    ax.set_xticklabels(class_names, fontsize=10, rotation=90, ha='center')
    plt.setp(ax.get_yticklabels(), rotation=0, ha='right')

    #Loop over data dimensions and create text annotations
    for i in range(len(class_names)):
        for j in range(len(class_names)):
            text = ax.text(j, i, '{:.1f}'.format(confusion_matrix[i, j]*100),
                        ha="center", va="center", color="black", fontsize=8)

    #Set title and axis labels
    if oversampling:
        title = "Confusion Matrix on " + dataset + " with oversampling (in %) - "+model
    else:
        title = "Confusion Matrix on " + dataset + " without oversampling (in %) - "+model
    
    ax.set_title(title, fontsize=12)
    ax.set_xlabel("Predicted label", fontsize=10)
    ax.set_ylabel("True label", fontsize=10)

    
    if oversampling:
        # Save figure
        plt.savefig(f"../{model}/Confusion Matrices/{dataset} Oversample.png", dpi=500)
    else:
        # Save figure
        plt.savefig(f"../{model}/Confusion Matrices/{dataset} No Oversample.png", dpi=500)
    plt.close()

### Metrics computation

This section focuses on computing the metrics we will save as JSON format or plot later (confusion matrices).

In [70]:
# Create Dictionaries to store the results in
accuracies = {}
f1s = {}
auprcs = {}

In [76]:
# Add the JSON files for the real life dataset with oversampling to the dictionaries created above 
# Save the confusion matrices as numpy arrays
accuracies_real_life_oversample, f1s_real_life_oversample, auprcs_real_life_oversample, confusion_matrices_real_life_oversample = get_metrics("Real Life", True)

accuracies["Real Life Oversample"] = accuracies_real_life_oversample
f1s["Real Life Oversample"] = f1s_real_life_oversample
auprcs["Real Life Oversample"] = auprcs_real_life_oversample
save_confusion_matrix(confusion_matrices_real_life_oversample[0], "DANN", "Real Life", True)
save_confusion_matrix(confusion_matrices_real_life_oversample[1], "CORAL", "Real Life", True)
save_confusion_matrix(confusion_matrices_real_life_oversample[2], "BASE Source", "Real Life", True)
save_confusion_matrix(confusion_matrices_real_life_oversample[3], "BASE Target", "Real Life", True)

In [77]:
# Add the JSON files for the real life dataset without oversampling to the dictionaries created above
# Save the confusion matrices as numpy arrays
accuracies_real_life_no_oversample, f1s_real_life_no_oversample, auprcs_real_life_no_oversample, confusion_matrices_real_life_no_oversample = get_metrics("Real Life", False)

accuracies["Real Life No Oversample"] = accuracies_real_life_no_oversample
f1s["Real Life No Oversample"] = f1s_real_life_no_oversample
auprcs["Real Life No Oversample"] = auprcs_real_life_no_oversample
save_confusion_matrix(confusion_matrices_real_life_no_oversample[0], "DANN", "Real Life", False)
save_confusion_matrix(confusion_matrices_real_life_no_oversample[1], "CORAL", "Real Life", False)
save_confusion_matrix(confusion_matrices_real_life_no_oversample[2], "BASE Source", "Real Life", False)
save_confusion_matrix(confusion_matrices_real_life_no_oversample[3], "BASE Target", "Real Life", False)


In [78]:
# Add the JSON files for the generated dataset with oversampling to the dictionaries created above
# Save the confusion matrices as numpy arrays
accuracies_generated_oversample, f1s_generated_oversample, auprcs_generated_oversample, confusion_matrices_generated_oversample= get_metrics("Generated", True)

accuracies["Generated Oversample"] = accuracies_generated_oversample
f1s["Generated Oversample"] = f1s_generated_oversample
auprcs["Generated Oversample"] = auprcs_generated_oversample
save_confusion_matrix(confusion_matrices_generated_oversample[0], "DANN", "Generated", True)
save_confusion_matrix(confusion_matrices_generated_oversample[1], "CORAL", "Generated", True)
save_confusion_matrix(confusion_matrices_generated_oversample[2], "BASE Source", "Generated", True)
save_confusion_matrix(confusion_matrices_generated_oversample[3], "BASE Target", "Generated", True)

In [79]:
# Add the JSON files for the generated dataset without oversampling to the dictionaries created above
# Save the confusion matrices as numpy arrays
accuracies_generated_no_oversample, f1s_generated_no_oversample, auprcs_generated_no_oversample, confusion_matrices_generated_no_oversample = get_metrics("Generated", False)

accuracies["Generated No Oversample"] = accuracies_generated_no_oversample
f1s["Generated No Oversample"] = f1s_generated_no_oversample
auprcs["Generated No Oversample"] = auprcs_generated_no_oversample
save_confusion_matrix(confusion_matrices_generated_no_oversample[0], "DANN", "Generated", False)
save_confusion_matrix(confusion_matrices_generated_no_oversample[1], "CORAL", "Generated", False)
save_confusion_matrix(confusion_matrices_generated_no_oversample[2], "BASE Source", "Generated", False)
save_confusion_matrix(confusion_matrices_generated_no_oversample[3], "BASE Target", "Generated", False)

In [75]:
# Save the dictionaries as JSON files
with open('../Testing Accuracies.json', 'w') as fp:
    json.dump(accuracies, fp)

with open('../Testing F1 Scores.json', 'w') as fp:
    json.dump(f1s, fp)

with open('../Testing AUPRC.json', 'w') as fp:
    json.dump(auprcs, fp)