In [1]:
import os
import time
from tqdm import tqdm

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch.nn.functional import sigmoid
from torch.utils.data import random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau, ExponentialLR
from torchmetrics.classification import MultilabelF1Score
from torchmetrics.classification import MultilabelAccuracy

from sklearn.metrics import f1_score, roc_auc_score, accuracy_score, balanced_accuracy_score
from scipy.sparse import csr_matrix

from transformers import BertModel, BertTokenizer

plt.style.use('seaborn')
torch.cuda.get_device_name(torch.cuda.device)

"""
1. Test different loss functions (ambrose, better weights)
2. Test different models (like Temporal CNN, bigger linear model (keeping track of hyperparameters)
    https://unit8.com/resources/temporal-convolutional-networks-and-forecasting/
3. Implement CAFA-Evaluator for better metrics
4. Use more GOs in predictions
5. Read Kaggle notebooks online to gain intuition
6. Use new data!
7. Using description of each GO for making predictions rather than considering them as labels
8. Implement winning Kaggle models
9. Debug current code
10. Use taxonomy (one-hot encoded, embedded)

*** Add more information from the Kaggle + Article stuff to the powerpoint
*** Create a schema of what model we want to create
"""

# esm2_t33_650M_UR50D

  plt.style.use('seaborn')


'\n1. Test different loss functions (ambrose, better weights)\n2. Test different models (like Temporal CNN, bigger linear model (keeping track of hyperparameters)\n    https://unit8.com/resources/temporal-convolutional-networks-and-forecasting/\n3. Implement CAFA-Evaluator for better metrics\n4. Use more GOs in predictions\n5. Read Kaggle notebooks online to gain intuition\n6. Use new data!\n7. Using description of each GO for making predictions rather than considering them as labels\n8. Implement winning Kaggle models\n9. Debug current code\n10. Use taxonomy (one-hot encoded, embedded)\n\n*** Add more information from the Kaggle + Article stuff to the powerpoint\n*** Create a schema of what model we want to create\n'

In [2]:
MAIN_DIR = "data"
WORK_DIR = "working"
DATA_DIR = MAIN_DIR + "/cafa-5-protein-function-prediction"
PROTBERT_DIR = MAIN_DIR + "/protbert-embeddings-for-cafa5"
ESM2_DIR = MAIN_DIR + "/cafa-5-esm-2-embeddings-numpy"

for dirname, _, filenames in os.walk(MAIN_DIR):
    for filename in filenames:
        print(os.path.join(dirname, filename))

data\cafa-5-esm-2-embeddings-numpy\go_weights_10000.pt
data\cafa-5-esm-2-embeddings-numpy\test_embeddings.npy
data\cafa-5-esm-2-embeddings-numpy\test_ids.npy
data\cafa-5-esm-2-embeddings-numpy\train_embeddings.npy
data\cafa-5-esm-2-embeddings-numpy\train_ids.npy
data\cafa-5-esm-2-embeddings-numpy\train_targets_top10000.pkl
data\cafa-5-esm-2-embeddings-numpy\train_targets_top5000.pkl
data\cafa-5-protein-function-prediction\IA.txt
data\cafa-5-protein-function-prediction\sample_submission.tsv
data\cafa-5-protein-function-prediction\Test (Targets)\testsuperset-taxon-list.tsv
data\cafa-5-protein-function-prediction\Test (Targets)\testsuperset.fasta
data\cafa-5-protein-function-prediction\Train\go-basic.obo
data\cafa-5-protein-function-prediction\Train\train_sequences.fasta
data\cafa-5-protein-function-prediction\Train\train_taxonomy.tsv
data\cafa-5-protein-function-prediction\Train\train_terms.tsv
data\protbert-embeddings-for-cafa5\test_embeddings.npy
data\protbert-embeddings-for-cafa5\test

In [3]:
# load a sample submission for Kaggle competition
submission = pd.read_csv(f'{DATA_DIR}/sample_submission.tsv', sep='\t', header=None)
submission.columns = ["ProteinID", "GO_ID", "Probability"]
submission.head(10)

Unnamed: 0,ProteinID,GO_ID,Probability
0,A0A0A0MRZ7,GO:0000001,0.123
1,A0A0A0MRZ7,GO:0000002,0.123
2,A0A0A0MRZ8,GO:0000001,0.123
3,A0A0A0MRZ8,GO:0000002,0.123
4,A0A0A0MRZ9,GO:0000001,0.123
5,A0A0A0MRZ9,GO:0000002,0.123
6,A0A0A0MS00,GO:0000001,0.123
7,A0A0A0MS00,GO:0000002,0.123
8,A0A0A0MS01,GO:0000001,0.123
9,A0A0A0MS01,GO:0000002,0.123


