1. load best model for moleformer, smiles-bert, chemberta;
2. load [5, 10, 15, 20] dataset;
3. run experiments on different setting, and obtain the auc-roc value for each setting;
4. draw the figure of the auc-roc.

In [21]:
import numpy as np
import pandas as pd 

import torch 
import torch.nn as nn

from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import auc, precision_score, recall_score, f1_score, roc_auc_score, accuracy_score, precision_recall_curve

import os

In [None]:

best_model_path = f"vibtcr/data/result/NAbest"
base_path = 'tc-hard/dataset/few_shot_split/pep+cdr3b'
embed_base_path = 'tc-hard/embeddings/few-shot/moleformer' 

DATA_PATH = "tc-hard/dataset/few_shot_split/pep+cdr3b"

DICT_PATH = "tc-hard/meta_data"

negative_generate_mode = "only-neg-assays"

model_name = "moleformer"

In [23]:
#%%
def make_df(df_path):
    df = pd.read_csv(df_path)

    # map_keys = {
    # 'cdr3.beta': 'tcrb',
    # 'antigen.epitope': 'peptide',
    # "label": "label"
    # }
    # df = df.rename(columns={c: map_keys[c] for c in df.columns})

    df['tcrb'] = df['tcrb'].str.replace('O','X')
    df['peptide'] = df['peptide'].str.replace('O','X')

    return df

In [24]:
import pickle

with open(os.path.join(DICT_PATH, model_name, negative_generate_mode, "peptide_dict.pkl"), 'rb') as f:
    peptide_embed_dict = pickle.load(f)

with open(os.path.join(DICT_PATH, model_name, negative_generate_mode, "tcrb_dict.pkl"), 'rb') as f:
    tcrb_embed_dict = pickle.load(f)


In [25]:
def get_embeddings(train_df, validation_df, test_df):
    tcrb_seq_train = np.vstack(train_df['tcrb'].apply(lambda x: tcrb_embed_dict[x]).values)
    tcrb_seq_validation = np.vstack(validation_df['tcrb'].apply(lambda x: tcrb_embed_dict[x]).values)
    tcrb_seq_test = np.vstack(test_df['tcrb'].apply(lambda x: tcrb_embed_dict[x]).values)

    peptide_seq_train = np.vstack(train_df['peptide'].apply(lambda x: peptide_embed_dict[x]).values)
    peptide_seq_validation = np.vstack(validation_df['peptide'].apply(lambda x: peptide_embed_dict[x]).values)
    peptide_seq_test = np.vstack(test_df['peptide'].apply(lambda x: peptide_embed_dict[x]).values)

    label_seq_train = train_df['label'].values
    label_seq_validation = validation_df['label'].values
    label_seq_test = test_df['label'].values

    # X_train = np.column_stack((tcrb_seq_train, peptide_seq_train))
    X_train = np.column_stack((peptide_seq_train, tcrb_seq_train))
    y_train = label_seq_train

    # X_validation = np.column_stack((tcrb_seq_validation, peptide_seq_validation))
    X_validation = np.column_stack((peptide_seq_validation, tcrb_seq_validation))
    y_validation = label_seq_validation

    # X_test = np.column_stack((tcrb_seq_test, peptide_seq_test))
    X_test = np.column_stack((peptide_seq_test, tcrb_seq_test))
    y_test = label_seq_test

    return X_train, y_train, X_validation, y_validation, X_test, y_test

