In [1]:
import torch

from torch.utils.data import random_split, Dataset, DataLoader

import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve

from model import build_transformer

torch.set_default_device("cuda")

In [None]:
model = build_transformer(dropout=0.1,
						source_vocab_size=8000+1, target_vocab_size=292, context_length=900 - 3 + 1,
						decoder_block_count=6,
						decoder_self_attention_head_count=8,
						decoder_self_attention_abstraction_coef=0.15,
						decoder_feed_forward_abstraction_coef=4,
						dim=256, epsilon=1e-9)

checkpoint = torch.load("weights/tr_model_10", weights_only=True)
model.load_state_dict(checkpoint["state"])

In [None]:
X = torch.load('X.pt', weights_only=True).int().to("cuda")
L = torch.load('L.pt', weights_only=True).int().to("cuda")
Y = torch.load('Y.pt', weights_only=True).half().to("cuda")

train_ratio = 0.9
train_size = int(len(X) * train_ratio)
test_size = len(X) - train_size

generator = torch.Generator(device="cuda").manual_seed(42)
X_train, X_test = random_split(X, [train_size, test_size], generator)
L_train, L_test = random_split(L, [train_size, test_size], generator)
Y_train, Y_test = random_split(Y, [train_size, test_size], generator)

print(len(X_train), len(X_test))
print(len(L_train), len(L_test))
print(len(Y_train), len(Y_test))

In [4]:
class CustomDataset(Dataset):
    def __init__(self, X, L, Y):
        self.X = X
        self.L = L
        self.Y = Y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.L[idx], self.Y[idx]


dataset = CustomDataset(X_train, L_train, Y_train)
loader = DataLoader(dataset, batch_size=64, shuffle=True, generator=torch.Generator(device='cuda'))

In [None]:
num_samples = len(dataset)
num_classes = 292
context_length = 900 - 3 + 1

preds = torch.empty(num_samples, num_classes, dtype=torch.float)
real = torch.empty(num_samples, num_classes, dtype=torch.float)

masks = torch.ones((context_length, 1, context_length, context_length))
for i in range(context_length):
	masks[i, :, :, i:] = 0

with torch.no_grad():
	model.eval()

	count = 0
	batch_iterator = tqdm(loader, desc=f"Processing batches")
	for x, l, y in batch_iterator:
		m = masks[l - 1]
		pred = model(x, mask=m)

		preds[count:count + len(x)] = pred
		real[count:count + len(x)] = y

		count += len(x)

In [6]:
# Save the predictions
torch.save(preds, "P.pt")

In [None]:
preds = preds.cpu().numpy()
real = real.cpu().numpy()
real = real.astype(np.int32)

meilleurs_seuils = []
f1_scores = []

# Compute the best threshold for each class
for c in range(num_classes):
    valeurs_c = preds[:, c]
    etiquettes_c = real[:, c]
    precisions, rappels, seuils = precision_recall_curve(etiquettes_c, valeurs_c)
    scores_f1 = 2 * (precisions[:-1] * rappels[:-1]) / (precisions[:-1] + rappels[:-1] + 1e-8)
    indice_max = np.argmax(scores_f1)
    meilleur_seuil = seuils[indice_max]
    meilleurs_seuils.append(meilleur_seuil)
    f1_scores.append(scores_f1[indice_max])

print(meilleurs_seuils[0:10])
print(f1_scores[0:10])

In [13]:
optimal_thresholds = np.array(meilleurs_seuils)

In [14]:
# Save the thresholds as a tensor
torch.save(torch.tensor(optimal_thresholds), "T.pt")

In [None]:
# Compute for each class (TP / (TP + FN)), (TP / (TP + FP)), ((VP + VN) / (VP + VN + FP + FN))

precision = np.zeros(num_classes)
recall = np.zeros(num_classes)
accuracy = np.zeros(num_classes)

for i in range(num_classes):
	pred_binary = (preds[:, i] >= optimal_thresholds[i]).astype(int)
	TP = np.sum(pred_binary & real[:, i])
	FP = np.sum(pred_binary & (1 - real[:, i]))
	FN = np.sum((1 - pred_binary) & real[:, i])
	TN = np.sum((1 - pred_binary) & (1 - real[:, i]))

	precision[i] = TP / (TP + FP)
	recall[i] = TP / (TP + FN)
	accuracy[i] = (TP + TN) / (TP + TN + FP + FN)

print("Average precision:", np.mean(precision))
print("Average recall:", np.mean(recall))
print("Average accuracy:", np.mean(accuracy))

print("Maximal precision:", np.max(precision))
print("Maximal recall:", np.max(recall))
print("Maximal accuracy:", np.max(accuracy))

print("Minimal precision:", np.min(precision))
print("Minimal recall:", np.min(recall))
print("Minimal accuracy:", np.min(accuracy))

In [None]:
# For the first class, plot the distribution of the predictions for the positive and negative cases and the chosen threshold

n = 42

plt.hist(preds[real[:, n] == 1, 0], bins=100, alpha=0.5, label='Positive', density=True)
plt.hist(preds[real[:, n] == 0, 0], bins=100, alpha=0.5, label='Negative', density=True)
plt.axvline(x=optimal_thresholds[0], color='r', linestyle='dashed', linewidth=2)
plt.legend()
plt.show()