In [4]:
# define important configurations of the code
class config:
    train_sequences_path = DATA_DIR  + "/Train/train_sequences.fasta"
    train_labels_path = DATA_DIR + "/Train/train_terms.tsv"
    test_sequences_path = DATA_DIR + "/Test (Targets)/testsuperset.fasta"

    num_labels = 10000
    n_epochs = 10
    batch_size = 128
    lr = 0.001
    gamma = 0.7

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Device: {device} - {torch.cuda.get_device_name(device)}')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Device: cuda - NVIDIA GeForce GTX 1650 Ti


In [5]:
# # ______________________ GET PROT BERT EMBEDDINGS WITH HUGGING FACE __________________________________
#
# # PROT BERT LOADING :
# tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
# model = BertModel.from_pretrained("Rostlab/prot_bert").to(config.device)
#
# def get_bert_embedding(
#     sequence : str,
#     len_seq_limit : int
# ):
#     """
#     Function to collect last hidden state embedding vector from pre-trained ProtBERT Model
#
#     INPUTS:
#     - sequence (str) : protein sequence (ex : AAABBB) from fasta file
#     - len_seq_limit (int) : maximum sequence lenght (i.e nb of letters) for truncation
#
#     OUTPUTS:
#     - output_hidden : last hidden state embedding vector for input sequence of length 1024
#     """
#     sequence_w_spaces = ' '.join(list(sequence))
#     encoded_input = tokenizer(
#         sequence_w_spaces,
#         truncation=True,
#         max_length=len_seq_limit,
#         padding='max_length',
#         return_tensors='pt').to(config.device)
#     output = model(**encoded_input)
#     output_hidden = output['last_hidden_state'][:,0][0].detach().cpu().numpy()
#     assert len(output_hidden)==1024
#     return output_hidden
#
# ### COLLECTING FOR TRAIN SAMPLES :
# print("Loading train set ProtBERT Embeddings...")
# fasta_train = SeqIO.parse(config.train_sequences_path, "fasta")
#
# print("Total Nb of Elements : ", len(list(fasta_train)))
# fasta_train = SeqIO.parse(config.train_sequences_path, "fasta")
#
# ids_list = []
# embed_vects_list = []
# t0 = time.time()
# checkpoint = 0
#
# for item in tqdm(fasta_train):
#     ids_list.append(item.id)
#     embed_vects_list.append(
#         get_bert_embedding(sequence = item.seq, len_seq_limit = 1200))
#     checkpoint+=1
#
#     if checkpoint>=100:
#         df_res = pd.DataFrame(data={"id" : ids_list, "embed_vect" : embed_vects_list})
#         np.save('/kaggle/working/train_ids.npy',np.array(ids_list))
#         np.save('/kaggle/working/train_embeddings.npy',np.array(embed_vects_list))
#         checkpoint=0
#
# np.save('/kaggle/working/train_ids.npy',np.array(ids_list))
# np.save('/kaggle/working/train_embeddings.npy',np.array(embed_vects_list))
# print('Total Elapsed Time:',time.time()-t0)
#
# ### COLLECTING FOR TEST SAMPLES :
# print("Loading test set ProtBERT Embeddings...")
# fasta_test = SeqIO.parse(config.test_sequences_path, "fasta")
# print("Total Nb of Elements : ", len(list(fasta_test)))
# fasta_test = SeqIO.parse(config.test_sequences_path, "fasta")
# ids_list = []
# embed_vects_list = []
# t0 = time.time()
# checkpoint=0
# for item in tqdm(fasta_test):
#     ids_list.append(item.id)
#     embed_vects_list.append(
#         get_bert_embedding(sequence = item.seq, len_seq_limit = 1200))
#     checkpoint+=1
#     if checkpoint>=100:
#         np.save('/kaggle/working/test_ids.npy',np.array(ids_list))
#         np.save('/kaggle/working/test_embeddings.npy',np.array(embed_vects_list))
#         checkpoint=0
#
# np.save('/kaggle/working/test_ids.npy',np.array(ids_list))
# np.save('/kaggle/working/test_embeddings.npy',np.array(embed_vects_list))
# print('Total Elasped Time:',time.time()-t0)

