In [3]:
from torchmetrics.classification import MultilabelAccuracy
from torchmetrics.classification import MultilabelF1Score
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import random_split
from torch import nn
from torch.utils.data import Dataset
import torch
import numpy as np
from tqdm import tqdm
import time
import pandas as pd
import os
import matplotlib.pyplot as plt

plt.style.use("ggplot")

# TORCH MODULES FOR METRICS COMPUTATION :

In [6]:
data_dir = "/gscratch/rao/aresf/Code/CAFA5/data"


class config:
    train_sequences_path = data_dir + "/Train/train_sequences.fasta"
    train_labels_path = data_dir + "/Train/train_terms.tsv"
    test_sequences_path = data_dir + "/Test/testsuperset.fasta"

    num_labels = 500
    n_epochs = 5
    batch_size = 128
    lr = 0.001

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
def generate_labels(config):
    print(
        "GENERATE TARGETS FOR ENTRY IDS ("
        + str(config.num_labels)
        + " MOST COMMON GO TERMS)"
    )
    ids = np.load("../data/train_ids.npy")
    labels = pd.read_csv(config.train_labels_path, sep="\t")

    top_terms = labels.groupby(
        "term")["EntryID"].count().sort_values(ascending=False)
    labels_names = top_terms[: config.num_labels].index.values
    train_labels_sub = labels[
        (labels.term.isin(labels_names)) & (labels.EntryID.isin(ids))
    ]
    id_labels = train_labels_sub.groupby(
        "EntryID")["term"].apply(list).to_dict()

    go_terms_map = {label: i for i, label in enumerate(labels_names)}
    labels_matrix = np.empty((len(ids), len(labels_names)))

    for index, id in tqdm(enumerate(ids)):
        id_gos_list = id_labels[id]
        temp = [go_terms_map[go] for go in labels_names if go in id_gos_list]
        labels_matrix[index, temp] = 1

    labels_list = []
    for l in range(labels_matrix.shape[0]):
        labels_list.append(labels_matrix[l, :])

    labels_df = pd.DataFrame(data={"EntryID": ids, "labels_vect": labels_list})
    labels_df.to_pickle("../data/train_targets_top" +
                        str(config.num_labels) + ".pkl")
    print("GENERATION FINISHED!")

In [11]:
# Directories for the different embedding vectors :
embeds_map = {
    "T5": "t5embeds",
    "ProtBERT": "protbert-embeddings-for-cafa5",
    "EMS2": "cafa-5-ems-2-embeddings-numpy",
}

# Length of the different embedding vectors :
embeds_dim = {"T5": 1024, "ProtBERT": 1024, "EMS2": 1280}


class ProteinSequenceDataset(Dataset):

    def __init__(self, datatype, embeddings_source):
        super(ProteinSequenceDataset).__init__()
        self.datatype = datatype

        if embeddings_source in ["ProtBERT", "EMS2"]:
            embeds = np.load("../data" + "/" + datatype + "_embeddings.npy")
            ids = np.load("../data" + "/" + datatype + "_ids.npy")

        embeds_list = []
        for l in range(embeds.shape[0]):
            embeds_list.append(embeds[l, :])
        self.df = pd.DataFrame(data={"EntryID": ids, "embed": embeds_list})

        if datatype == "train":
            df_labels = pd.read_pickle(
                "../data/train_targets_top" + str(config.num_labels) + ".pkl"
            )
            self.df = self.df.merge(df_labels, on="EntryID")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        embed = torch.tensor(self.df.iloc[index]["embed"], dtype=torch.float32)
        if self.datatype == "train":
            targets = torch.tensor(
                self.df.iloc[index]["labels_vect"], dtype=torch.float32
            )
            return embed, targets
        if self.datatype == "test":
            id = self.df.iloc[index]["EntryID"]
            return embed, id

In [19]:
class MultiLayerPerceptron(torch.nn.Module):

    def __init__(self, input_dim, num_classes):
        super(MultiLayerPerceptron, self).__init__()

        self.linear1 = torch.nn.Linear(input_dim, 1012)
        self.activation1 = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(1012, 712)
        self.activation2 = torch.nn.ReLU()
        self.linear3 = torch.nn.Linear(712, num_classes)

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation1(x)
        x = self.linear2(x)
        x = self.activation2(x)
        x = self.linear3(x)
        return x


