In [1]:
import torch

from torch.utils.data import random_split, Dataset, DataLoader

from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

from model import build_transformer

torch.set_default_device("cuda")

In [None]:
model = build_transformer(dropout=0.1,
						source_vocab_size=8000+1, target_vocab_size=292, context_length=900 - 3 + 1,
						encoder_block_count=6,
						encoder_self_attention_head_count=8,
						encoder_self_attention_abstraction_coef=0.15,
						encoder_feed_forward_abstraction_coef=4,
						dim=256, epsilon=1e-9)

checkpoint = torch.load("weights/tr_model_10", weights_only=True)
model.load_state_dict(checkpoint["state"])

In [None]:
X = torch.load('X.pt', weights_only=True).int().to("cuda")
L = torch.load('L.pt', weights_only=True).int().to("cuda")
Y = torch.load('Y.pt', weights_only=True).half().to("cuda")

train_ratio = 0.9
train_size = int(len(X) * train_ratio)
test_size = len(X) - train_size

generator = torch.Generator(device="cuda").manual_seed(42)
X_train, X_test = random_split(X, [train_size, test_size], generator)
L_train, L_test = random_split(L, [train_size, test_size], generator)
Y_train, Y_test = random_split(Y, [train_size, test_size], generator)

print(len(X_train), len(X_test))
print(len(L_train), len(L_test))
print(len(Y_train), len(Y_test))

In [7]:
thresholds = torch.load('T.pt', weights_only=True).int().to("cuda")

In [8]:
class CustomDataset(Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]


dataset = CustomDataset(X_test, Y_test)
loader = DataLoader(dataset, batch_size=64, shuffle=True, generator=torch.Generator(device='cuda'))

In [11]:
def testing_loop(model, data, thresholds, num_classes=292):
    loss_fn = torch.nn.BCEWithLogitsLoss()  # BCE avec logits, donc pas besoin de Sigmoid séparé
    loss_data = {"num_loss": 0, "sum": 0}

    confusion_matrix = torch.zeros((num_classes, 2, 2), dtype=torch.int32)  # Matrice de confusion par classe

    with torch.no_grad():
        model.eval()
        batch_iterator = tqdm(data, desc=f"Processing batches")

        for x, y in batch_iterator:
            pred = model(x, mask=None)

            # Calcul de la perte
            loss = loss_fn(pred, y)
            batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})

            # Calcul de la perte totale
            loss_data["num_loss"] += 1
            loss_data["sum"] += loss.item()

            # Appliquer une sigmoïde pour obtenir les probabilités
            pred_probs = torch.sigmoid(pred)

            # Calcul des prédictions
            pred_labels = (pred_probs > thresholds).int()

            # Calcul des statistiques
            for i in range(num_classes):
                tp = ((pred_labels[:, i] == 1) & (y[:, i] == 1)).sum()
                tn = ((pred_labels[:, i] == 0) & (y[:, i] == 0)).sum()
                fp = ((pred_labels[:, i] == 1) & (y[:, i] == 0)).sum()
                fn = ((pred_labels[:, i] == 0) & (y[:, i] == 1)).sum()

                confusion_matrix[i, 0, 0] += tn
                confusion_matrix[i, 0, 1] += fp
                confusion_matrix[i, 1, 0] += fn
                confusion_matrix[i, 1, 1] += tp
    
    print(f"Loss: {loss_data['sum'] / loss_data['num_loss']}")
    return confusion_matrix

In [None]:
results = testing_loop(model, loader, thresholds)