In [14]:
##### SCRIPT FOR LABELS (TARGETS) VECTORS COLLECTING #####

print(f"GENERATE TARGETS FOR ENTRY IDS ({config.num_labels} MOST COMMON GO TERMS)")
ids = np.load(f"{ESM2_DIR}/train_ids.npy")
labels = pd.read_csv(config.train_labels_path, sep = "\t")

top_terms = labels.groupby("term")["EntryID"].count().sort_values(ascending=False)
labels_names = top_terms[:config.num_labels].index.values
train_labels_sub = labels[(labels.term.isin(labels_names)) & (labels.EntryID.isin(ids))]
id_labels = train_labels_sub.groupby('EntryID')['term'].apply(list).to_dict()

go_terms_map = {label: i for i, label in enumerate(labels_names)}
labels_matrix = np.empty((len(ids), len(labels_names)))

for index, id in tqdm(enumerate(ids)):
    id_gos_list = id_labels[id]
    temp = [go_terms_map[go] for go in labels_names if go in id_gos_list]
    labels_matrix[index, temp] = 1

labels_list = []
for l in range(labels_matrix.shape[0]):
    labels_list.append(labels_matrix[l, :])

labels_df = pd.DataFrame(data={"EntryID":ids, "labels_vect":labels_list})
labels_df.to_pickle(f"{ESM2_DIR}/train_targets_top{config.num_labels}.pkl")
print("GENERATION FINISHED!")
labels_df

GENERATE TARGETS FOR ENTRY IDS (10000 MOST COMMON GO TERMS)


142246it [15:39, 151.47it/s]


GENERATION FINISHED!


Unnamed: 0,EntryID,labels_vect
0,Q9ZSA8,"[0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ..."
1,P25353,"[0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, ..."
2,A0A2R8YCW8,"[1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ..."
3,G3V5N8,"[1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, ..."
4,A0A140LFN4,"[0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ..."


In [None]:
# Save and labels as a sparse matrix
labels_list_sparse = csr_matrix(labels_list)
labels_list_sparse.to_pickle(f"{ESM2_DIR}/train_targets_top{config.num_labels}_sparse.pkl")
labels_list_sparse = pd.read_pickle(f"{ESM2_DIR}/train_targets_top{config.num_labels}_sparse.pkl")

In [17]:
# load GO_weights (IA data) as a tensor to feed into the loss function

GO_weight_dataset = pd.read_table(f'{DATA_DIR}/IA.txt', header=None, names=['GO', 'weight'])
GO_weights = []
for each_label in labels_names:
    GO_weights.append(GO_weight_dataset.loc[GO_weight_dataset['GO'] == each_label]['weight'].values[0])

GO_weights = torch.tensor(GO_weights, dtype=torch.float32)
torch.save(GO_weights, f"{ESM2_DIR}/go_weights_{config.num_labels}.pt")
GO_weights

tensor([0.0000, 0.0000, 0.0255,  ..., 0.0000, 3.6413, 2.1203])

In [6]:
# IF you already have saved the labels and go_weights
labels_df = pd.read_pickle(f"{ESM2_DIR}/train_targets_top{config.num_labels}.pkl")
# GO_weights = torch.load(f"{ESM2_DIR}/go_weights_{config.num_labels}.pt")
# print(f'Labels shape: {labels_df.shape}, GO Weights shape: {GO_weights.shape}')
labels_df

Unnamed: 0,EntryID,labels_vect
0,Q9ZSA8,"[0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ..."
1,P25353,"[0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, ..."
2,A0A2R8YCW8,"[1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ..."
3,G3V5N8,"[1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, ..."
4,A0A140LFN4,"[0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ..."
...,...,...
142241,O81299,"[1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, ..."
142242,Q55AH8,"[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, ..."
142243,Q80VK8,"[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..."
142244,Q8IS12,"[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."


In [7]:
# Directories for the different embedding vectors :
embeds_map = {
    "T5" : "t5embeds",
    "ProtBERT" : "protbert-embeddings-for-cafa5",
    "ESM2" : "cafa-5-esm-2-embeddings-numpy"
}

# Length of the different embedding vectors :
embeds_dim = {
    "T5" : 1024,
    "ProtBERT" : 1024,
    "ESM2" : 1280
}

