In [19]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, roc_auc_score

In [None]:
# Path to your directory containing the `.pt` files
dir_path = 'test_pred_wo_norm'

# List all .pt files in the directory
file_names = [f for f in os.listdir(dir_path) if f.endswith('.pt')]

# Create subplots: one for each file
n_files = len(file_names)
fig, axes = plt.subplots(n_files, 1, figsize=(8, 6 * n_files))
auc_o1_values = []

if n_files == 1:  # If only one subplot, axes is not an array, so handle it
    axes = [axes]

# Loop through each .pt file and plot its ROC curve on separate plots
for i, filename in enumerate(file_names):
    # Load the tensor from the .pt file
    file_path = os.path.join(dir_path, filename)
    tensor_data = torch.load(file_path)

    # Extract ground truth labels and logits
    y_true = tensor_data[0].numpy()   # True labels
    y_scores = tensor_data[1].numpy() # Logits

    # Compute FPR, TPR, and thresholds
    fpr, tpr, thresholds = roc_curve(y_true, y_scores)

    # Compute AUC
    roc_auc = auc(fpr, tpr)

    # Get indices where FPR <= 0.1 for AUC@0.1
    indices = np.where(fpr <= 0.1)

    # Compute AUC@0.1
    auc_01 = roc_auc_score(y_true, y_scores, max_fpr=0.1)
    auc_o1_values.append(auc_01)

    # Plot ROC Curve
    axes[i].plot(fpr, tpr, color='blue', label=f'ROC Curve (AUC = {roc_auc:.2f})')
    axes[i].plot(fpr[indices], tpr[indices], color='red', label=f'AUC@0.1 = {auc_01:.2f}', linewidth=2)

    # Plot the random classifier line (diagonal)
    axes[i].plot([0, 1], [0, 1], color='gray', linestyle='--')

    # Add plot details
    axes[i].set_xlim([0.0, 1.0])
    axes[i].set_ylim([0.0, 1.05])
    axes[i].set_xlabel('False Positive Rate (FPR)')
    axes[i].set_ylabel('True Positive Rate (TPR)')
    axes[i].set_title(f'ROC Curve for {filename}')
    axes[i].legend(loc='lower right')

# Adjust layout for the subplots
plt.tight_layout()

# Show the plot
# plt.show()
print('Done')

In [30]:
sorted(auc_o1_values, reverse= True)

[np.float64(0.6737204772416725),
 np.float64(0.6713430950148429),
 np.float64(0.6706612920768655),
 np.float64(0.6699439788771142),
 np.float64(0.66942455930993),
 np.float64(0.6686191585007775),
 np.float64(0.6684097769794705),
 np.float64(0.668059988708685),
 np.float64(0.6675468807767053),
 np.float64(0.6675456049589485),
 np.float64(0.6671500030991558),
 np.float64(0.6658444012760609),
 np.float64(0.66468220565954),
 np.float64(0.6642367908965253),
 np.float64(0.6631645318863221),
 np.float64(0.6622639895623168),
 np.float64(0.6619299006395754),
 np.float64(0.6582282674844013),
 np.float64(0.6579397345531256),
 np.float64(0.6549907621417046)]

# AUC 0.1 values for model using TCRlang (480 embed dim)

[np.float64(0.6683042738018465),
 np.float64(0.6668380618733383),
 np.float64(0.6659118519818539),
 np.float64(0.6658859431522504),
 np.float64(0.6655036516908618),
 np.float64(0.6652624853119795),
 np.float64(0.6651911355558772),
 np.float64(0.6641593177109503),
 np.float64(0.6633553181755469),
 np.float64(0.6628706073694876),
 np.float64(0.6605613004198885),
 np.float64(0.6605488772135663),
 np.float64(0.6599074709024423),
 np.float64(0.6592128670726143),
 np.float64(0.6581775356204888),
 np.float64(0.6563516187612273),
 np.float64(0.65600887160367),
 np.float64(0.6551784665030181),
 np.float64(0.6540362472917294)]

 # AUC 0.1 values for model using blosum50 encodings (20 embed dim)

## Baseline - without sample weighting

[np.float64(0.7101537576934895),
 np.float64(0.7099087925029443),
 np.float64(0.709272574589124),
 np.float64(0.7048349248247264),
 np.float64(0.7027083727997101),
 np.float64(0.7020391136934334),
 np.float64(0.7017375713959785),
 np.float64(0.7016576228024789),
 np.float64(0.6993577680798362),
 np.float64(0.6989029447975592),
 np.float64(0.6970407018071376),
 np.float64(0.6955426045396471),
 np.float64(0.6955346679103585),
 np.float64(0.6943553063330273),
 np.float64(0.6929753339970508),
 np.float64(0.6918515680771008),
 np.float64(0.6918510234129736),
 np.float64(0.6917744483633941),
 np.float64(0.6915950212326408),
 np.float64(0.6915813814851173)]

 ## With sample weighting

 [np.float64(0.7116509137234577),
 np.float64(0.7067853191654011),
 np.float64(0.7064464394691559),
 np.float64(0.7064341112873894),
 np.float64(0.706083841163931),
 np.float64(0.7044213084767847),
 np.float64(0.7030464701933742),
 np.float64(0.7029717203178603),
 np.float64(0.7023464340943544),
 np.float64(0.7021310802737446),
 np.float64(0.7019421254275555),
 np.float64(0.6977761839936111),
 np.float64(0.6975282055094117),
 np.float64(0.6971194276102374),
 np.float64(0.6954749144898721),
 np.float64(0.6946872984888932),
 np.float64(0.6939193432754462),
 np.float64(0.6923446249312903),
 np.float64(0.6919774063445733),
 np.float64(0.6894434134451888)]
 