In [None]:
import os
import time
from tqdm import tqdm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('seaborn')

import torch
from torch.utils.data import Dataset
from torch import nn
from torch.utils.data import random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.classification import MultilabelF1Score
from torchmetrics.classification import MultilabelAccuracy

from transformers import BertModel, BertTokenizer

torch.cuda.get_device_name(torch.cuda.device)

"""
1. Test different loss functions (ambrose, better weights)
2. Test different models (like Temporal CNN, bigger linear model (keeping track of hyperparameters)
    https://unit8.com/resources/temporal-convolutional-networks-and-forecasting/
3. Implement CAFA-Evaluator for better metrics
4. Use more GOs in predictions
5. Read Kaggle notebooks online to gain intuition
6. Use new data!
7. Using description of each GO for making predictions rather than considering them as labels
"""

In [None]:
MAIN_DIR = "data/"
WORK_DIR = "working/"
DATA_DIR = MAIN_DIR + "cafa-5-protein-function-prediction"
PROTBERT_DIR = MAIN_DIR + "protbert-embeddings-for-cafa5"

for dirname, _, filenames in os.walk(MAIN_DIR):
    for filename in filenames:
        print(os.path.join(dirname, filename))

In [None]:
# load a sample submission for Kaggle competition
submission = pd.read_csv(f'{DATA_DIR}/sample_submission.tsv', sep='\t', header=None)
submission.columns = ["ProteinID", "GO_ID", "Probability"]
submission.head(10)

In [None]:
# define important configurations of the code
class config:
    train_sequences_path = DATA_DIR  + "/Train/train_sequences.fasta"
    train_labels_path = DATA_DIR + "/Train/train_terms.tsv"
    test_sequences_path = DATA_DIR + "/Test (Targets)/testsuperset.fasta"

    num_labels = 500
    n_epochs = 10
    batch_size = 128
    lr = 0.002

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Device: {device} - {torch.cuda.get_device_name(device)}')

In [None]:
# # ______________________ GET PROT BERT EMBEDDINGS WITH HUGGING FACE __________________________________
#
# # PROT BERT LOADING :
# tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
# model = BertModel.from_pretrained("Rostlab/prot_bert").to(config.device)
#
# def get_bert_embedding(
#     sequence : str,
#     len_seq_limit : int
# ):
#     """
#     Function to collect last hidden state embedding vector from pre-trained ProtBERT Model
#
#     INPUTS:
#     - sequence (str) : protein sequence (ex : AAABBB) from fasta file
#     - len_seq_limit (int) : maximum sequence lenght (i.e nb of letters) for truncation
#
#     OUTPUTS:
#     - output_hidden : last hidden state embedding vector for input sequence of length 1024
#     """
#     sequence_w_spaces = ' '.join(list(sequence))
#     encoded_input = tokenizer(
#         sequence_w_spaces,
#         truncation=True,
#         max_length=len_seq_limit,
#         padding='max_length',
#         return_tensors='pt').to(config.device)
#     output = model(**encoded_input)
#     output_hidden = output['last_hidden_state'][:,0][0].detach().cpu().numpy()
#     assert len(output_hidden)==1024
#     return output_hidden
#
# ### COLLECTING FOR TRAIN SAMPLES :
# print("Loading train set ProtBERT Embeddings...")
# fasta_train = SeqIO.parse(config.train_sequences_path, "fasta")
#
# print("Total Nb of Elements : ", len(list(fasta_train)))
# fasta_train = SeqIO.parse(config.train_sequences_path, "fasta")
#
# ids_list = []
# embed_vects_list = []
# t0 = time.time()
# checkpoint = 0
#
# for item in tqdm(fasta_train):
#     ids_list.append(item.id)
#     embed_vects_list.append(
#         get_bert_embedding(sequence = item.seq, len_seq_limit = 1200))
#     checkpoint+=1
#
#     if checkpoint>=100:
#         df_res = pd.DataFrame(data={"id" : ids_list, "embed_vect" : embed_vects_list})
#         np.save('/kaggle/working/train_ids.npy',np.array(ids_list))
#         np.save('/kaggle/working/train_embeddings.npy',np.array(embed_vects_list))
#         checkpoint=0
#
# np.save('/kaggle/working/train_ids.npy',np.array(ids_list))
# np.save('/kaggle/working/train_embeddings.npy',np.array(embed_vects_list))
# print('Total Elapsed Time:',time.time()-t0)
#
# ### COLLECTING FOR TEST SAMPLES :
# print("Loading test set ProtBERT Embeddings...")
# fasta_test = SeqIO.parse(config.test_sequences_path, "fasta")
# print("Total Nb of Elements : ", len(list(fasta_test)))
# fasta_test = SeqIO.parse(config.test_sequences_path, "fasta")
# ids_list = []
# embed_vects_list = []
# t0 = time.time()
# checkpoint=0
# for item in tqdm(fasta_test):
#     ids_list.append(item.id)
#     embed_vects_list.append(
#         get_bert_embedding(sequence = item.seq, len_seq_limit = 1200))
#     checkpoint+=1
#     if checkpoint>=100:
#         np.save('/kaggle/working/test_ids.npy',np.array(ids_list))
#         np.save('/kaggle/working/test_embeddings.npy',np.array(embed_vects_list))
#         checkpoint=0
#
# np.save('/kaggle/working/test_ids.npy',np.array(ids_list))
# np.save('/kaggle/working/test_embeddings.npy',np.array(embed_vects_list))
# print('Total Elasped Time:',time.time()-t0)

