In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import gdown
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_auc_score
scaler = torch.amp.GradScaler('cuda')
from sklearn.metrics import accuracy_score, roc_curve, precision_recall_fscore_support, roc_auc_score, f1_score, precision_score, recall_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
file_id = "file_id"
gdown.download(f"https://drive.google.com/uc?id={file_id}", "test_dissimilarity_dataset_vit.parquet", quiet=False)

file_id = "1HvTVqpN5PNSdDZkroQPp-8Yjg-tNRVdt"
gdown.download(f"https://drive.google.com/uc?id={file_id}", "pairnet_model.pth", quiet=False)

In [None]:
class PairNet(nn.Module):
    def __init__(self):
        super(PairNet, self).__init__()
        self.fc1 = nn.Linear(6, 4)
        self.fc2 = nn.Linear(4, 2)
        self.out = nn.Linear(2, 1)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.out(x)

In [None]:
model = PairNet()
checkpoint = torch.load("pairnet_model.pth", weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

best_threshold = checkpoint["best_threshold"]

df_test = pd.read_parquet("test_dissimilarity_dataset_vit.parquet")
X_test = df_test[["d_1N_1F", "d_2N_2F", "d_1N_2N", "d_1N_2F", "d_1F_2N", "d_1F_2F"]].values
y_test = df_test["label"].values

X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32).unsqueeze(1)

with torch.no_grad():
    logits = model(X_test_tensor)
    test_probs = torch.sigmoid(logits).numpy().flatten()
    test_preds = (test_probs > best_threshold).astype(int)
    test_true = y_test_tensor.numpy().flatten()

def compute_far_frr(y_true, y_pred):
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[1, 0]).ravel()
    far = fp / (fp + tn) if (fp + tn) > 0 else 0.0
    frr = fn / (fn + tp) if (fn + tp) > 0 else 0.0
    return far, frr

threshold = checkpoint["best_threshold"]
test_preds = (test_probs >= threshold).astype(int)
acc  = accuracy_score(test_true, test_preds)
prec, rec, f1, _ = precision_recall_fscore_support(test_true, test_preds, average="binary")
auc_roc  = roc_auc_score(test_true, test_probs)
fpr, tpr, thresholds = roc_curve(test_true, test_probs)
fnr = 1 - tpr
eer_threshold_idx = np.argmin(np.abs(fpr - fnr))
eer = (fpr[eer_threshold_idx] + fnr[eer_threshold_idx]) / 2
eer_threshold = thresholds[eer_threshold_idx]
far, frr = compute_far_frr(test_true, test_preds)

print(f"Threshold          : {threshold}")
print(f"ROC-AUC            : {auc_roc:.4f}")
print(f"Accuracy           : {acc:.4f}")
print(f"Precision          : {prec:.4f}")
print(f"Recall             : {rec:.4f}")
print(f"F1-score           : {f1:.4f}")
print(f"EER                : {eer:.4f}")
print(f"FAR                : {far:.4f}")
print(f"FRR                : {frr:.4f}")
print()