In [26]:
# load model
class MLP(nn.Module):
    def __init__(self, input_size, output_size, hidden_sizes=[512, 512, 512, 256, 256, 256], dropout=0.2):
        super(MLP, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_sizes = hidden_sizes
        self.dropout = dropout
        
        layers = []
        layers.append(nn.Linear(input_size, hidden_sizes[0]))
        layers.append(nn.BatchNorm1d(hidden_sizes[0]))
        layers.append(nn.ReLU())
        layers.append(nn.Dropout(dropout))
        
        for i in range(len(hidden_sizes) - 1):
            layers.append(nn.Linear(hidden_sizes[i], hidden_sizes[i+1]))
            layers.append(nn.BatchNorm1d(hidden_sizes[i+1]))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
        
        layers.append(nn.Linear(hidden_sizes[-1], output_size))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

In [27]:
# model = MLP(input_size = 748 + 500, output_size = 2, hidden_sizes = [32], dropout = 0.3)
# model.load_state_dict(torch.load(os.path.join(model_path, "best_mol-esm_0.pth")))

In [28]:
# train_df_path = os.path.join(DATA_PATH, "train", negative_generate_mode, f"train-{dataset_index}.csv")
# validation_df_path = os.path.join(DATA_PATH, "validation", negative_generate_mode, f"validation-{dataset_index}.csv")
# test_df_path = os.path.join(DATA_PATH, "test", negative_generate_mode, f"test-{dataset_index}.csv")



In [29]:
def evaluate_and_save(model, dataset_index, best_model_path):
    model = MLP(input_size = 480 + 768, output_size = 2, hidden_sizes = [32], dropout = 0.3)
    model.load_state_dict(torch.load(os.path.join(best_model_path, f"best_mol-esm_{dataset_index}.pth")))
    model.to("cuda:0")
    model.eval()
    # with torch.no_grad():
    #     test_probabilities = []
        # y_true_test = []
        # test_running_loss = 0.0
        # for test_inputs, test_targets in test_loader:
        #     test_inputs, test_targets = test_inputs.to(device), test_targets.to(device)
        #     test_outputs = model(test_inputs.float())
        #     test_loss = criterion(test_outputs, test_targets)
        #     test_running_loss += test_loss.item() * test_inputs.size(0)
        #     test_probabilities.extend(torch.softmax(test_outputs, dim=1)[:, 1].cpu().numpy())
            # y_true_test.extend(test_targets.cpu().numpy())

    y_pred = model(torch.from_numpy(X_test).to("cuda:0"))
    test_probabilities = torch.softmax(y_pred, dim=1)[:, 1].detach().cpu().numpy()
    y_true_test = y_test
    # test_loss = test_running_loss / len(test_loader.dataset)    
    test_auc = roc_auc_score(y_true_test, test_probabilities)    
    test_predictions = [1 if prob > 0.5 else 0 for prob in test_probabilities]
    precision, recall, _ = precision_recall_curve(y_true_test, test_probabilities)
    
    metrics = {
            'AUROC': test_auc,
            'Accuracy': accuracy_score(y_true_test, test_predictions),
            'Recall': recall_score(y_true_test, test_predictions),
            'Precision': precision_score(y_true_test, test_predictions),
            'F1 score': f1_score(y_true_test, test_predictions),
            'AUPR': auc(recall, precision),
        }

    result_df = pd.DataFrame({
            'score': list(metrics.values()),
            'metrics': list(metrics.keys()),
            'experiment': dataset_index
        })
    
    print(f"\nBest Model Performance and Evaluation of dataset{dataset_index}:")
    for metric, score in metrics.items():
        print(f"{metric}: {score*100:.4f}%")

    return result_df

In [30]:
from torch import optim
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import ReduceLROnPlateau

criterion = nn.CrossEntropyLoss()
batch_size = 32

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [31]:
def train_model(model, train_loader, val_loader, num_epochs, criterion, optimizer, scheduler, device, dataset_index, output_model_path):
    patience = 20
    counter = 0
    best_val_roc_auc = 0.0
    best_model_path = output_model_path + f"/best_mol-esm_{dataset_index}.pth"
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs.float())
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)  
        model.eval()
        val_running_loss = 0.0
        y_true_val = []
        val_probabilities = [] 
        with torch.no_grad():
            for val_inputs, val_targets in val_loader:
                val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)
                val_outputs = model(val_inputs.float())
                val_loss = criterion(val_outputs, val_targets)
                val_running_loss += val_loss.item() * val_inputs.size(0)       
                val_probabilities.extend(torch.softmax(val_outputs, dim=1)[:, 1].cpu().numpy())
                y_true_val.extend(val_targets.cpu().numpy())  
            val_loss = val_running_loss / len(val_loader.dataset)
            val_auc = roc_auc_score(y_true_val, val_probabilities)
            precision, recall, _ = precision_recall_curve(y_true_val, val_probabilities)
            pr_auc = auc(recall, precision)
        
        print(f'Epoch [{epoch+1}/{num_epochs}], Validation Loss: {val_loss:.4f}, Validation AUC: {val_auc:.4f}, PR-AUC: {pr_auc:.4f}')

        if val_auc > best_val_roc_auc:
            best_val_roc_auc = val_auc
            torch.save(model.state_dict(), best_model_path)
            print("Saved best model")
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                print(f'Early stopping at epoch {epoch}.')
                break        
        scheduler.step(val_auc)       
    print("Training complete. Best Val ROC-AUC: {:.4f}".format(best_val_roc_auc))  
        
    return best_model_path