In [None]:
##### SCRIPT FOR LABELS (TARGETS) VECTORS COLLECTING #####

print(f"GENERATE TARGETS FOR ENTRY IDS ({config.num_labels} MOST COMMON GO TERMS)")
ids = np.load(f"{PROTBERT_DIR}/train_ids.npy")
labels = pd.read_csv(config.train_labels_path, sep = "\t")

top_terms = labels.groupby("term")["EntryID"].count().sort_values(ascending=False)
labels_names = top_terms[:config.num_labels].index.values
train_labels_sub = labels[(labels.term.isin(labels_names)) & (labels.EntryID.isin(ids))]
id_labels = train_labels_sub.groupby('EntryID')['term'].apply(list).to_dict()

go_terms_map = {label: i for i, label in enumerate(labels_names)}
labels_matrix = np.empty((len(ids), len(labels_names)))

for index, id in tqdm(enumerate(ids)):
    id_gos_list = id_labels[id]
    temp = [go_terms_map[go] for go in labels_names if go in id_gos_list]
    labels_matrix[index, temp] = 1

labels_list = []
for l in range(labels_matrix.shape[0]):
    labels_list.append(labels_matrix[l, :])

labels_df = pd.DataFrame(data={"EntryID":ids, "labels_vect":labels_list})
labels_df.to_pickle(f"{WORK_DIR}/train_targets_top"+str(config.num_labels)+".pkl")
print("GENERATION FINISHED!")
labels_df.head(5)

In [None]:
GO_weight_dataset = pd.read_table(f'{DATA_DIR}/IA.txt', header=None, names=['GO', 'weight'])
GO_weight_dataset

In [None]:
# load GO_weights (IA data) as a tensor to feed into the loss function

GO_weights = []
for each_label in labels_names:
    GO_weights.append(GO_weight_dataset.loc[GO_weight_dataset['GO'] == each_label]['weight'].values[0])

GO_weights = torch.tensor(GO_weights, dtype=torch.float32)
GO_weights

In [None]:
# Directories for the different embedding vectors :
embeds_map = {
    "T5" : "t5embeds",
    "ProtBERT" : "protbert-embeddings-for-cafa5",
    "ESM2" : "cafa-5-esm-2-embeddings-numpy"
}

# Length of the different embedding vectors :
embeds_dim = {
    "T5" : 1024,
    "ProtBERT" : 1024,
    "ESM2" : 1280
}