class CNN1D(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(CNN1D, self).__init__()
        # (batch_size, channels, embed_size)
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=3,
                               kernel_size=3, dilation=1, padding=1, stride=1)
        # (batch_size, 3, embed_size)
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
        # (batch_size, 3, embed_size/2 = 512)
        self.conv2 = nn.Conv1d(in_channels=3, out_channels=8,
                               kernel_size=3, dilation=1, padding=1, stride=1)
        # (batch_size, 8, embed_size/2 = 512)
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
        # (batch_size, 8, embed_size/4 = 256)
        self.fc1 = nn.Linear(in_features=int(
            8 * input_dim/4), out_features=128)
        self.fc2 = nn.Linear(in_features=128, out_features=num_classes)

    def forward(self, x):
        x = x.reshape(x.shape[0], 1, x.shape[1])
        x = self.pool1(nn.functional.relu(self.conv1(x)))
        x = self.pool2(nn.functional.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [20]:
def train_model(embeddings_source, model_type="linear", train_size=0.9):

    train_dataset = ProteinSequenceDataset(
        datatype="train", embeddings_source=embeddings_source)

    train_set, val_set = random_split(train_dataset, lengths=[int(len(
        train_dataset)*train_size), len(train_dataset)-int(len(train_dataset)*train_size)])
    train_dataloader = torch.utils.data.DataLoader(
        train_set, batch_size=config.batch_size, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(
        val_set, batch_size=config.batch_size, shuffle=True)

    if model_type == "linear":
        model = MultiLayerPerceptron(
            input_dim=embeds_dim[embeddings_source], num_classes=config.num_labels).to(config.device)
    if model_type == "convolutional":
        model = CNN1D(input_dim=embeds_dim[embeddings_source],
                      num_classes=config.num_labels).to(config.device)

    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
    scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=1)
    CrossEntropy = torch.nn.CrossEntropyLoss()
    f1_score = MultilabelF1Score(
        num_labels=config.num_labels).to(config.device)
    n_epochs = config.n_epochs

    print("BEGIN TRAINING...")
    train_loss_history = []
    val_loss_history = []

    train_f1score_history = []
    val_f1score_history = []
    for epoch in range(n_epochs):
        print("EPOCH ", epoch+1)
        # TRAIN PHASE :
        losses = []
        scores = []
        for embed, targets in tqdm(train_dataloader):
            embed, targets = embed.to(config.device), targets.to(config.device)
            optimizer.zero_grad()
            preds = model(embed)
            loss = CrossEntropy(preds, targets)
            score = f1_score(preds, targets)
            losses.append(loss.item())
            scores.append(score.item())
            loss.backward()
            optimizer.step()
        avg_loss = np.mean(losses)
        avg_score = np.mean(scores)
        print("Running Average TRAIN Loss : ", avg_loss)
        print("Running Average TRAIN F1-Score : ", avg_score)
        train_loss_history.append(avg_loss)
        train_f1score_history.append(avg_score)

        # VALIDATION PHASE :
        losses = []
        scores = []
        for embed, targets in val_dataloader:
            embed, targets = embed.to(config.device), targets.to(config.device)
            preds = model(embed)
            loss = CrossEntropy(preds, targets)
            score = f1_score(preds, targets)
            losses.append(loss.item())
            scores.append(score.item())
        avg_loss = np.mean(losses)
        avg_score = np.mean(scores)
        print("Running Average VAL Loss : ", avg_loss)
        print("Running Average VAL F1-Score : ", avg_score)
        val_loss_history.append(avg_loss)
        val_f1score_history.append(avg_score)

        scheduler.step(avg_loss)
        print("\n")

    print("TRAINING FINISHED")
    print("FINAL TRAINING SCORE : ", train_f1score_history[-1])
    print("FINAL VALIDATION SCORE : ", val_f1score_history[-1])

    losses_history = {"train": train_loss_history, "val": val_loss_history}
    scores_history = {"train": train_f1score_history,
                      "val": val_f1score_history}

    return model, losses_history, scores_history

In [21]:
model, losses, scores = train_model(
    embeddings_source="ProtBERT", model_type="convolutional"
)

BEGIN TRAINING...
EPOCH  1


100%|██████████| 1001/1001 [00:17<00:00, 55.83it/s]


Running Average TRAIN Loss :  141.17737231030688
Running Average TRAIN F1-Score :  0.05762168635185305
Running Average VAL Loss :  139.7490392412458
Running Average VAL F1-Score :  0.08033394248091749


EPOCH  2


100%|██████████| 1001/1001 [00:15<00:00, 66.56it/s]


Running Average TRAIN Loss :  138.0984937075254
Running Average TRAIN F1-Score :  0.10020145817429989
Running Average VAL Loss :  137.7405114855085
Running Average VAL F1-Score :  0.11651615951476353


EPOCH  3


100%|██████████| 1001/1001 [00:14<00:00, 68.27it/s]


Running Average TRAIN Loss :  136.762160645141
Running Average TRAIN F1-Score :  0.12057363572386237
Running Average VAL Loss :  136.51682104383195
Running Average VAL F1-Score :  0.12397308521238821


EPOCH  4


100%|██████████| 1001/1001 [00:14<00:00, 66.92it/s]


Running Average TRAIN Loss :  136.02096755021103
Running Average TRAIN F1-Score :  0.13082024394036768
Running Average VAL Loss :  136.07698331560408
Running Average VAL F1-Score :  0.13541990106127091


EPOCH  5


100%|██████████| 1001/1001 [00:15<00:00, 65.36it/s]


Running Average TRAIN Loss :  135.45669509933427
Running Average TRAIN F1-Score :  0.13802178239042348
Running Average VAL Loss :  135.967351096017
Running Average VAL F1-Score :  0.14094462125961268


TRAINING FINISHED
FINAL TRAINING SCORE :  0.13802178239042348
FINAL VALIDATION SCORE :  0.14094462125961268