In [None]:
for num_few_shot in [20, 100, 150, 200, 300]:
    output_model_path = f"analysis/few-shot/{num_few_shot}"

    for dataset_index in range(5):
        if model_name == "moleformer":
            model = MLP(input_size = 480 + 768, output_size = 2, hidden_sizes = [32], dropout = 0.3)
            model.load_state_dict(torch.load(os.path.join(best_model_path, f"best_mol-esm_{dataset_index}.pth")))
        elif model_name == "smiles-bert":
            model = MLP(input_size = 480 + 768, output_size = 2, hidden_sizes = [32], dropout = 0.3)
            model.load_state_dict(torch.load("tc-hard/embeddings/SMILES_BERT/only-neg-assays"))
        elif model_name == "ChemBERTa":
            model = MLP(input_size = 480 + 384, output_size = 2, hidden_sizes = [32], dropout = 0.3)
            model.load_state_dict(torch.load(f"best_mol-esm_{dataset_index}.pth"))
        model = model.to("cuda:0")


        train_df_path = os.path.join(DATA_PATH, "train", negative_generate_mode, f"{num_few_shot}-train-{dataset_index}.csv")
        validation_df_path = os.path.join(DATA_PATH, "validation", negative_generate_mode, f"{num_few_shot}-validation-{dataset_index}.csv")
        test_df_path = os.path.join(DATA_PATH, "test", negative_generate_mode, f"{num_few_shot}-test-{dataset_index}.csv")

        train_df = make_df(train_df_path)
        validation_df = make_df(validation_df_path)
        test_df = make_df(test_df_path)

        X_train, y_train, X_validation, y_validation, X_test, y_test = get_embeddings(train_df, validation_df, test_df)
            
        train_dataset = TensorDataset(torch.from_numpy(X_train), torch.tensor(y_train))
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
        val_dataset = TensorDataset(torch.from_numpy(X_validation), torch.tensor(y_validation))
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        test_dataset = TensorDataset(torch.from_numpy(X_test), torch.tensor(y_test))
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)    
        
        optimizer = optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.001)  
        scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.1, verbose=True)
        
        train_model(model, train_loader, val_loader, 10, criterion, optimizer, scheduler, device, dataset_index, output_model_path)

