In [6]:
import os
import glob
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from sklearn.metrics import average_precision_score
from tqdm import tqdm

DATA_DIR = "data"
MODEL_SAVE_PATH = os.path.join(DATA_DIR, "esmc_bilstm_best.pth")

BATCH_SIZE = 32
LEARNING_RATE = 1e-3
EPOCHS = 20
HIDDEN_DIM = 32
DROPOUT = 0.1
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ASMEmbeddingsDataset(Dataset):
    def __init__(self, root_dir, split='train'):
        self.data = []
        self.labels = []
        categories = {f'positive_{split}': 1.0, f'negative_{split}': 0.0}
        
        for folder_name, label_val in categories.items():
            dir_path = os.path.join(root_dir, folder_name)
            if not os.path.exists(dir_path):
                print(f"Missing folder: {dir_path}")
                continue
            
            shard_files = glob.glob(os.path.join(dir_path, "*.pt"))
            for p in tqdm(shard_files, desc=folder_name):
                try:
                    shard_content = torch.load(p, map_location='cpu')
                    for _, embedding in shard_content.items():
                        self.data.append(embedding.float())
                        self.labels.append(label_val)
                except Exception as e:
                    print(f"Error during loading file {p}: {e}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], torch.tensor(self.labels[idx], dtype=torch.float32)

def collate_fn(batch):
    embeddings, labels = zip(*batch)
    padded_embeddings = pad_sequence(embeddings, batch_first=True, padding_value=0.0)
    labels = torch.stack(labels)
    return padded_embeddings, labels

class ASMDetectorLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, dropout=0.1):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_dim * 2, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        _, (hn, _) = self.lstm(x)
        final_embedding = torch.cat((hn[-2], hn[-1]), dim=1)
        x = self.dropout(final_embedding)
        return self.sigmoid(self.fc(x)).squeeze(-1)

train_dataset = ASMEmbeddingsDataset(DATA_DIR, split='train')
val_dataset = ASMEmbeddingsDataset(DATA_DIR, split='val')
test_dataset = ASMEmbeddingsDataset(DATA_DIR, split='test')

