In [25]:
#%%
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from tqdm import trange
import random

import os

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score, precision_recall_curve, roc_curve, auc

#%%
seed = 12
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

torch.backends.cudnn.deterministic=True

#%%
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using device: ", device)

using device:  cuda:0


In [None]:
#%%
Input_BASE = "tc-hard/embeddings/few-shot/"
EMEBEDS_BASE = "tc-hard/embeddings/few-shot/"
RESULTS_BASE = "tc-hard/reproduce/results/"
DATA_BASE = f"tc-hard/dataset/few_shot_split/pep+cdr3b/"

# %%
def embed_norm(embeddings):
    
    from sklearn.preprocessing import MinMaxScaler
    scaler = MinMaxScaler()
    embeddings = scaler.fit_transform(embeddings)
    
    return embeddings

# %%
def load_embeddings(path, train_df, test_df, split_id = 0):
    tcrb_seq_train = np.load(os.path.join(EMEBEDS_BASE, args.pretrain_name, args.neg_generate_mode, f"train-{split_id}.tcrb.npy"))
    # peptide_seq_train = np.load(os.path.join(EMEBEDS_BASE, args.pretrain_name, args.neg_generate_mode, f"train-{split_id}.peptide.npy"))
    label_seq_train = np.load(os.path.join(EMEBEDS_BASE, args.pretrain_name, args.neg_generate_mode, f"train-{split_id}.label.npy"))
    
    tcrb_seq_test = np.load(os.path.join(EMEBEDS_BASE, args.pretrain_name, args.neg_generate_mode, f"test-{split_id}.tcrb.npy"))
    # peptide_seq_test = np.load(os.path.join(EMEBEDS_BASE, args.pretrain_name, args.neg_generate_mode, f"test-{split_id}.peptide.npy"))
    label_seq_test = np.load(os.path.join(EMEBEDS_BASE, args.pretrain_name, args.neg_generate_mode, f"test-{split_id}.label.npy"))


    peptide_seq_train = np.load(os.path.join(EMEBEDS_BASE, args.pretrain_name, args.neg_generate_mode, f"train-{split_id}.peptide.npy"))
    peptide_seq_test = np.load(os.path.join(EMEBEDS_BASE, args.pretrain_name, args.neg_generate_mode, f"test-{split_id}.peptide.npy"))

    peptide_uniq_train = list(train_df["peptide"].unique())
    peptide_uniq_test = list(test_df["peptide"].unique())
    train_df["peptide_embed"] = train_df["peptide"].apply(lambda x: peptide_seq_train[peptide_uniq_train.index(x)])
    test_df["peptide_embed"] = test_df["peptide"].apply(lambda x: peptide_seq_test[peptide_uniq_test.index(x)])

    peptide_seq_train = train_df["peptide_embed"].values
    peptide_seq_train = np.stack(peptide_seq_train, axis = 0)
    
    peptide_seq_test = test_df["peptide_embed"].values
    peptide_seq_test = np.stack(peptide_seq_test, axis = 0)

    tcrb_seq_train = embed_norm(tcrb_seq_train)
    peptide_seq_train = embed_norm(peptide_seq_train)
    tcrb_seq_test = embed_norm(tcrb_seq_test)
    peptide_seq_test = embed_norm(peptide_seq_test)

    return tcrb_seq_train, peptide_seq_train, label_seq_train, tcrb_seq_test, peptide_seq_test, label_seq_test


In [None]:
# %%
# train data 
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, tcrb_seq, peptide_seq, label_seq) -> None:
        super().__init__()
        # self.tcralpha_seq = torch.from_numpy(np.load(EMEBEDS_BASE + "tcra.npy"))
        # self.tcrbeta_seq = torch.from_numpy(np.load(EMEBEDS_BASE + "tcrb.npy"))
        # if mode == "train":
        self.tcrbeta_seq = torch.from_numpy(tcrb_seq)
        self.peptide_seq = torch.from_numpy(peptide_seq)
        self.label_seq = torch.from_numpy(label_seq).reshape([-1, 1])

        # self.mhc_seq = torch.from_numpy(np.load(EMEBEDS_BASE + "mhc.npy"))
        # self.peptide_seq = torch.from_numpy(np.load(EMEBEDS_BASE + "peptide.npy"))
        # self.label_seq = torch.from_numpy(np.load(EMEBEDS_BASE + "labels.npy")).reshape([-1, 1])
        self.label_seq = self.label_seq.to(dtype = torch.float)
        
    def __len__(self):
        return self.label_seq.shape[0]

    def __getitem__(self, index):
        # return self.tcralpha_seq[index], self.tcrbeta_seq[index], self.mhc_seq[index], self.peptide_seq[index], self.label_seq[index]
        return self.tcrbeta_seq[index], self.peptide_seq[index], self.label_seq[index]

