In [1]:
import os
import time
from tqdm import tqdm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('seaborn')

import torch
from torch.utils.data import Dataset
from torch import nn
from torch.utils.data import random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.classification import MultilabelF1Score
from torchmetrics.classification import MultilabelAccuracy

from transformers import BertModel, BertTokenizer

torch.cuda.get_device_name(torch.cuda.device)

  plt.style.use('seaborn')


'NVIDIA GeForce GTX 1650 Ti'

In [2]:
MAIN_DIR = "data/"
WORK_DIR = "working/"
DATA_DIR = MAIN_DIR + "cafa-5-protein-function-prediction"
PROTBERT_DIR = MAIN_DIR + "protbert-embeddings-for-cafa5"

for dirname, _, filenames in os.walk(MAIN_DIR):
    for filename in filenames:
        print(os.path.join(dirname, filename))

data/cafa-5-ems-2-embeddings-numpy\test_embeddings.npy
data/cafa-5-ems-2-embeddings-numpy\test_ids.npy
data/cafa-5-ems-2-embeddings-numpy\train_embeddings.npy
data/cafa-5-ems-2-embeddings-numpy\train_ids.npy
data/cafa-5-protein-function-prediction\IA.txt
data/cafa-5-protein-function-prediction\sample_submission.tsv
data/cafa-5-protein-function-prediction\Test (Targets)\testsuperset-taxon-list.tsv
data/cafa-5-protein-function-prediction\Test (Targets)\testsuperset.fasta
data/cafa-5-protein-function-prediction\Train\go-basic.obo
data/cafa-5-protein-function-prediction\Train\train_sequences.fasta
data/cafa-5-protein-function-prediction\Train\train_taxonomy.tsv
data/cafa-5-protein-function-prediction\Train\train_terms.tsv
data/protbert-embeddings-for-cafa5\test_embeddings.npy
data/protbert-embeddings-for-cafa5\test_ids.npy
data/protbert-embeddings-for-cafa5\train_embeddings.npy
data/protbert-embeddings-for-cafa5\train_ids.npy
data/t5embeds\test_embeds.npy
data/t5embeds\test_ids.npy
data/t5

In [3]:
submission = pd.read_csv(f'{DATA_DIR}/sample_submission.tsv', sep='\t', header=None)
submission.columns = ["ProteinID", "GO_ID", "Probability"]
submission.head(10)

Unnamed: 0,ProteinID,GO_ID,Probability
0,A0A0A0MRZ7,GO:0000001,0.123
1,A0A0A0MRZ7,GO:0000002,0.123
2,A0A0A0MRZ8,GO:0000001,0.123
3,A0A0A0MRZ8,GO:0000002,0.123
4,A0A0A0MRZ9,GO:0000001,0.123
5,A0A0A0MRZ9,GO:0000002,0.123
6,A0A0A0MS00,GO:0000001,0.123
7,A0A0A0MS00,GO:0000002,0.123
8,A0A0A0MS01,GO:0000001,0.123
9,A0A0A0MS01,GO:0000002,0.123


In [16]:
class config:
    train_sequences_path = DATA_DIR  + "/Train/train_sequences.fasta"
    train_labels_path = DATA_DIR + "/Train/train_terms.tsv"
    test_sequences_path = DATA_DIR + "/Test (Targets)/testsuperset.fasta"

    num_labels = 500
    n_epochs = 25
    batch_size = 128
    lr = 0.002

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Device: {device} - {torch.cuda.get_device_name(device)}')

Device: cuda - NVIDIA GeForce GTX 1650 Ti