In [8]:
class ProteinSequenceDataset(Dataset):
    """
    Custom dataset to store embeddings of different sources
    It could be used to get training or test dataset
    """

    def __init__(self, datatype, embeddings_source):
        super(ProteinSequenceDataset).__init__()
        self.datatype = datatype

        if embeddings_source in ["ProtBERT", "ESM2"]:
            embeds = np.load(f"{MAIN_DIR}/"+embeds_map[embeddings_source]+"/"+datatype+"_embeddings.npy")
            ids = np.load(f"{MAIN_DIR}/"+embeds_map[embeddings_source]+"/"+datatype+"_ids.npy")

        if embeddings_source == "T5":
            embeds = np.load(f"{MAIN_DIR}/"+embeds_map[embeddings_source]+"/"+datatype+"_embeds.npy")
            ids = np.load(f"{MAIN_DIR}/"+embeds_map[embeddings_source]+"/"+datatype+"_ids.npy")

        embeds_list = []
        for l in range(embeds.shape[0]):
            embeds_list.append(embeds[l,:])
        self.df = pd.DataFrame(data={"EntryID": ids, "embed" : embeds_list})

        if datatype=="train":
            df_labels = pd.read_pickle(
                f"{MAIN_DIR}/{embeds_map[embeddings_source]}/train_targets_top{config.num_labels}.pkl")
            self.df = self.df.merge(df_labels, on="EntryID")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        embed = torch.tensor(self.df.iloc[index]["embed"], dtype=torch.float32)

        if self.datatype=="train":
            targets = torch.tensor(self.df.iloc[index]["labels_vect"], dtype=torch.float32)
            return embed, targets

        if self.datatype=="test":
            id = self.df.iloc[index]["EntryID"]
            return embed, id


dataset = ProteinSequenceDataset(datatype="train", embeddings_source="ESM2")
dataset.df.head(10)

Unnamed: 0,EntryID,embed,labels_vect
0,Q9ZSA8,"[-0.092291094, -0.066283956, -0.01226195, 0.04...","[0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ..."
1,P25353,"[0.011624349, -0.030317612, -0.0058019715, 0.0...","[0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, ..."
2,A0A2R8YCW8,"[0.02737274, -0.041047025, -0.029205365, 0.028...","[1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ..."
3,G3V5N8,"[0.033766113, -0.07888931, -0.05974137, 0.0456...","[1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, ..."
4,A0A140LFN4,"[0.0119482, -0.002107593, -0.084922194, 0.0687...","[0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ..."
5,B8ZZU6,"[0.020136692, -0.13927373, -0.04045331, 0.0626...","[1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, ..."
6,Q01850,"[0.026013047, -0.08974387, -0.020240448, 0.132...","[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, ..."
7,P11076,"[0.005006821, -0.03306428, -0.037696116, 0.059...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..."
8,Q9VJ64,"[0.029283574, -0.010495956, -0.013060905, 0.02...","[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
9,Q7YSJ4,"[0.02852764, -0.049777817, -0.056415968, 0.141...","[1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."


In [9]:
embeddings, labels = dataset.__getitem__(0)
print("COMPONENTS FOR FIRST PROTEIN:  ")
print("EMBEDDINGS VECTOR: \n ", embeddings, "\n")
print("TARGETS LABELS VECTOR: \n ", labels, "\n")

COMPONENTS FOR FIRST PROTEIN:  
EMBEDDINGS VECTOR: 
  tensor([-0.0923, -0.0663, -0.0123,  ..., -0.1607,  0.0159,  0.0017]) 

TARGETS LABELS VECTOR: 
  tensor([0., 1., 0.,  ..., 0., 0., 0.]) 