#%%
# import argparse

# parser = argparse.ArgumentParser()
# parser.add_argument("--split_id", type = int, default = 0)
# parser.add_argument("--epochs", type = int, default = 100)
# parser.add_argument("--save", type = bool, default = True)
# parser.add_argument("--lr", type = float, default = 0.02)
# parser.add_argument("--pretrain_name", type = str, default = "moleformer")
# parser.add_argument("--neg_generate_mode", type = str, default = "only-sampled-negs")


# args = parser.parse_args()

from dotmap import DotMap

args = DotMap(
    {
        "split_id": 0,
        "epochs": 20,
#         # "save": True,
#         # "lr": 0.02,
        "pretrain_name": "moleformer",
        "neg_generate_mode": "only-neg-assays"
    }
)

In [28]:
class MLP(nn.Module):
    def __init__(self, input_features, output_features, hidden_features, num_layers):
        super(MLP, self).__init__()
        self.layers = []
        self.layers.append(nn.Linear(input_features, hidden_features))
        self.layers.append(nn.BatchNorm1d(hidden_features))
        self.layers.append(nn.Dropout(0.5))
        self.layers.append(nn.LeakyReLU(0.1))
        for i in range(num_layers - 2):
            self.layers.append(nn.Linear(hidden_features, hidden_features))
            self.layers.append(nn.BatchNorm1d(hidden_features))
            self.layers.append(nn.Dropout(0.5))
            self.layers.append(nn.LeakyReLU(0.1))
        self.layers.append(nn.Linear(hidden_features, output_features))
        self.layers.append(nn.Sigmoid())
        self.layers = nn.ModuleList(self.layers)
    
    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = layer(x)
            # if i == 0:
            #     print(x)
        return x


In [None]:
# %%
loss_fn = nn.BCELoss()  # binary cross entropy
save = False

In [30]:
def train(model, train_loader, epochs = 100):

    loss_list = []
    acc_list = []

    model.train()

    for epoch in trange(epochs):
        loss_batch = []
        acc_batch = []
        for i, batch in enumerate(train_loader):
            # alpha, beta, mhc, peptide, label = batch
            beta, peptide, label = batch
            
            inputs = torch.cat([beta, peptide], dim = 1)
            inputs = inputs.to(device = device)
            label = label.to(device = device)

            output = model(inputs)
            loss = loss_fn(output, label)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pos_ratio = output.round().sum() / len(output)
            acc = (output.round() == label).float().mean()

            # if i % 50 == 0:
            #     print(f"epoch {epoch} batch {i} loss: {loss.item()} acc: {acc.item()} pos_ratio: {pos_ratio.item()}")
            
            loss_batch.append(loss.item())
            acc_batch.append(acc.item())
            # print(f"batch : {i} finished")
        loss_list.append(np.round(np.mean(loss_batch), 3))
        acc_list.append(np.round(np.mean(acc_batch), 3))


        # print(f"epoch {epoch} finished")

    return loss_list, acc_list, model

# %%
# use sklearn to calculate auc-roc, with y_true and y_score
def calculate_auc(y_true, y_score):
    auc_roc = roc_auc_score(y_true, y_score)
    return auc_roc


In [31]:

# %%
metrics = [
    'AUROC',
    'Accuracy',
    'Recall',
    'Precision',
    'F1 score',
    'AUPR'
]

def pr_auc(y_true, y_prob):
    precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
    pr_auc = auc(recall, precision)
    return pr_auc

def get_scores(y_true, y_prob, y_pred):
    """
    Compute a df with all classification metrics and respective scores.
    """
    
    scores = [
        roc_auc_score(y_true, y_prob),
        accuracy_score(y_true, y_pred),
        recall_score(y_true, y_pred),
        precision_score(y_true, y_pred),
        f1_score(y_true, y_pred),
        pr_auc(y_true, y_prob)
    ]
    
    df = pd.DataFrame(data={'score': scores, 'metrics': metrics})
    return df

In [32]:

# %%
def evaluate(model, test_loader):
    loss_batch = []
    acc_batch = []
    model.eval()
    with torch.no_grad():
        for i, batch in enumerate(test_loader):
            # alpha, beta, mhc, peptide, label = batch
            beta, peptide, label = batch
            print("length of test data: ", len(label))
            
            inputs = torch.cat([beta, peptide], dim = 1)
            inputs = inputs.to(device = device)
            label = label.to(device = device)

            output = model(inputs)
            loss = loss_fn(output, label)

            pos_ratio = output.round().sum() / len(output)
            
            acc = (output.round() == label).float().mean()

            # auc_roc = calculate_auc(label.cpu().numpy(), output.cpu().numpy())
            scores_df = get_scores(label.cpu().numpy(), output.cpu().numpy(), output.round().cpu().numpy())
            
            if i % 50 == 0:
                print(f"batch {i} loss: {loss.item()} acc: {acc.item()} pos_ratio: {pos_ratio}")
            
            loss_batch.append(np.round(np.mean(loss.item()), 3))
            acc_batch.append(np.round(np.mean(acc.item()), 3))
            # print(f"batch : {i} finished")
    return loss_batch, acc_batch, scores_df