In [None]:
# ______________________ GET PROT BERT EMBEDDINGS WITH HUGGING FACE __________________________________
#
# # PROT BERT LOADING :
# tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
# model = BertModel.from_pretrained("Rostlab/prot_bert").to(config.device)
#
# def get_bert_embedding(
#     sequence : str,
#     len_seq_limit : int
# ):
#     """
#     Function to collect last hidden state embedding vector from pre-trained ProtBERT Model
#
#     INPUTS:
#     - sequence (str) : protein sequence (ex : AAABBB) from fasta file
#     - len_seq_limit (int) : maximum sequence lenght (i.e nb of letters) for truncation
#
#     OUTPUTS:
#     - output_hidden : last hidden state embedding vector for input sequence of length 1024
#     """
#     sequence_w_spaces = ' '.join(list(sequence))
#     encoded_input = tokenizer(
#         sequence_w_spaces,
#         truncation=True,
#         max_length=len_seq_limit,
#         padding='max_length',
#         return_tensors='pt').to(config.device)
#     output = model(**encoded_input)
#     output_hidden = output['last_hidden_state'][:,0][0].detach().cpu().numpy()
#     assert len(output_hidden)==1024
#     return output_hidden
#
# ### COLLECTING FOR TRAIN SAMPLES :
# print("Loading train set ProtBERT Embeddings...")
# fasta_train = SeqIO.parse(config.train_sequences_path, "fasta")
# print("Total Nb of Elements : ", len(list(fasta_train)))
# fasta_train = SeqIO.parse(config.train_sequences_path, "fasta")
# ids_list = []
# embed_vects_list = []
# t0 = time.time()
# checkpoint = 0
# for item in tqdm(fasta_train):
#     ids_list.append(item.id)
#     embed_vects_list.append(
#         get_bert_embedding(sequence = item.seq, len_seq_limit = 1200))
#     checkpoint+=1
#     if checkpoint>=100:
#         df_res = pd.DataFrame(data={"id" : ids_list, "embed_vect" : embed_vects_list})
#         np.save('/kaggle/working/train_ids.npy',np.array(ids_list))
#         np.save('/kaggle/working/train_embeddings.npy',np.array(embed_vects_list))
#         checkpoint=0
#
# np.save('/kaggle/working/train_ids.npy',np.array(ids_list))
# np.save('/kaggle/working/train_embeddings.npy',np.array(embed_vects_list))
# print('Total Elapsed Time:',time.time()-t0)
#
# ### COLLECTING FOR TEST SAMPLES :
# print("Loading test set ProtBERT Embeddings...")
# fasta_test = SeqIO.parse(config.test_sequences_path, "fasta")
# print("Total Nb of Elements : ", len(list(fasta_test)))
# fasta_test = SeqIO.parse(config.test_sequences_path, "fasta")
# ids_list = []
# embed_vects_list = []
# t0 = time.time()
# checkpoint=0
# for item in tqdm(fasta_test):
#     ids_list.append(item.id)
#     embed_vects_list.append(
#         get_bert_embedding(sequence = item.seq, len_seq_limit = 1200))
#     checkpoint+=1
#     if checkpoint>=100:
#         np.save('/kaggle/working/test_ids.npy',np.array(ids_list))
#         np.save('/kaggle/working/test_embeddings.npy',np.array(embed_vects_list))
#         checkpoint=0
#
# np.save('/kaggle/working/test_ids.npy',np.array(ids_list))
# np.save('/kaggle/working/test_embeddings.npy',np.array(embed_vects_list))
# print('Total Elasped Time:',time.time()-t0)

In [5]:
##### SCRIPT FOR LABELS (TARGETS) VECTORS COLLECTING #####

print("GENERATE TARGETS FOR ENTRY IDS ("+str(config.num_labels)+" MOST COMMON GO TERMS)")
ids = np.load(f"{PROTBERT_DIR}/train_ids.npy")
labels = pd.read_csv(config.train_labels_path, sep = "\t")

top_terms = labels.groupby("term")["EntryID"].count().sort_values(ascending=False)
labels_names = top_terms[:config.num_labels].index.values
train_labels_sub = labels[(labels.term.isin(labels_names)) & (labels.EntryID.isin(ids))]
id_labels = train_labels_sub.groupby('EntryID')['term'].apply(list).to_dict()

go_terms_map = {label: i for i, label in enumerate(labels_names)}
labels_matrix = np.empty((len(ids), len(labels_names)))

for index, id in tqdm(enumerate(ids)):
    id_gos_list = id_labels[id]
    temp = [go_terms_map[go] for go in labels_names if go in id_gos_list]
    labels_matrix[index, temp] = 1