In [12]:
class ResidualNetwork(nn.Module):
    """
    A deep Residual Network module with step by step predictions.
    """

    def __init__(self, input_dim, step_dim):
        super(ResidualNetwork, self).__init__()

        self.forward_linear1 = torch.nn.Linear(input_dim, input_dim)
        self.batchnorm1 = nn.BatchNorm1d(input_dim)
        self.activation1 = torch.nn.ReLU()
        self.dropout1 = nn.Dropout()

        self.pred_linear1 = torch.nn.Linear(input_dim, step_dim[0])
        self.sigmoid1 = torch.nn.Sigmoid()

        self.forward_linear2 = torch.nn.Linear(input_dim + step_dim[0], input_dim + step_dim[0])
        self.batchnorm2 = nn.BatchNorm1d(input_dim + step_dim[0])
        self.activation2 = torch.nn.ReLU()
        self.dropout2 = nn.Dropout()

        self.pred_linear2 = torch.nn.Linear(input_dim + step_dim[0], step_dim[1])
        self.sigmoid2 = torch.nn.Sigmoid()

        self.forward_linear3 = torch.nn.Linear(input_dim + step_dim[0] + step_dim[1], input_dim + step_dim[0] + step_dim[1])
        self.batchnorm3 = nn.BatchNorm1d(input_dim + step_dim[0] + step_dim[1])
        self.activation3 = torch.nn.ReLU()
        self.dropout3 = nn.Dropout()

        self.pred_linear3 = torch.nn.Linear(input_dim + step_dim[0] + step_dim[1], step_dim[2])
        self.sigmoid3 = torch.nn.Sigmoid()

        self.forward_linear4 = torch.nn.Linear(input_dim + step_dim[0] + step_dim[1] + step_dim[2], input_dim + step_dim[0] + step_dim[1] + step_dim[2])
        self.batchnorm4 = nn.BatchNorm1d(input_dim + step_dim[0] + step_dim[1] + step_dim[2])
        self.activation4 = torch.nn.ReLU()
        self.dropout4 = nn.Dropout()

        self.pred_linear4 = torch.nn.Linear(input_dim + step_dim[0] + step_dim[1] + step_dim[2], step_dim[3])
        self.sigmoid4 = torch.nn.Sigmoid()


    def forward(self, x):
        x = self.forward_linear1(x)
        x = self.batchnorm1(x)
        x = self.activation11(x)
        x = self.dropout1(x)

        y1 = self.sigmoid1(self.pred_linear1(x))

        x = torch.cat([x, y1], dim=1)
        x = self.forward_linear2(x)
        x = self.batchnorm2(x)
        x = self.activation2(x)
        x = self.dropout2(x)

        y2 = self.sigmoid2(self.pred_linear2(x))

        x = torch.cat([x, y2], dim=1)
        x = self.forward_linear3(x)
        x = self.batchnorm3(x)
        x = self.activation3(x)
        x = self.dropout3(x)

        y3 = self.sigmoid3(self.pred_linear3(x))

        x = torch.cat([x, y3], dim=1)
        x = self.forward_linear3(x)
        x = self.batchnorm3(x)
        x = self.activation3(x)
        x = self.dropout3(x)

        y4 = self.sigmoid4(self.pred_linear4(x))

        y = torch.cat([y1, y2, y3, y4], dim=1)

        return y

In [32]:
class MultiLayerPerceptron(nn.Module):
    """
    Adjusted MLP model with 6 linear layers
    """

    def __init__(self, input_dim, num_classes):
        super(MultiLayerPerceptron, self).__init__()

        self.linear1 = torch.nn.Linear(input_dim, 1280)
        self.activation1 = torch.nn.ReLU()
        self.batchnorm1 = nn.BatchNorm1d(1280)
        self.dropout1 = nn.Dropout()

        self.linear2 = torch.nn.Linear(1280, 1800)
        self.activation2 = torch.nn.ReLU()
        self.batchnorm2 = nn.BatchNorm1d(1800)
        self.dropout2 = nn.Dropout()

        self.linear3 = torch.nn.Linear(1800, 2560)
        self.activation3 = torch.nn.ReLU()
        self.batchnorm3 = nn.BatchNorm1d(2560)
        self.dropout3 = nn.Dropout()

        self.linear4 = torch.nn.Linear(2560, 3200)
        self.activation4 = torch.nn.ReLU()
        self.batchnorm4 = nn.BatchNorm1d(3200)
        self.dropout4 = nn.Dropout()

        self.linear5 = torch.nn.Linear(3200, 4200)
        self.activation5 = torch.nn.ReLU()
        self.batchnorm5 = nn.BatchNorm1d(4200)
        self.dropout5 = nn.Dropout()

        self.linear6 = torch.nn.Linear(4200, num_classes)

    def forward(self, x):
        x = self.linear1(x)
        x = self.batchnorm1(x)
        x = self.activation1(x)
        x = self.dropout1(x)

        x = self.linear2(x)
        x = self.batchnorm2(x)
        x = self.activation2(x)
        x = self.dropout2(x)

        x = self.linear3(x)
        x = self.batchnorm3(x)
        x = self.activation3(x)
        x = self.dropout3(x)

        x = self.linear4(x)
        x = self.batchnorm4(x)
        x = self.activation4(x)
        x = self.dropout4(x)

        x = self.linear5(x)
        x = self.batchnorm5(x)
        x = self.activation5(x)
        x = self.dropout5(x)

        x = self.linear6(x)
        return x