In [None]:
class ProteinSequenceDataset(Dataset):
    """
    Custom dataset to store embeddings of different sources
    It could be used to get training or test dataset
    """

    def __init__(self, datatype, embeddings_source):
        super(ProteinSequenceDataset).__init__()
        self.datatype = datatype

        if embeddings_source in ["ProtBERT", "ESM2"]:
            embeds = np.load(f"{MAIN_DIR}"+embeds_map[embeddings_source]+"/"+datatype+"_embeddings.npy")
            ids = np.load(f"{MAIN_DIR}/"+embeds_map[embeddings_source]+"/"+datatype+"_ids.npy")

        if embeddings_source == "T5":
            embeds = np.load(f"{MAIN_DIR}/"+embeds_map[embeddings_source]+"/"+datatype+"_embeds.npy")
            ids = np.load(f"{MAIN_DIR}/"+embeds_map[embeddings_source]+"/"+datatype+"_ids.npy")

        embeds_list = []
        for l in range(embeds.shape[0]):
            embeds_list.append(embeds[l,:])
        self.df = pd.DataFrame(data={"EntryID": ids, "embed" : embeds_list})

        if datatype=="train":
            df_labels = pd.read_pickle(
                f"{WORK_DIR}/train_targets_top"+str(config.num_labels)+".pkl")
            self.df = self.df.merge(df_labels, on="EntryID")\

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        embed = torch.tensor(self.df.iloc[index]["embed"], dtype=torch.float32)

        if self.datatype=="train":
            targets = torch.tensor(self.df.iloc[index]["labels_vect"], dtype=torch.float32)
            return embed, targets

        if self.datatype=="test":
            id = self.df.iloc[index]["EntryID"]
            return embed, id


dataset = ProteinSequenceDataset(datatype="train", embeddings_source="T5")
dataset.df.head(10)

In [None]:
embeddings, labels = dataset.__getitem__(0)
print("COMPONENTS FOR FIRST PROTEIN:  ")
print("EMBEDDINGS VECTOR: \n ", embeddings, "\n")
print("TARGETS LABELS VECTOR: \n ", labels, "\n")

In [None]:
class MultiLayerPerceptron(nn.Module):
    """
    Baseline MLP model to make predictions using CLS token embeddings
    """

    def __init__(self, input_dim, num_classes):
        super(MultiLayerPerceptron, self).__init__()

        self.linear1 = torch.nn.Linear(input_dim, input_dim)
        self.activation1 = torch.nn.ReLU()
        self.linear1 = torch.nn.Linear(input_dim, 1000)
        self.activation1 = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(1000, 800)
        self.activation2 = torch.nn.ReLU()
        self.linear3 = torch.nn.Linear(800, num_classes)

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation1(x)
        x = self.linear2(x)
        x = self.activation2(x)
        x = self.linear3(x)
        return x