In [None]:
def plot_confusion_stats(confusion_matrix_total, num_classes=292):
	# Initialisation des listes pour stocker les pourcentages
	detection_accuracy = []
	identification_accuracy = []
	overall_accuracy = []

	# Parcourir chaque classe pour extraire les métriques
	for i in range(num_classes):
		TN = confusion_matrix_total[i, 0, 0].item()
		FP = confusion_matrix_total[i, 0, 1].item()
		FN = confusion_matrix_total[i, 1, 0].item()
		TP = confusion_matrix_total[i, 1, 1].item()

		# Pourcentage d'identifications réussies (VP / (VP + FN))
		if (TP + FN) > 0:
			detection_accuracy.append(TP / (TP + FN))
		else:
			detection_accuracy.append(0.0)

		# Pourcentage de bonnes identifications (VP / (VP + FP))
		if (TP + FP) > 0:
			identification_accuracy.append(TP / (TP + FP))
		else:
			identification_accuracy.append(0.0)

		# Pourcentage de bonne décision globale ((VP + VN) / (VP + VN + FP + FN))
		total = TP + TN + FP + FN
		if total > 0:
			overall_accuracy.append((TP + TN) / total)
		else:
			overall_accuracy.append(0.0)

	# Convertir les listes en tableaux numpy pour un traitement facile
	detection_accuracy = np.array(detection_accuracy)
	identification_accuracy = np.array(identification_accuracy)
	overall_accuracy = np.array(overall_accuracy)

	# Affichage des statistiques
	print(f"Pourcentage moyen d'identifications réussies: {detection_accuracy.mean() * 100:.2f}%")
	print(f"Pourcentage moyen de bonnes identifications: {identification_accuracy.mean() * 100:.2f}%")
	print(f"Pourcentage moyen de bonne décision globale: {overall_accuracy.mean() * 100:.2f}%")

	print(f"Pourcentage médian d'identifications réussies: {np.median(detection_accuracy) * 100:.2f}%")
	print(f"Pourcentage médian de bonnes identifications: {np.median(identification_accuracy) * 100:.2f}%")
	print(f"Pourcentage médian de bonne décision globale: {np.median(overall_accuracy) * 100:.2f}%")

	print(f"Pire classe pour le pourcentage d'identifications réussies: {detection_accuracy.min() * 100:.2f}%")
	print(f"Pire classe pour le pourcentage de bonnes identifications: {identification_accuracy.min() * 100:.2f}%")
	print(f"Pire classe pour le pourcentage de bonne décision globale: {overall_accuracy.min() * 100:.2f}%")

	print(f"Meilleure classe pour le pourcentage d'identifications réussies: {detection_accuracy.max() * 100:.2f}%")
	print(f"Meilleure classe pour le pourcentage de bonnes identifications: {identification_accuracy.max() * 100:.2f}%")
	print(f"Meilleure classe pour le pourcentage de bonne décision globale: {overall_accuracy.max() * 100:.2f}%")

	# Création des barplots
	classes = np.arange(num_classes)

	# Plot 1: Pourcentage d'identifications réussies
	plt.figure(figsize=(10, 6))
	plt.bar(classes, detection_accuracy, color='b')
	plt.title('Pourcentage d\'identifications réussies (VP / (VP + FN))')
	plt.xlabel('Classes')
	plt.ylabel('Pourcentage')
	plt.ylim(0, 1)  # Les pourcentages sont entre 0 et 1
	plt.show()

	# Plot 2: Pourcentage de bonnes identifications
	plt.figure(figsize=(10, 6))
	plt.bar(classes, identification_accuracy, color='g')
	plt.title('Pourcentage de bonnes identifications (VP / (VP + FP))')
	plt.xlabel('Classes')
	plt.ylabel('Pourcentage')
	plt.ylim(0, 1)
	plt.show()

	# Plot 3: Pourcentage de bonne décision globale
	plt.figure(figsize=(10, 6))
	plt.bar(classes, overall_accuracy, color='r')
	plt.title('Pourcentage de bonne décision ((VP + VN) / (VP + VN + FP + FN))')
	plt.xlabel('Classes')
	plt.ylabel('Pourcentage')
	plt.ylim(0, 1)
	plt.show()
    
plot_confusion_stats(results)