In [33]:
class CNN1D(nn.Module):
    """
    Baseline CNN-1D model to make predictions using CLS token embeddings
    """

    def __init__(self, input_dim, num_classes):
        super(CNN1D, self).__init__()
        # (batch_size, channels, embed_size)
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=3, kernel_size=3, dilation=1, padding=1, stride=1)
        # (batch_size, 3, embed_size)
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
        # (batch_size, 3, embed_size/2 = 512)
        self.conv2 = nn.Conv1d(in_channels=3, out_channels=8, kernel_size=3, dilation=1, padding=1, stride=1)
        # (batch_size, 8, embed_size/2 = 512)
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
        # (batch_size, 8, embed_size/4 = 256)
        self.fc1 = nn.Linear(in_features=int(8 * input_dim/4), out_features=1024)       # 1024 is better
        self.fc2 = nn.Linear(in_features=1024, out_features=num_classes)                # 1024 is better

    def forward(self, x):
        x = x.reshape(x.shape[0], 1, x.shape[1])
        x = self.pool1(nn.functional.relu(self.conv1(x)))
        x = self.pool2(nn.functional.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [13]:
def get_train_val_dataloader(embeddings_source, train_size=0.9):
    train_dataset = ProteinSequenceDataset(datatype="train", embeddings_source=embeddings_source)
    train_set, val_set = random_split(train_dataset,
                                      lengths=[int(len(train_dataset)*train_size), len(train_dataset)-int(len(train_dataset)*train_size)])

    train_dataloader = DataLoader(train_set, batch_size=config.batch_size, shuffle=True)
    val_dataloader = DataLoader(val_set, batch_size=config.batch_size, shuffle=True)

    return train_dataloader, val_dataloader

In [14]:
def train_model(train_dataloader, val_dataloader, embeddings_source, model_type):
    """
    Custom function to train the baseline model on dataset
    :param val_dataloader: dataloader for validation data
    :param train_dataloader: dataloader for training data
    :param embeddings_source: define the type of embedding
    :param model_type: define the type of model
    """
    if model_type == 'residual':
        model = ResidualNetwork(input_dim=embeds_dim[embeddings_source], step_dim=[800, 1200, 1400, 1600])

    if model_type == "linear":
        model = MultiLayerPerceptron(input_dim=embeds_dim[embeddings_source], num_classes=config.num_labels)

    if model_type == "conv":
        model = CNN1D(input_dim=embeds_dim[embeddings_source], num_classes=config.num_labels)

    print(f'Model:\n{model}')
    model.to(config.device)

    # define configurations of the model
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
    scheduler = ExponentialLR(optimizer, gamma=config.gamma)
    # scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=1)

    # multilabel prediction task, GO_weights could be used as weight in the loss function and f1score
    GO_weights.to(config.device)
    MultiLabelLoss = torch.nn.BCEWithLogitsLoss(weight=GO_weights+1)
    # f1_score = MultilabelF1Score(num_labels=config.num_labels).to(config.device)
    n_epochs = config.n_epochs

    print("BEGIN TRAINING...")
    train_loss_history, val_loss_history = [], []
    train_f1score_history, val_f1score_history = [], []
    train_accuracy_history, val_accuracy_history = [], []

    for epoch in range(n_epochs):
        print("EPOCH ", epoch+1)

        ## TRAIN PHASE :
        model.train()
        losses, scores, accuracy = [], [], []

        for embed, targets in tqdm(train_dataloader):
            embed, targets = embed.to(config.device), targets.to(config.device)
            preds = model(embed)

            loss = MultiLabelLoss(preds, targets)
            preds = (sigmoid(preds).detach().cpu().numpy() > 0.5).astype(int)
            targets = targets.detach().cpu().numpy()
            
            score = f1_score(targets, preds, average='weighted', sample_weight=GO_weights)
            acc = np.mean([accuracy_score(targets[i], preds[i], sample_weight=GO_weights) for i in range(len(targets))])
            losses.append(loss.item())
            scores.append(score)
            accuracy.append(acc)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        avg_loss = np.mean(losses)
        avg_score = np.mean(scores)
        avg_accuracy = np.mean(accuracy)
        print("Running Average TRAIN Loss: ", avg_loss)
        print("Running Average TRAIN F1-Score: ", avg_score)
        print("Running Average TRAIN Accuracy: ", avg_accuracy)
        train_loss_history.append(avg_loss)
        train_f1score_history.append(avg_score)
        train_accuracy_history.append(avg_accuracy)

        ## VALIDATION PHASE :
        model.eval()
        losses, scores, accuracy = [], [], []

        for embed, targets in val_dataloader:
            embed, targets = embed.to(config.device), targets.to(config.device)
            preds = model(embed)

            loss = MultiLabelLoss(preds, targets)
            preds = (sigmoid(preds).detach().cpu().numpy() > 0.5).astype(int)
            targets = targets.detach().cpu().numpy()
            
            score = f1_score(targets, preds, average='weighted', sample_weight=GO_weights)
            acc = np.mean([accuracy_score(targets[i], preds[i], sample_weight=GO_weights) for i in range(len(targets))])
            losses.append(loss.item())
            scores.append(score)
            accuracy.append(acc)

        avg_loss = np.mean(losses)
        avg_score = np.mean(scores)
        avg_accuracy = np.mean(accuracy)
        print("Running Average VAL Loss: ", avg_loss)
        print("Running Average VAL F1-Score: ", avg_score)
        print("Running Average VAL Accuracy: ", avg_accuracy)
        val_loss_history.append(avg_loss)
        val_f1score_history.append(avg_score)
        val_accuracy_history.append(avg_accuracy)

        scheduler.step(), print("\n")


    print("TRAINING FINISHED _____")
    print(f"FINAL TRAINING F1SCORE: {train_f1score_history[-1]},  ACCURACY: {train_accuracy_history[-1]}")
    print(f"FINAL VALIDATION F1SCORE: {val_f1score_history[-1]},  ACCURACY: {val_accuracy_history[-1]}")

    losses_history = {"train" : train_loss_history, "val" : val_loss_history}
    scores_history = {"train" : train_f1score_history, "val" : val_f1score_history}
    accuracy_history = {"train" : train_accuracy_history, "val" : val_accuracy_history}

    return model, losses_history, scores_history, accuracy_history

In [None]:
train_dataloader, val_dataloader = get_train_val_dataloader("ESM2")

esm2_model, esm2_losses, esm2_scores, esm2_accuracy = train_model(train_dataloader,
                                                                  val_dataloader,
                                                                  embeddings_source="ESM2",
                                                                  model_type="residual")

In [None]:
sample = next(iter(val_dataloader))[0][1].unsqueeze(0)
label = next(iter(val_dataloader))[1][1].unsqueeze(0)
print(f"Sample:\n{sample}\n\nLabel:\n{label}")

In [None]:
esm2_model.to('cpu')
esm2_model.eval()
output = esm2_model(sample)
output = torch.round(sigmoid(output))
output = output.detach().cpu().numpy()
accuracy_score(label[0], output[0])

In [None]:
from transformers import AutoTokenizer, EsmModel
tokenizer = AutoTokenizer.from_pretrained('facebook/esm2_t33_650M_UR50D')
embedding_model = EsmModel.from_pretrained('facebook/esm2_t33_650M_UR50D', add_cross_attention=False, is_decoder=False)
embedding_model.eval()
embedding_model

In [None]:
ids = tokenizer(['MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN', 'QQQQQ'], add_special_tokens=True, padding="longest")
input_ids = torch.tensor(ids['input_ids'])
attention_mask = torch.tensor(ids['attention_mask'])
output = embedding_model(input_ids=input_ids, attention_mask=attention_mask)

In [None]:
pred = esm2_model(output.last_hidden_state[0].to(device))
cls = esm2_model(output.last_hidden_state[0][0].to(device))
torch.nn.functional.sigmoid(cls) > 0.5

In [None]:
esm2_model(dataset[0][0].reshape(1, -1).to(config.device))

In [None]:
def predict(embeddings_source, data_source):
    """
    Custom function to make inference using the model
    :param embeddings_source: define the type of embedding
    """

    test_dataset = ProteinSequenceDataset(datatype="test", embeddings_source = embeddings_source)

    if data_source == "val":
        predict_dataloader = val_dataloader

    if data_source == 'test':
        predict_dataloader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)

    if embeddings_source == "T5":
        model = t5_model
    if embeddings_source == "ProtBERT":
        model = protbert_model
    if embeddings_source == "ESM2":
        model = esm2_model

    # Set model on evaluation mode
    model.eval()

    labels = pd.read_csv(config.train_labels_path, sep = "\t")
    top_terms = labels.groupby("term")["EntryID"].count().sort_values(ascending=False)
    labels_names = top_terms[:config.num_labels].index.values
    print("GENERATE PREDICTION FOR TEST SET...")

    ids_ = np.empty(shape=(len(predict_dataloader)*config.num_labels,), dtype=object)
    go_terms_ = np.empty(shape=(len(predict_dataloader)*config.num_labels,), dtype=object)
    confs_ = np.empty(shape=(len(predict_dataloader)*config.num_labels,), dtype=np.float32)

    for i, (embed, id) in tqdm(enumerate(predict_dataloader)):
        embed = embed.to(config.device)
        confs_[i*config.num_labels:(i+1)*config.num_labels] = sigmoid(model(embed)).squeeze().detach().cpu().numpy()
        ids_[i*config.num_labels:(i+1)*config.num_labels] = id[0]
        go_terms_[i*config.num_labels:(i+1)*config.num_labels] = labels_names

    submission_df = pd.DataFrame(data={"Id" : ids_, "GO term" : go_terms_, "Confidence" : confs_})
    print("PREDICTIONS DONE")
    return submission_df