In [None]:
class CNN1D(nn.Module):
    """
    Baseline CNN-1D model to make predictions using CLS token embeddings
    """

    def __init__(self, input_dim, num_classes):
        super(CNN1D, self).__init__()
        # (batch_size, channels, embed_size)
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=3, kernel_size=3, dilation=1, padding=1, stride=1)
        # (batch_size, 3, embed_size)
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
        # (batch_size, 3, embed_size/2 = 512)
        self.conv2 = nn.Conv1d(in_channels=3, out_channels=8, kernel_size=3, dilation=1, padding=1, stride=1)
        # (batch_size, 8, embed_size/2 = 512)
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
        # (batch_size, 8, embed_size/4 = 256)
        self.fc1 = nn.Linear(in_features=int(8 * input_dim/4), out_features=1024)       # 1024 is better
        self.fc2 = nn.Linear(in_features=1024, out_features=num_classes)                # 1024 is better

    def forward(self, x):
        x = x.reshape(x.shape[0], 1, x.shape[1])
        x = self.pool1(nn.functional.relu(self.conv1(x)))
        x = self.pool2(nn.functional.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
def train_model(embeddings_source, model_type, train_size=0.9):
    """
    Custom function to train the baseline model on dataset
    :param embeddings_source: define the type of embedding
    :param model_type: define the type of model
    :param train_size: define the training portion ratio
    """

    train_dataset = ProteinSequenceDataset(datatype="train", embeddings_source=embeddings_source)
    train_set, val_set = random_split(train_dataset, lengths = [int(len(train_dataset)*train_size), len(train_dataset)-int(len(train_dataset)*train_size)])

    train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=config.batch_size, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_set, batch_size=config.batch_size, shuffle=True)

    if model_type == "linear":
        model = MultiLayerPerceptron(input_dim=embeds_dim[embeddings_source], num_classes=config.num_labels)

    if model_type == "conv":
        model = CNN1D(input_dim=embeds_dim[embeddings_source], num_classes=config.num_labels)

    model.to(config.device)

    # define configurations of the model
    optimizer = torch.optim.Adam(model.parameters(), lr = config.lr)
    scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=1)

    # multilabel prediction task, GO_weights could be used as weight in the loss function
    # GO_weights.to(config.device)
    MultiLabelLoss = torch.nn.BCEWithLogitsLoss()
    f1_score = MultilabelF1Score(num_labels=config.num_labels).to(config.device)
    n_epochs = config.n_epochs

    print("BEGIN TRAINING...")
    train_loss_history, val_loss_history = [], []
    train_f1score_history, val_f1score_history = [], []

    for epoch in range(n_epochs):
        print("EPOCH ", epoch+1)

        ## TRAIN PHASE :
        model.train()
        losses, scores = [], []

        for embed, targets in tqdm(train_dataloader):
            embed, targets = embed.to(config.device), targets.to(config.device)
            preds = model(embed)

            loss = MultiLabelLoss(preds, targets)
            score = f1_score(preds, targets)
            losses.append(loss.item())
            scores.append(score.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        avg_loss = np.mean(losses)
        avg_score = np.mean(scores)
        print("Running Average TRAIN Loss : ", avg_loss)
        print("Running Average TRAIN F1-Score : ", avg_score)
        train_loss_history.append(avg_loss)
        train_f1score_history.append(avg_score)

        ## VALIDATION PHASE :
        model.eval()
        losses, scores = [], []

        for embed, targets in val_dataloader:
            embed, targets = embed.to(config.device), targets.to(config.device),
            preds = model(embed)

            loss = MultiLabelLoss(preds, targets)
            score = f1_score(preds, targets)
            losses.append(loss.item())
            scores.append(score.item())

        avg_loss, avg_score = np.mean(losses), np.mean(scores)
        print("Running Average VAL Loss : ", avg_loss)
        print("Running Average VAL F1-Score : ", avg_score)

        val_loss_history.append(avg_loss)
        val_f1score_history.append(avg_score)

        scheduler.step(avg_loss), print("\n")


    print("TRAINING FINISHED _____")
    print("FINAL TRAINING SCORE : ", train_f1score_history[-1])
    print("FINAL VALIDATION SCORE : ", val_f1score_history[-1])

    losses_history = {"train" : train_loss_history, "val" : val_loss_history}
    scores_history = {"train" : train_f1score_history, "val" : val_f1score_history}

    return model, losses_history, scores_history

In [None]:
t5_model, t5_losses, t5_scores = train_model(embeddings_source="T5", model_type="linear")

In [None]:
t5_model(dataset[0][0].reshape(1, -1).to(config.device))

In [None]:
def predict(embeddings_source):
    """
    Custom function to make inference using the model
    :param embeddings_source: define the type of embedding
    """

    test_dataset = ProteinSequenceDataset(datatype="test", embeddings_source = embeddings_source)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

    if embeddings_source == "T5":
        model = t5_model
    if embeddings_source == "ProtBERT":
        model = protbert_model
    if embeddings_source == "ESM2":
        model = esm2_model

    # Set model on evaluation mode
    model.eval()

    labels = pd.read_csv(config.train_labels_path, sep = "\t")
    top_terms = labels.groupby("term")["EntryID"].count().sort_values(ascending=False)
    labels_names = top_terms[:config.num_labels].index.values
    print("GENERATE PREDICTION FOR TEST SET...")

    ids_ = np.empty(shape=(len(test_dataloader)*config.num_labels,), dtype=object)
    go_terms_ = np.empty(shape=(len(test_dataloader)*config.num_labels,), dtype=object)
    confs_ = np.empty(shape=(len(test_dataloader)*config.num_labels,), dtype=np.float32)

    for i, (embed, id) in tqdm(enumerate(test_dataloader)):
        embed = embed.to(config.device)
        confs_[i*config.num_labels:(i+1)*config.num_labels] = torch.nn.functional.sigmoid(model(embed)).squeeze().detach().cpu().numpy()
        ids_[i*config.num_labels:(i+1)*config.num_labels] = id[0]
        go_terms_[i*config.num_labels:(i+1)*config.num_labels] = labels_names

    submission_df = pd.DataFrame(data={"Id" : ids_, "GO term" : go_terms_, "Confidence" : confs_})
    print("PREDICTIONS DONE")
    return submission_df

In [None]:
submission_df = predict("T5")
submission_df.head(50)

In [None]:
class Linear_Lightning(pl.LightningModule):
    """
    In progress, used to train the MLP model on multiple GPUs using PyTorchLightning
    """

    def __init__(self, input_dim, num_classes, train_size, **hparams):
        super(Linear_Lightning, self).__init__()

        self.model = MultiLayerPerceptron(input_dim=embeds_dim[embeddings_source], num_classes=config.num_labels).to(config.device)

        train_dataset = ProteinSequenceDataset(datatype="train", embeddings_source = embeddings_source)
        self.train_set, self.val_set = random_split(train_dataset, lengths = [int(len(train_dataset)*train_size), len(train_dataset)-int(len(train_dataset)*train_size)])

        self.loss_fn = torch.nn.BCEWithLogitsLoss()
        self.batch_size = batch_size
        self.lr = lr

        self.f1_score = MultilabelF1Score(num_labels=num_classes)
        self.accuracy = MultilabelAccuracy(num_labels=num_classes)


    def forward(self, x):
        return self.model(x)


    def training_step(self, batch, batch_idx):
        embed, targets = batch
        preds = self(embed)
        loss = self.loss_fn(preds, targets)
        f1_score = self.f1_score(preds, targets)
        acc_score = self.accuracy(preds, targets)

        logs = {"train_loss" : loss, "f1_score" : f1_score, "accuracy_score" : acc_score}
        self.log_dict(
            logs,
            on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        return {"loss": loss, "log": logs}


    def validation_step(self, batch, batch_idx):
        embed, targets = batch
        preds = self(embed)
        loss= self.loss_fn(preds, targets)
        f1_score = self.f1_score(preds, targets)
        acc_score = self.accuracy(preds, targets)

        return {"val_loss": loss, "f1_score": f1_score, "accuracy_score": acc_score}


    def validation_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in ouputs]).mean()
        logs = {"val_loss" : avg_loss}
        self.log_dict(
            logs,
            on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        return {"avg_val_loss": avg_loss, "log": logs}


    def val_dataloader(self):
        val_dataloader = torch.utils.data.DataLoader(self.val_set, batch_size=config.batch_size, shuffle=False,)
        return val_dataloader


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer


    def train_dataloader(self):
        train_dataloader = torch.utils.data.DataLoader(self.train_set, batch_size=self.batch_size, shuffle=False)
        return train_dataloader

In [None]:
### IN PROGRESS - SCRIPT TO TRAIN THE MODEL USING PyTorchLightning ###
trainer = Trainer(
    max_epochs=config.n_epochs,
    limit_train_batches=5000,
    logger=logger)

model = Linear_Lightning(
    input_dim=embeds_dim[embeddings_source],
    num_classes=config.num_labels,
    train_size=0.8
)

trainer.fit(model)