Epoch [1/10], Validation Loss: 0.7311, Validation AUC: 0.9595, PR-AUC: 0.9739
Saved best model
Epoch [2/10], Validation Loss: 0.8589, Validation AUC: 0.9717, PR-AUC: 0.9852
Saved best model
Epoch [3/10], Validation Loss: 0.8440, Validation AUC: 0.9696, PR-AUC: 0.9820
Epoch [4/10], Validation Loss: 1.1644, Validation AUC: 0.9595, PR-AUC: 0.9791
Epoch [5/10], Validation Loss: 0.9477, Validation AUC: 0.9531, PR-AUC: 0.9752
Epoch [6/10], Validation Loss: 1.0631, Validation AUC: 0.9642, PR-AUC: 0.9820
Epoch [7/10], Validation Loss: 1.5956, Validation AUC: 0.9586, PR-AUC: 0.9809
Epoch [8/10], Validation Loss: 1.3291, Validation AUC: 0.9605, PR-AUC: 0.9808
Epoch [9/10], Validation Loss: 0.9895, Validation AUC: 0.9578, PR-AUC: 0.9781
Epoch [10/10], Validation Loss: 1.4836, Validation AUC: 0.9488, PR-AUC: 0.9760
Training complete. Best Val ROC-AUC: 0.9717
Epoch [1/10], Validation Loss: 1.4017, Validation AUC: 0.9511, PR-AUC: 0.9805
Saved best model
Epoch [2/10], Validation Loss: 1.4619, Validat

In [None]:
result_df = pd.DataFrame({
    "score": [],
    "metrics": [],
    "experiment": [] 
})
l = []
for num_few_shot in [20, 100, 150, 200, 300]:
    output_model_path = f"analysis/few-shot/{num_few_shot}"
    for dataset_index in range(5):
        l.append(evaluate_and_save(model, dataset_index, output_model_path))

    result_df = pd.concat(l)
    # print(result_df)
    result_df.to_csv(f"few-shot-result/{num_few_shot}_result.csv")
    
    stats = result_df.groupby('metrics')['score'].agg(['mean', 'std']).reset_index()
    stats.to_csv(f"few-shot-result/{num_few_shot}_stats.csv")


Best Model Performance and Evaluation of dataset0:
AUROC: 98.2542%
Accuracy: 99.8827%
Recall: 99.9923%
Precision: 99.8903%
F1 score: 99.9413%
AUPR: 99.9981%

Best Model Performance and Evaluation of dataset1:
AUROC: 98.1221%
Accuracy: 99.8903%
Recall: 100.0000%
Precision: 99.8903%
F1 score: 99.9451%
AUPR: 99.9976%

Best Model Performance and Evaluation of dataset2:
AUROC: 90.7713%
Accuracy: 99.8852%
Recall: 99.9949%
Precision: 99.8903%
F1 score: 99.9426%
AUPR: 99.9893%

Best Model Performance and Evaluation of dataset3:
AUROC: 98.1205%
Accuracy: 99.8903%
Recall: 100.0000%
Precision: 99.8903%
F1 score: 99.9451%
AUPR: 99.9979%

Best Model Performance and Evaluation of dataset4:
AUROC: 93.0345%
Accuracy: 99.8291%
Recall: 99.9387%
Precision: 99.8903%
F1 score: 99.9145%
AUPR: 99.9920%

Best Model Performance and Evaluation of dataset0:
AUROC: 96.4357%
Accuracy: 99.8903%
Recall: 100.0000%
Precision: 99.8903%
F1 score: 99.9451%
AUPR: 99.9960%

Best Model Performance and Evaluation of dataset

In [77]:
# Group data by 'metrics' and calculate the mean and standard deviation for 'score'


Unnamed: 0,metrics,mean,std
0,AUPR,0.999916,8.294838e-05
1,AUROC,0.929772,0.06655524
2,Accuracy,0.99888,5.651287e-05
3,F1 score,0.99944,2.828822e-05
4,Precision,0.998934,6.022991e-08
5,Recall,0.999945,5.657316e-05


0: 0.901

20% seen

1: 0.924796

2: 0.948553

3: 0.947340

4: 0.944642

5: 0.907792 / 0.931580

10: 0.907764 / 0.933113

15: 0.907747 / 0.929772

20: 0.907721 / 0.956605

100: 0.907400 / 0.945554

150: 0.907151 / 0.952375

200: 0.906876 / 0.952898

300: 0.904357 / 0.956077