#%%
def make_df(df_path):
    df = pd.read_csv(df_path)

    # map_keys = {
    # 'cdr3.beta': 'tcrb',
    # 'antigen.epitope': 'peptide',
    # "label": "label"
    # }
    # df = df.rename(columns={c: map_keys[c] for c in df.columns})

    df['tcrb'] = df['tcrb'].str.replace('O','X')
    df['peptide'] = df['peptide'].str.replace('O','X')

    return df

In [None]:
split_id = 0
num_few_shot = 5

# train_df_path = os.path.join(DATA_BASE, "train", args.neg_generate_mode, f"{num_few_shot}-train-{split_id}.csv")
# test_df_path = os.path.join(DATA_BASE, "test", args.neg_generate_mode, f"{num_few_shot}-test-{split_id}.csv")
DATA_BASE = f"tc-hard/dataset/new_split/pep+cdr3b/"

train_df_path = os.path.join(DATA_BASE, "train", args.neg_generate_mode, f"train-{split_id}.csv")
test_df_path = os.path.join(DATA_BASE, "test", args.neg_generate_mode, f"test-{split_id}.csv")

train_df = make_df(train_df_path)
test_df = make_df(test_df_path)

In [60]:
train_df.shape, test_df.shape

((165980, 3), (40480, 3))

In [61]:
all_df = pd.concat([train_df, test_df], axis = 0)
all_df.shape

(206460, 3)

In [58]:
all_df["peptide"].unique().shape

(836,)

In [21]:
from rdkit import Chem
from rdkit.Chem import AllChem
def amino_acid_to_smiles(sequence):
    molecule = Chem.MolFromSequence(sequence)
    smiles = Chem.MolToSmiles(molecule)
    return smiles

In [22]:
all_df["peptide_smiles"] = all_df["peptide"].apply(amino_acid_to_smiles)

In [23]:
unique_tcrb = all_df["tcrb"]
lengths = np.vectorize(len)(unique_tcrb)
mean_length = np.mean(lengths)
print(f"mean length of tcrb: {mean_length}")

unique_tcrb = all_df["peptide"]
lengths = np.vectorize(len)(unique_tcrb)
mean_length = np.mean(lengths)
print(f"mean length of peptide: {mean_length}")

unique_tcrb = all_df["peptide_smiles"]
lengths = np.vectorize(len)(unique_tcrb)
mean_length = np.mean(lengths)
print(f"mean length of peptide: {mean_length}")

mean length of tcrb: 14.079306335981398
mean length of peptide: 9.745989149389652
mean length of peptide: 185.40658302654523


In [53]:
#%%
results_df = []

load_path = os.path.join(EMEBEDS_BASE, args.pretrain_name, args.neg_generate_mode)

num_few_shot = 5

for i in trange(5):
    split_id = i

    train_df_path = os.path.join(DATA_BASE, "train", args.neg_generate_mode, f"{num_few_shot}-train-{split_id}.csv")
    test_df_path = os.path.join(DATA_BASE, "test", args.neg_generate_mode, f"{num_few_shot}-test-{split_id}.csv")

    train_df = make_df(train_df_path)
    test_df = make_df(test_df_path)

    tcrb_seq_train, peptide_seq_train, label_seq_train, tcrb_seq_test, peptide_seq_test, label_seq_test = load_embeddings(load_path, train_df, test_df, split_id)

    train_data = CustomDataset(tcrb_seq_train, peptide_seq_train, label_seq_train)
    test_data = CustomDataset(tcrb_seq_test, peptide_seq_test, label_seq_test)

    train_loader = DataLoader(train_data, batch_size = 128, shuffle = True, drop_last = True)
    test_loader = DataLoader(test_data, batch_size = 128000, shuffle = False, drop_last = False)

    embed_size_tcr = tcrb_seq_train.shape[1]
    embed_size_peptide = peptide_seq_train.shape[1]

    # input: beta + peptide
    model = MLP(input_features = embed_size_tcr + embed_size_peptide, output_features = 1, hidden_features = 100, num_layers = 4)
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr = 0.02)

    loss_list, acc_list, model = train(model, train_loader, args.epochs)
    loss_test, acc_test, scores_df = evaluate(model, test_loader)

    print(f"train loss: {loss_list}, train acc: {acc_list}")
    print(f"test loss: {min(loss_test)}, test acc: {max(acc_test)}")

    scores_df["experiment"] = split_id
    scores_df.to_csv(os.path.join(RESULTS_BASE, args.pretrain_name + "_" + args.neg_generate_mode + f"_{split_id}.csv"), index = False)

    results_df.append(scores_df)