In [None]:
submission_df = predict("T5", "val")
submission_df.to_tsv("/working/predictions_val.tsv")
submission_df.head(50)

In [None]:
### SCRIPT TO EVALUATE PREDICTIONS USING CAFA EVALUATOR ###

import cafaeval
from cafaeval.evaluation import cafa_eval

cafa_eval(f"{DATA_DIR}/Train/go-basic.obo", submission_df, f"{DATA_DIR}/Train/train_terms.tsv", ia=f"{DATA_DIR}/IA.txt")

In [None]:
### IN PROGRESS - SCRIPT TO TRAIN THE MODEL USING PyTorchLightning ###

class Linear_Lightning(pl.LightningModule):
    """
    In progress, used to train the MLP model on multiple GPUs using PyTorchLightning
    """

    def __init__(self, input_dim, num_classes, train_size, **hparams):
        super(Linear_Lightning, self).__init__()

        self.model = MultiLayerPerceptron(input_dim=embeds_dim[embeddings_source], num_classes=config.num_labels).to(config.device)

        train_dataset = ProteinSequenceDataset(datatype="train", embeddings_source = embeddings_source)
        self.train_set, self.val_set = random_split(train_dataset, lengths = [int(len(train_dataset)*train_size), len(train_dataset)-int(len(train_dataset)*train_size)])

        self.loss_fn = torch.nn.BCEWithLogitsLoss()
        self.batch_size = batch_size
        self.lr = lr

        self.f1_score = MultilabelF1Score(num_labels=num_classes)
        self.accuracy = MultilabelAccuracy(num_labels=num_classes)


    def forward(self, x):
        return self.model(x)


    def training_step(self, batch, batch_idx):
        embed, targets = batch
        preds = self(embed)
        loss = self.loss_fn(preds, targets)
        f1_score = self.f1_score(preds, targets)
        acc_score = self.accuracy(preds, targets)

        logs = {"train_loss" : loss, "f1_score" : f1_score, "accuracy_score" : acc_score}
        self.log_dict(
            logs,
            on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        return {"loss": loss, "log": logs}


    def validation_step(self, batch, batch_idx):
        embed, targets = batch
        preds = self(embed)
        loss= self.loss_fn(preds, targets)
        f1_score = self.f1_score(preds, targets)
        acc_score = self.accuracy(preds, targets)

        return {"val_loss": loss, "f1_score": f1_score, "accuracy_score": acc_score}


    def validation_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in ouputs]).mean()
        logs = {"val_loss" : avg_loss}
        self.log_dict(
            logs,
            on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        return {"avg_val_loss": avg_loss, "log": logs}


    def val_dataloader(self):
        val_dataloader = torch.utils.data.DataLoader(self.val_set, batch_size=config.batch_size, shuffle=False,)
        return val_dataloader


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer


    def train_dataloader(self):
        train_dataloader = torch.utils.data.DataLoader(self.train_set, batch_size=self.batch_size, shuffle=False)
        return train_dataloader


trainer = Trainer(
    max_epochs=config.n_epochs,
    limit_train_batches=5000,
    logger=logger)


model = Linear_Lightning(
    input_dim=embeds_dim[embeddings_source],
    num_classes=config.num_labels,
    train_size=0.8)


trainer.fit(model)