labels_list = []
for l in range(labels_matrix.shape[0]):
    labels_list.append(labels_matrix[l, :])

labels_df = pd.DataFrame(data={"EntryID":ids, "labels_vect":labels_list})
labels_df.to_pickle(f"{WORK_DIR}/train_targets_top"+str(config.num_labels)+".pkl")
print("GENERATION FINISHED!")
labels_df.head(5)

GENERATE TARGETS FOR ENTRY IDS (500 MOST COMMON GO TERMS)


142246it [00:37, 3829.12it/s]


GENERATION FINISHED!


Unnamed: 0,EntryID,labels_vect
0,P20536,"[0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, ..."
1,O73864,"[1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, ..."
2,O95231,"[1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, ..."
3,A0A0B4J1F4,"[1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, ..."
4,P54366,"[1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, ..."


In [6]:
# Directories for the different embedding vectors :
embeds_map = {
    "T5" : "t5embeds",
    "ProtBERT" : "protbert-embeddings-for-cafa5",
    "EMS2" : "cafa-5-ems-2-embeddings-numpy"
}

# Length of the different embedding vectors :
embeds_dim = {
    "T5" : 1024,
    "ProtBERT" : 1024,
    "EMS2" : 1280
}

In [8]:
class ProteinSequenceDataset(Dataset):

    def __init__(self, datatype, embeddings_source):
        super(ProteinSequenceDataset).__init__()
        self.datatype = datatype

        if embeddings_source in ["ProtBERT", "EMS2"]:
            embeds = np.load(f"{MAIN_DIR}"+embeds_map[embeddings_source]+"/"+datatype+"_embeddings.npy")
            ids = np.load(f"{MAIN_DIR}/"+embeds_map[embeddings_source]+"/"+datatype+"_ids.npy")

        if embeddings_source == "T5":
            embeds = np.load(f"{MAIN_DIR}/"+embeds_map[embeddings_source]+"/"+datatype+"_embeds.npy")
            ids = np.load(f"{MAIN_DIR}/"+embeds_map[embeddings_source]+"/"+datatype+"_ids.npy")

        embeds_list = []
        for l in range(embeds.shape[0]):
            embeds_list.append(embeds[l,:])
        self.df = pd.DataFrame(data={"EntryID": ids, "embed" : embeds_list})

        if datatype=="train":
            df_labels = pd.read_pickle(
                f"{WORK_DIR}/train_targets_top"+str(config.num_labels)+".pkl")
            self.df = self.df.merge(df_labels, on="EntryID")\

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        embed = torch.tensor(self.df.iloc[index]["embed"], dtype=torch.float32)

        if self.datatype=="train":
            targets = torch.tensor(self.df.iloc[index]["labels_vect"], dtype=torch.float32)
            return embed, targets

        if self.datatype=="test":
            id = self.df.iloc[index]["EntryID"]
            return embed, id


dataset = ProteinSequenceDataset(datatype="train", embeddings_source="T5")
dataset.df.head(10)

Unnamed: 0,EntryID,embed,labels_vect
0,P20536,"[0.04948842525482178, -0.03293515741825104, 0....","[0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, ..."
1,O73864,"[-0.04461636394262314, 0.06492499262094498, -0...","[1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, ..."
2,O95231,"[-0.02012803591787815, -0.04977943375706673, 0...","[1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, ..."
3,A0A0B4J1F4,"[-0.00751461973413825, 0.06062775477766991, 0....","[1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, ..."
4,P54366,"[0.013468174263834953, 0.04151567816734314, 0....","[1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, ..."
5,P33681,"[0.001116646104492247, -0.01536268275231123, 0...","[1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, ..."
6,P77596,"[0.03678780049085617, 0.052980050444602966, 0....","[0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, ..."
7,Q16787,"[0.007108339574187994, 0.01562744379043579, 0....","[1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, ..."
8,Q59VP0,"[-0.006104866974055767, -0.026720179244875908,...","[0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, ..."
9,P13508,"[-0.0071898759342730045, -0.02323203906416893,...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..."