results_df = pd.concat(results_df)
results_df.to_csv(os.path.join(RESULTS_BASE, args.pretrain_name + "_" + args.neg_generate_mode + "_summary.csv"), index = False)

#%%


100%|██████████| 20/20 [01:15<00:00,  3.77s/it]
 20%|██        | 1/5 [01:17<05:09, 77.41s/it]

length of test data:  40480
batch 0 loss: 0.7718985676765442 acc: 0.6464179754257202 pos_ratio: 0.6662796139717102
train loss: [0.286, 0.253, 0.242, 0.233, 0.227, 0.219, 0.219, 0.217, 0.214, 0.213, 0.209, 0.209, 0.207, 0.207, 0.205, 0.205, 0.203, 0.201, 0.199, 0.199], train acc: [0.896, 0.914, 0.921, 0.926, 0.928, 0.93, 0.932, 0.932, 0.934, 0.935, 0.935, 0.936, 0.937, 0.938, 0.938, 0.938, 0.939, 0.939, 0.94, 0.94]
test loss: 0.772, test acc: 0.646


100%|██████████| 20/20 [01:17<00:00,  3.88s/it]
 40%|████      | 2/5 [02:38<03:58, 79.47s/it]

length of test data:  43293
batch 0 loss: 0.6834385395050049 acc: 0.6663202047348022 pos_ratio: 0.7066500186920166
train loss: [0.286, 0.256, 0.24, 0.232, 0.227, 0.225, 0.22, 0.216, 0.215, 0.213, 0.211, 0.208, 0.21, 0.207, 0.205, 0.205, 0.203, 0.204, 0.202, 0.201], train acc: [0.897, 0.913, 0.921, 0.926, 0.928, 0.929, 0.931, 0.932, 0.934, 0.935, 0.934, 0.936, 0.936, 0.937, 0.938, 0.938, 0.938, 0.938, 0.939, 0.939]
test loss: 0.683, test acc: 0.666


100%|██████████| 20/20 [01:15<00:00,  3.79s/it]
 60%|██████    | 3/5 [03:56<02:37, 78.77s/it]

length of test data:  40550
batch 0 loss: 0.9582415819168091 acc: 0.6410604119300842 pos_ratio: 0.6241676807403564
train loss: [0.293, 0.259, 0.248, 0.239, 0.233, 0.229, 0.227, 0.224, 0.22, 0.219, 0.215, 0.217, 0.214, 0.213, 0.211, 0.211, 0.209, 0.208, 0.206, 0.206], train acc: [0.893, 0.908, 0.915, 0.92, 0.922, 0.924, 0.926, 0.927, 0.928, 0.929, 0.931, 0.93, 0.932, 0.932, 0.933, 0.933, 0.934, 0.934, 0.934, 0.935]
test loss: 0.958, test acc: 0.641


100%|██████████| 20/20 [01:16<00:00,  3.84s/it]
 80%|████████  | 4/5 [05:15<01:18, 78.84s/it]

length of test data:  40446
batch 0 loss: 0.7067746520042419 acc: 0.7022697329521179 pos_ratio: 0.7477377653121948
train loss: [0.303, 0.262, 0.245, 0.239, 0.232, 0.227, 0.227, 0.223, 0.219, 0.218, 0.216, 0.214, 0.214, 0.211, 0.21, 0.209, 0.207, 0.205, 0.205, 0.205], train acc: [0.888, 0.91, 0.918, 0.921, 0.924, 0.927, 0.928, 0.929, 0.931, 0.931, 0.933, 0.933, 0.934, 0.935, 0.935, 0.936, 0.936, 0.936, 0.936, 0.937]
test loss: 0.707, test acc: 0.702


100%|██████████| 20/20 [01:15<00:00,  3.78s/it]
100%|██████████| 5/5 [06:32<00:00, 78.59s/it]

length of test data:  40411
batch 0 loss: 0.9571189284324646 acc: 0.600678026676178 pos_ratio: 0.5997129678726196
train loss: [0.301, 0.261, 0.248, 0.237, 0.232, 0.227, 0.229, 0.222, 0.22, 0.217, 0.215, 0.214, 0.213, 0.21, 0.21, 0.208, 0.209, 0.208, 0.207, 0.205], train acc: [0.892, 0.911, 0.918, 0.923, 0.926, 0.928, 0.929, 0.93, 0.931, 0.932, 0.934, 0.934, 0.934, 0.936, 0.936, 0.936, 0.937, 0.937, 0.937, 0.938]
test loss: 0.957, test acc: 0.601