if len(train_dataset) > 0:
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn) if len(test_dataset) > 0 else None
    INPUT_DIM = train_dataset[0][0].shape[-1]
    
    model = ASMDetectorLSTM(INPUT_DIM, HIDDEN_DIM, DROPOUT).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.BCELoss()

    best_val_ap = 0.0

    for epoch in range(EPOCHS):
        model.train()
        train_loss = 0.0
        for X, y in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            X, y = X.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            loss = criterion(model(X), y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        model.eval()
        val_preds, val_targets = [], []
        with torch.no_grad():
            for X, y in val_loader:
                preds = model(X.to(DEVICE)).cpu().numpy()
                val_preds.extend(preds)
                val_targets.extend(y.numpy())
        
        val_ap = average_precision_score(val_targets, val_preds)
        print(f"Epoch {epoch+1} | Loss: {train_loss/len(train_loader):.4f} | Val AP: {val_ap:.4f}")

        if val_ap > best_val_ap:
            best_val_ap = val_ap
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
            print(f"--> Model saved (AP: {best_val_ap:.4f})")

positive_train: 100%|██████████| 3/3 [00:00<00:00, 11.26it/s]
negative_train: 100%|██████████| 9/9 [00:00<00:00, 10.50it/s]
positive_val: 100%|██████████| 2/2 [00:00<00:00, 12.59it/s]
negative_val: 100%|██████████| 2/2 [00:00<00:00, 12.99it/s]
positive_test: 100%|██████████| 1/1 [00:00<00:00, 44.38it/s]
negative_test: 100%|██████████| 2/2 [00:00<00:00, 12.63it/s]
Epoch 1: 100%|██████████| 169/169 [00:04<00:00, 38.04it/s]


Epoch 1 | Loss: 0.2095 | Val AP: 0.9897
--> Model saved (AP: 0.9897)


Epoch 2: 100%|██████████| 169/169 [00:04<00:00, 34.18it/s]


Epoch 2 | Loss: 0.0157 | Val AP: 0.9913
--> Model saved (AP: 0.9913)


Epoch 3: 100%|██████████| 169/169 [00:04<00:00, 40.46it/s]


Epoch 3 | Loss: 0.0065 | Val AP: 0.9887


Epoch 4: 100%|██████████| 169/169 [00:04<00:00, 40.28it/s]


Epoch 4 | Loss: 0.0028 | Val AP: 0.9906


Epoch 5: 100%|██████████| 169/169 [00:04<00:00, 37.42it/s]


Epoch 5 | Loss: 0.0016 | Val AP: 0.9911


Epoch 6: 100%|██████████| 169/169 [00:04<00:00, 35.90it/s]


Epoch 6 | Loss: 0.0010 | Val AP: 0.9924
--> Model saved (AP: 0.9924)


Epoch 7: 100%|██████████| 169/169 [00:05<00:00, 30.00it/s]


Epoch 7 | Loss: 0.0006 | Val AP: 0.9934
--> Model saved (AP: 0.9934)


Epoch 8: 100%|██████████| 169/169 [00:05<00:00, 33.16it/s]


Epoch 8 | Loss: 0.0005 | Val AP: 0.9920


Epoch 9: 100%|██████████| 169/169 [00:05<00:00, 29.91it/s]


Epoch 9 | Loss: 0.0004 | Val AP: 0.9915


Epoch 10: 100%|██████████| 169/169 [00:04<00:00, 34.29it/s]


Epoch 10 | Loss: 0.0003 | Val AP: 0.9914


Epoch 11: 100%|██████████| 169/169 [00:04<00:00, 35.10it/s]


Epoch 11 | Loss: 0.0002 | Val AP: 0.9925


Epoch 12: 100%|██████████| 169/169 [00:04<00:00, 35.53it/s]


Epoch 12 | Loss: 0.0002 | Val AP: 0.9920


Epoch 13: 100%|██████████| 169/169 [00:04<00:00, 35.28it/s]


Epoch 13 | Loss: 0.0002 | Val AP: 0.9925


Epoch 14: 100%|██████████| 169/169 [00:03<00:00, 45.15it/s]


Epoch 14 | Loss: 0.0001 | Val AP: 0.9917


Epoch 15: 100%|██████████| 169/169 [00:04<00:00, 40.57it/s]


Epoch 15 | Loss: 0.0001 | Val AP: 0.9930


Epoch 16: 100%|██████████| 169/169 [00:04<00:00, 33.93it/s]


Epoch 16 | Loss: 0.0001 | Val AP: 0.9929


Epoch 17: 100%|██████████| 169/169 [00:05<00:00, 28.17it/s]


Epoch 17 | Loss: 0.0001 | Val AP: 0.9933


Epoch 18: 100%|██████████| 169/169 [00:04<00:00, 37.07it/s]


Epoch 18 | Loss: 0.0001 | Val AP: 0.9908


Epoch 19: 100%|██████████| 169/169 [00:04<00:00, 38.62it/s]


Epoch 19 | Loss: 0.0004 | Val AP: 0.9944
--> Model saved (AP: 0.9944)


Epoch 20: 100%|██████████| 169/169 [00:04<00:00, 37.31it/s]


Epoch 20 | Loss: 0.0001 | Val AP: 0.9956
--> Model saved (AP: 0.9956)


In [7]:
import numpy as np

def evaluate_at_fixed_fpr(model, dataloader, target_fpr=1e-3):
    model.eval()
    device = next(model.parameters()).device
    
    all_scores = []
    all_targets = []
    
    with torch.no_grad():
        for X, y in dataloader:
            scores = model(X.to(device)).cpu().numpy()
            targets = y.numpy()
            all_scores.extend(scores)
            all_targets.extend(targets)
            
    all_scores = np.array(all_scores)
    all_targets = np.array(all_targets)
    
    desc_score_indices = np.argsort(all_scores)[::-1]
    all_scores = all_scores[desc_score_indices]
    all_targets = all_targets[desc_score_indices]
    
    negatives = (all_targets == 0)
    positives = (all_targets == 1)
    
    n_neg = negatives.sum()
    n_pos = positives.sum()
    
    fps = np.cumsum(negatives)
    tps = np.cumsum(positives)
    
    calculated_fprs = fps / n_neg
    
    valid_indices = np.where(calculated_fprs <= target_fpr)[0]
    
    if len(valid_indices) == 0:
        print(f"Failed to achieve FPR <= {target_fpr}. The lowest possible FPR is {calculated_fprs[0]:.4f}")
        return
        
    cutoff_idx = valid_indices[-1]
    best_threshold = all_scores[cutoff_idx]
    
    recall = tps[cutoff_idx] / n_pos
    actual_fpr = calculated_fprs[cutoff_idx]
    
    print(f"Target: FPR < {target_fpr} ({target_fpr*100}%)\n")
    print(f"Threshold:    {best_threshold:.6f}")
    print(f"Recall (TPR): {recall*100:.2f}%")
    print(f"FPR:          {actual_fpr*100:.4f}%")
    print(f"TP / Pos:     {int(tps[cutoff_idx])} / {n_pos}")
    print(f"FP / Neg:     {int(fps[cutoff_idx])} / {n_neg}")

if 'model' in locals() and 'test_loader' in locals() and test_loader is not None:
    model.load_state_dict(torch.load(MODEL_SAVE_PATH))
    evaluate_at_fixed_fpr(model, test_loader, target_fpr=0.001)
else:
    print("Skipping evaluation: test_loader not defined or empty.")

Target: FPR < 0.001 (0.1%)

Threshold:    0.001710
Recall (TPR): 80.11%
FPR:          0.0000%
TP / Pos:     145 / 181
FP / Neg:     0 / 874