In [9]:
embeddings, labels = dataset.__getitem__(0)
print("COMPONENTS FOR FIRST PROTEIN:  ")
print("EMBEDDINGS VECTOR: \n ", embeddings, "\n")
print("TARGETS LABELS VECTOR: \n ", labels, "\n")

COMPONENTS FOR FIRST PROTEIN:  
EMBEDDINGS VECTOR: 
  tensor([ 0.0495, -0.0329,  0.0325,  ..., -0.0435,  0.0965,  0.0731]) 

TARGETS LABELS VECTOR: 
  tensor([0., 1., 0., 1., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 1.,
        0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 1., 0., 0., 0., 0., 0., 0., 1.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 1.,
        0., 1., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
        0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0

In [17]:
class MultiLayerPerceptron(torch.nn.Module):

    def __init__(self, input_dim, num_classes):
        super(MultiLayerPerceptron, self).__init__()

        self.linear1 = torch.nn.Linear(input_dim, input_dim)
        self.activation1 = torch.nn.ReLU()
        self.linear1 = torch.nn.Linear(input_dim, 1000)
        self.activation1 = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(1000, 800)
        self.activation2 = torch.nn.ReLU()
        self.linear3 = torch.nn.Linear(800, num_classes)

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation1(x)
        x = self.linear2(x)
        x = self.activation2(x)
        x = self.linear3(x)
        return x

In [11]:
class CNN1D(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(CNN1D, self).__init__()
        # (batch_size, channels, embed_size)
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=3, kernel_size=3, dilation=1, padding=1, stride=1)
        # (batch_size, 3, embed_size)
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
        # (batch_size, 3, embed_size/2 = 512)
        self.conv2 = nn.Conv1d(in_channels=3, out_channels=8, kernel_size=3, dilation=1, padding=1, stride=1)
        # (batch_size, 8, embed_size/2 = 512)
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
        # (batch_size, 8, embed_size/4 = 256)
        self.fc1 = nn.Linear(in_features=int(8 * input_dim/4), out_features=1024)       # 1024 is better
        self.fc2 = nn.Linear(in_features=1024, out_features=num_classes)                # 1024 is better

    def forward(self, x):
        x = x.reshape(x.shape[0], 1, x.shape[1])
        x = self.pool1(nn.functional.relu(self.conv1(x)))
        x = self.pool2(nn.functional.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
def train_model(embeddings_source, model_type="linear", train_size=0.9):

    train_dataset = ProteinSequenceDataset(datatype="train", embeddings_source = embeddings_source)

    train_set, val_set = random_split(train_dataset, lengths = [int(len(train_dataset)*train_size), len(train_dataset)-int(len(train_dataset)*train_size)])
    train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=config.batch_size, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_set, batch_size=config.batch_size, shuffle=True)

    if model_type == "linear":
        model = MultiLayerPerceptron(input_dim=embeds_dim[embeddings_source], num_classes=config.num_labels).to(config.device)

    if model_type == "conv":
        model = CNN1D(input_dim=embeds_dim[embeddings_source], num_classes=config.num_labels).to(config.device)

    optimizer = torch.optim.Adam(model.parameters(), lr = config.lr)
    scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=1)
    MultiLabelLoss = torch.nn.BCEWithLogitsLoss()
    f1_score = MultilabelF1Score(num_labels=config.num_labels).to(config.device)
    n_epochs = config.n_epochs

    print("BEGIN TRAINING...")
    train_loss_history=[]
    val_loss_history=[]

    train_f1score_history=[]
    val_f1score_history=[]

    for epoch in range(n_epochs):
        print("EPOCH ", epoch+1)

        ## TRAIN PHASE :
        losses, scores = [], []

        for embed, targets in tqdm(train_dataloader):
            embed, targets = embed.to(config.device), targets.to(config.device)
            preds = model(embed)
            loss= MultiLabelLoss(preds, targets)

            score=f1_score(preds, targets)
            losses.append(loss.item())
            scores.append(score.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        avg_loss = np.mean(losses)
        avg_score = np.mean(scores)
        print("Running Average TRAIN Loss : ", avg_loss)
        print("Running Average TRAIN F1-Score : ", avg_score)
        train_loss_history.append(avg_loss)
        train_f1score_history.append(avg_score)

        ## VALIDATION PHASE :
        losses, scores = [], []

        for embed, targets in val_dataloader:
            embed, targets = embed.to(config.device), targets.to(config.device)
            preds = model(embed)

            loss= MultiLabelLoss(preds, targets)
            score=f1_score(preds, targets)
            losses.append(loss.item())
            scores.append(score.item())

        avg_loss = np.mean(losses)
        avg_score = np.mean(scores)
        print("Running Average VAL Loss : ", avg_loss)
        print("Running Average VAL F1-Score : ", avg_score)
        val_loss_history.append(avg_loss)
        val_f1score_history.append(avg_score)

        scheduler.step(avg_loss)
        print("\n")

    print("TRAINING FINISHED")
    print("FINAL TRAINING SCORE : ", train_f1score_history[-1])
    print("FINAL VALIDATION SCORE : ", val_f1score_history[-1])

    losses_history = {"train" : train_loss_history, "val" : val_loss_history}
    scores_history = {"train" : train_f1score_history, "val" : val_f1score_history}

    return model, losses_history, scores_history

In [None]:
t5_model, t5_losses, t5_scores = train_model(embeddings_source="T5", model_type="linear")

BEGIN TRAINING...
EPOCH  1


100%|██████████| 1001/1001 [00:38<00:00, 25.72it/s]


Running Average TRAIN Loss :  0.14723475045942283
Running Average TRAIN F1-Score :  0.0372122754938305
Running Average VAL Loss :  0.13460040618000285
Running Average VAL F1-Score :  0.0732388557433816


EPOCH  2


100%|██████████| 1001/1001 [00:38<00:00, 25.94it/s]


Running Average TRAIN Loss :  0.13037609190612168
Running Average TRAIN F1-Score :  0.0806840056596281
Running Average VAL Loss :  0.12976928959999764
Running Average VAL F1-Score :  0.09229968850766974


EPOCH  3


100%|██████████| 1001/1001 [00:37<00:00, 26.43it/s]


Running Average TRAIN Loss :  0.12601814652001345
Running Average TRAIN F1-Score :  0.10553580846685867
Running Average VAL Loss :  0.1273318100720644
Running Average VAL F1-Score :  0.10245534395133811


EPOCH  4


100%|██████████| 1001/1001 [00:37<00:00, 26.57it/s]


Running Average TRAIN Loss :  0.12247991531372784
Running Average TRAIN F1-Score :  0.12793126974727487
Running Average VAL Loss :  0.12556536103199636
Running Average VAL F1-Score :  0.1269398009005402


EPOCH  5


100%|██████████| 1001/1001 [00:38<00:00, 25.82it/s]


Running Average TRAIN Loss :  0.11940935002518939
Running Average TRAIN F1-Score :  0.1490349325460273
Running Average VAL Loss :  0.12478991903896842
Running Average VAL F1-Score :  0.1656294320044773


EPOCH  6


100%|██████████| 1001/1001 [00:38<00:00, 26.33it/s]


Running Average TRAIN Loss :  0.11648161369157242
Running Average TRAIN F1-Score :  0.16883610393647308
Running Average VAL Loss :  0.12283315349902425
Running Average VAL F1-Score :  0.15888071632278816


EPOCH  7


100%|██████████| 1001/1001 [00:37<00:00, 26.60it/s]


Running Average TRAIN Loss :  0.11356009525078517
Running Average TRAIN F1-Score :  0.18733627639420622
Running Average VAL Loss :  0.12277722804407988
Running Average VAL F1-Score :  0.17766014406723635


EPOCH  8


100%|██████████| 1001/1001 [00:39<00:00, 25.16it/s]


Running Average TRAIN Loss :  0.11078822561404802
Running Average TRAIN F1-Score :  0.20694158645300242
Running Average VAL Loss :  0.12223972473293543
Running Average VAL F1-Score :  0.19981384024556195


EPOCH  9


100%|██████████| 1001/1001 [00:37<00:00, 26.79it/s]


Running Average TRAIN Loss :  0.10804230017172588
Running Average TRAIN F1-Score :  0.225697842198652
Running Average VAL Loss :  0.12107926447476659
Running Average VAL F1-Score :  0.19289390676255738


EPOCH  10


100%|██████████| 1001/1001 [00:39<00:00, 25.51it/s]


Running Average TRAIN Loss :  0.10532545226883817
Running Average TRAIN F1-Score :  0.2443946681983702
Running Average VAL Loss :  0.12118569076327342
Running Average VAL F1-Score :  0.19555765683097498


EPOCH  11


100%|██████████| 1001/1001 [00:38<00:00, 26.05it/s]


Running Average TRAIN Loss :  0.10278876154840767
Running Average TRAIN F1-Score :  0.26213499251660055


In [82]:
t5_model(dataset[0][0].reshape(1, -1).to(config.device))

tensor([[ 3.3614,  4.4422,  3.3582,  4.4130,  3.5421,  4.0963,  3.2633,  3.2357,
          2.6108,  3.1686,  3.3675,  3.1777,  2.2656,  2.2527,  1.6129,  1.1627,
          3.4669,  4.2954,  2.3734,  2.2344,  2.4893,  4.0412,  2.5327,  1.0220,
          4.4346,  4.0681,  3.2098,  2.4389,  2.6691,  1.3291,  0.3863,  0.3933,
          0.0571,  0.1333,  2.7415,  0.2767,  2.2885, -0.7042,  2.2657,  1.0918,
         -2.0445,  0.1190,  0.1533, -0.7778, -0.7773, -0.7578,  0.2327, -0.5024,
          0.0200, -0.0803,  3.5225,  2.5142, -0.5563,  2.6985, -0.9831, -0.8499,
          1.1326, -0.9097,  0.9827,  1.3431, -2.7158,  0.8884,  1.2357,  0.9186,
         -0.6056,  3.4156, -1.1062,  1.7882,  0.7484,  2.0584,  0.0669,  2.9543,
         -1.1269,  3.6058, -0.3335,  2.8464, -1.8762, -1.0112,  3.5109,  0.6140,
         -0.8772, -1.1279,  0.4499, -1.8721,  3.3256,  2.7646, -0.1361, -1.1882,
         -1.2810, -1.3150,  3.4049,  0.1128,  0.5204, -1.0620, -1.0783, -1.1083,
         -2.1880, -0.0204,  

In [None]:
def predict(embeddings_source):
    test_dataset = ProteinSequenceDataset(datatype="test", embeddings_source = embeddings_source)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

    if embeddings_source == "T5":
        model = t5_model
    if embeddings_source == "ProtBERT":
        model = protbert_model
    if embeddings_source == "EMS2":
        model = ems2_model

    model.eval()

    labels = pd.read_csv(config.train_labels_path, sep = "\t")
    top_terms = labels.groupby("term")["EntryID"].count().sort_values(ascending=False)
    labels_names = top_terms[:config.num_labels].index.values
    print("GENERATE PREDICTION FOR TEST SET...")

    ids_ = np.empty(shape=(len(test_dataloader)*config.num_labels,), dtype=object)
    go_terms_ = np.empty(shape=(len(test_dataloader)*config.num_labels,), dtype=object)
    confs_ = np.empty(shape=(len(test_dataloader)*config.num_labels,), dtype=np.float32)

    for i, (embed, id) in tqdm(enumerate(test_dataloader)):
        embed = embed.to(config.device)
        confs_[i*config.num_labels:(i+1)*config.num_labels] = torch.nn.functional.sigmoid(model(embed)).squeeze().detach().cpu().numpy()
        ids_[i*config.num_labels:(i+1)*config.num_labels] = id[0]
        go_terms_[i*config.num_labels:(i+1)*config.num_labels] = labels_names

    submission_df = pd.DataFrame(data={"Id" : ids_, "GO term" : go_terms_, "Confidence" : confs_})
    print("PREDICTIONS DONE")
    return submission_df

In [None]:
submission_df = predict("T5")
submission_df.head(50)