<a href="https://colab.research.google.com/github/agemagician/ProtTrans/blob/master/Fine-Tuning/ProtBert-BFD-FineTuning-Localization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Notes:**
This notebook was tested with Colab Pro using P100-16 GB of memory.
1. If you have a GPU with less memory then you should consider the following:
    * Reduce the batch size.
    * Disable fine tuning the per-trained model.
    * Reduce the maximum sequence length.
    * Enable mixed precision.
    * Run it in Google Colab with TPUs.
2. If you have a GPU with more memory then you should consider the following: 
    * Disable gradient checkpointing.

**1. Load necessry libraries including huggingface transformers and pytorch lightning**

In [None]:
tpu = False
if tpu == True:
  !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
  !python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev

In [None]:
!pip install -q test_tube transformers pytorch-nlp git+https://github.com/PyTorchLightning/pytorch-lightning.git

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
[K     |████████████████████████████████| 778kB 5.8MB/s 
[K     |████████████████████████████████| 92kB 9.2MB/s 
[K     |████████████████████████████████| 829kB 19.9MB/s 
[K     |████████████████████████████████| 276kB 32.1MB/s 
[K     |████████████████████████████████| 1.1MB 31.1MB/s 
[K     |████████████████████████████████| 3.0MB 45.4MB/s 
[K     |████████████████████████████████| 890kB 64.5MB/s 
[?25h  Building wheel for pytorch-lightning (PEP 517) ... [?25l[?25hdone
  Building wheel for test-tube (setup.py) ... [?25l[?25hdone
  Building wheel for future (setup.py) ... [?25l[?25hdone
  Building wheel for PyYAML (setup.py) ... [?25l[?25hdone
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone


In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

In [None]:
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, RandomSampler

import pytorch_lightning as pl
from pytorch_lightning.loggers import TestTubeLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.metrics.sklearns import Accuracy

from transformers import BertTokenizer, BertModel

from torchnlp.encoders import LabelEncoder
from torchnlp.datasets.dataset import Dataset
from torchnlp.utils import collate_tensors

import pandas as pd
from test_tube import HyperOptArgumentParser
import os
import re
import requests
from tqdm.auto import tqdm
from datetime import datetime
from collections import OrderedDict
import logging as log
import numpy as np
import glob

**2. Create the Membrane dataset functions**

In [None]:
class Loc_dataset():
    """
    Loads the Dataset from the csv files passed to the parser.
    :param hparams: HyperOptArgumentParser obj containg the path to the data files.
    :param train: flag to return the train set.
    :param val: flag to return the validation set.
    :param test: flag to return the test set.
    Returns:
        - Training Dataset, Development Dataset, Testing Dataset
    """
    def  __init__(self) -> None:
        self.downloadDeeplocDataset()
    
    def downloadDeeplocDataset(self):
        deeplocDatasetTrainUrl = 'https://www.dropbox.com/s/vgdqcl4vzqm9as0/deeploc_per_protein_train.csv?dl=1'
        deeplocDatasetValidUrl = 'https://www.dropbox.com/s/jfzuokrym7nflkp/deeploc_per_protein_test.csv?dl=1'

        datasetFolderPath = 'data/'
        trainFilePath = os.path.join(datasetFolderPath, 'deeploc_per_protein_train.csv')
        testFilePath = os.path.join(datasetFolderPath, 'deeploc_per_protein_test.csv')


        if not os.path.exists(datasetFolderPath):
            os.makedirs(datasetFolderPath)

        def download_file(url, filename):
            response = requests.get(url, stream=True)
            with tqdm.wrapattr(open(filename, "wb"), "write", miniters=1,
                              total=int(response.headers.get('content-length', 0)),
                              desc=filename) as fout:
                for chunk in response.iter_content(chunk_size=4096):
                    fout.write(chunk)

        if not os.path.exists(trainFilePath):
            download_file(deeplocDatasetTrainUrl, trainFilePath)

        if not os.path.exists(testFilePath):
            download_file(deeplocDatasetValidUrl, testFilePath)
            
    def collate_lists(self, seq: list, label: list) -> dict:
        """ Converts each line into a dictionary. """
        collated_dataset = []
        for i in range(len(seq)):
            collated_dataset.append({"seq": str(seq[i]), "label": str(label[i])})
        return collated_dataset

    def load_dataset(self,path):
        df = pd.read_csv(path,names=['input','loc','membrane'],skiprows=1)
        
        seq = list(df['input'])
        label = list(df['loc'])

        assert len(seq) == len(label)
        return Dataset(self.collate_lists(seq, label))

**3. Create the ProtBert pytorch lighting class**

In [None]:
class ProtBertBFDClassifier(pl.LightningModule):
    """
    # https://github.com/minimalist-nlp/lightning-text-classification.git
    
    Sample model to show how to use BERT to classify sentences.
    
    :param hparams: ArgumentParser containing the hyperparameters.
    """

    def __init__(self, hparams) -> None:
        super(ProtBertBFDClassifier, self).__init__()
        self.hparams = hparams
        self.batch_size = self.hparams.batch_size

        self.modelFolderPath = 'models/ProtBert-BFD/'
        self.vocabFilePath = os.path.join(self.modelFolderPath, 'vocab.txt')
        
        self.dataset = Loc_dataset()

        self.metric_acc = Accuracy()

        # build model
        self.__download_model()

        # build model
        self.__build_model()

        # Loss criterion initialization.
        self.__build_loss()

        if self.hparams.nr_frozen_epochs > 0:
            self.freeze_encoder()
        else:
            self._frozen = False
        self.nr_frozen_epochs = self.hparams.nr_frozen_epochs

    def __download_model(self) -> None:
        modelUrl = 'https://www.dropbox.com/s/luv2r115bumo90z/pytorch_model.bin?dl=1'
        configUrl = 'https://www.dropbox.com/s/33en5mbl4wf27om/bert_config.json?dl=1'
        vocabUrl = 'https://www.dropbox.com/s/tffddoqfubkfcsw/vocab.txt?dl=1'

        modelFolderPath = self.modelFolderPath
        modelFilePath = os.path.join(modelFolderPath, 'pytorch_model.bin')
        configFilePath = os.path.join(modelFolderPath, 'config.json')
        vocabFilePath = self.vocabFilePath

        if not os.path.exists(modelFolderPath):
            os.makedirs(modelFolderPath)

        def download_file(url, filename):
            response = requests.get(url, stream=True)
            with tqdm.wrapattr(open(filename, "wb"), "write", miniters=1,
                              total=int(response.headers.get('content-length', 0)),
                              desc=filename) as fout:
                for chunk in response.iter_content(chunk_size=4096):
                    fout.write(chunk)

        if not os.path.exists(modelFilePath):
            download_file(modelUrl, modelFilePath)

        if not os.path.exists(configFilePath):
            download_file(configUrl, configFilePath)

        if not os.path.exists(vocabFilePath):
            download_file(vocabUrl, vocabFilePath)
            

    def __build_model(self) -> None:
        """ Init BERT model + tokenizer + classification head."""
        self.ProtBertBFD = BertModel.from_pretrained(self.modelFolderPath,gradient_checkpointing=self.hparams.gradient_checkpointing)
        self.encoder_features = 1024

        # Tokenizer
        self.tokenizer = BertTokenizer(self.vocabFilePath, do_lower_case=False)

        # Label Encoder
        self.label_encoder = LabelEncoder(
            self.hparams.label_set.split(","), reserved_labels=[]
        )
        self.label_encoder.unknown_index = None

        # Classification head
        self.classification_head = nn.Sequential(
            nn.Linear(self.encoder_features*4, self.label_encoder.vocab_size),
            nn.Tanh(),
        )

    def __build_loss(self):
        """ Initializes the loss function/s. """
        self._loss = nn.CrossEntropyLoss()

    def unfreeze_encoder(self) -> None:
        """ un-freezes the encoder layer. """
        if self._frozen:
            log.info(f"\n-- Encoder model fine-tuning")
            for param in self.ProtBertBFD.parameters():
                param.requires_grad = True
            self._frozen = False

    def freeze_encoder(self) -> None:
        """ freezes the encoder layer. """
        for param in self.ProtBertBFD.parameters():
            param.requires_grad = False
        self._frozen = True

    def predict(self, sample: dict) -> dict:
        """ Predict function.
        :param sample: dictionary with the text we want to classify.
        Returns:
            Dictionary with the input text and the predicted label.
        """
        if self.training:
            self.eval()

        with torch.no_grad():
            model_input, _ = self.prepare_sample([sample], prepare_target=False)
            model_out = self.forward(**model_input)
            logits = model_out["logits"].numpy()
            predicted_labels = [
                self.label_encoder.index_to_token[prediction]
                for prediction in np.argmax(logits, axis=1)
            ]
            sample["predicted_label"] = predicted_labels[0]

        return sample
    
    # https://github.com/UKPLab/sentence-transformers/blob/eb39d0199508149b9d32c1677ee9953a84757ae4/sentence_transformers/models/Pooling.py
    def pool_strategy(self, features,
                      pool_cls=True, pool_max=True, pool_mean=True,
                      pool_mean_sqrt=True):
        token_embeddings = features['token_embeddings']
        cls_token = features['cls_token_embeddings']
        attention_mask = features['attention_mask']

        ## Pooling strategy
        output_vectors = []
        if pool_cls:
            output_vectors.append(cls_token)
        if pool_max:
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            token_embeddings[input_mask_expanded == 0] = -1e9  # Set padding tokens to large negative value
            max_over_time = torch.max(token_embeddings, 1)[0]
            output_vectors.append(max_over_time)
        if pool_mean or pool_mean_sqrt:
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)

            #If tokens are weighted (by WordWeights layer), feature 'token_weights_sum' will be present
            if 'token_weights_sum' in features:
                sum_mask = features['token_weights_sum'].unsqueeze(-1).expand(sum_embeddings.size())
            else:
                sum_mask = input_mask_expanded.sum(1)

            sum_mask = torch.clamp(sum_mask, min=1e-9)

            if pool_mean:
                output_vectors.append(sum_embeddings / sum_mask)
            if pool_mean_sqrt:
                output_vectors.append(sum_embeddings / torch.sqrt(sum_mask))

        output_vector = torch.cat(output_vectors, 1)
        return output_vector
    
    def forward(self, input_ids, token_type_ids, attention_mask):
        """ Usual pytorch forward function.
        :param tokens: text sequences [batch_size x src_seq_len]
        :param lengths: source lengths [batch_size]
        Returns:
            Dictionary with model outputs (e.g: logits)
        """
        input_ids = torch.tensor(input_ids, device=self.device)
        attention_mask = torch.tensor(attention_mask,device=self.device)

        word_embeddings = self.ProtBertBFD(input_ids,
                                           attention_mask)[0]

        pooling = self.pool_strategy({"token_embeddings": word_embeddings,
                                      "cls_token_embeddings": word_embeddings[:, 0],
                                      "attention_mask": attention_mask,
                                      })
        return {"logits": self.classification_head(pooling)}

    def loss(self, predictions: dict, targets: dict) -> torch.tensor:
        """
        Computes Loss value according to a loss function.
        :param predictions: model specific output. Must contain a key 'logits' with
            a tensor [batch_size x 1] with model predictions
        :param labels: Label values [batch_size]
        Returns:
            torch.tensor with loss value.
        """
        return self._loss(predictions["logits"], targets["labels"])

    def prepare_sample(self, sample: list, prepare_target: bool = True) -> (dict, dict):
        """
        Function that prepares a sample to input the model.
        :param sample: list of dictionaries.
        
        Returns:
            - dictionary with the expected model inputs.
            - dictionary with the expected target labels.
        """
        sample = collate_tensors(sample)

        inputs = self.tokenizer.batch_encode_plus(sample["seq"],
                                                  add_special_tokens=True,
                                                  padding=True,
                                                  truncation=True,
                                                  max_length=self.hparams.max_length)

        if not prepare_target:
            return inputs, {}

        # Prepare target:
        try:
            targets = {"labels": self.label_encoder.batch_encode(sample["label"])}
            return inputs, targets
        except RuntimeError:
            print(sample["label"])
            raise Exception("Label encoder found an unknown label.")

    def training_step(self, batch: tuple, batch_nb: int, *args, **kwargs) -> dict:
        """ 
        Runs one training step. This usually consists in the forward function followed
            by the loss function.
        
        :param batch: The output of your dataloader. 
        :param batch_nb: Integer displaying which batch this is
        Returns:
            - dictionary containing the loss and the metrics to be added to the lightning logger.
        """
        inputs, targets = batch
        model_out = self.forward(**inputs)
        loss_val = self.loss(model_out, targets)

        tqdm_dict = {"train_loss": loss_val}
        output = OrderedDict(
            {"loss": loss_val, "progress_bar": tqdm_dict, "log": tqdm_dict})

        # can also return just a scalar instead of a dict (return loss_val)
        return output

    def validation_step(self, batch: tuple, batch_nb: int, *args, **kwargs) -> dict:
        """ Similar to the training step but with the model in eval mode.
        Returns:
            - dictionary passed to the validation_end function.
        """
        inputs, targets = batch

        model_out = self.forward(**inputs)
        loss_val = self.loss(model_out, targets)

        y = targets["labels"]
        y_hat = model_out["logits"]
        
        labels_hat = torch.argmax(y_hat, dim=1)
        val_acc = self.metric_acc(labels_hat, y)
        
        output = OrderedDict({"val_loss": loss_val, "val_acc": val_acc,})

        return output
        
    
    def validation_epoch_end(self, outputs: list) -> dict:
        """ Function that takes as input a list of dictionaries returned by the validation_step
        function and measures the model performance accross the entire validation set.
        
        Returns:
            - Dictionary with metrics to be added to the lightning logger.  
        """

        val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean()
        val_acc_mean = torch.stack([x['val_acc'] for x in outputs]).mean()

        tqdm_dict = {"val_loss": val_loss_mean, "val_acc": val_acc_mean}
        result = {
            "progress_bar": tqdm_dict,
            "log": tqdm_dict,
            "val_loss": val_loss_mean,
        }
        return result
    

    def test_step(self, batch: tuple, batch_nb: int, *args, **kwargs) -> dict:
        """ Similar to the training step but with the model in eval mode.
        Returns:
            - dictionary passed to the validation_end function.
        """
        inputs, targets = batch
        model_out = self.forward(**inputs)
        loss_test = self.loss(model_out, targets)

        y = targets["labels"]
        y_hat = model_out["logits"]
        
        labels_hat = torch.argmax(y_hat, dim=1)
        test_acc = self.metric_acc(labels_hat, y)
        
        output = OrderedDict({"test_loss": loss_test, "test_acc": test_acc,})

        return output

    def test_epoch_end(self, outputs: list) -> dict:
        """ Function that takes as input a list of dictionaries returned by the validation_step
        function and measures the model performance accross the entire validation set.
        
        Returns:
            - Dictionary with metrics to be added to the lightning logger.  
        """
        test_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean()
        test_acc_mean = torch.stack([x['test_acc'] for x in outputs]).mean()

        tqdm_dict = {"test_loss": test_loss_mean, "test_acc": test_acc_mean}
        result = {
            "progress_bar": tqdm_dict,
            "log": tqdm_dict,
            "test_loss": test_loss_mean,
        }
        return result

    def configure_optimizers(self):
        """ Sets different Learning rates for different parameter groups. """
        parameters = [
            {"params": self.classification_head.parameters()},
            {
                "params": self.ProtBertBFD.parameters(),
                "lr": self.hparams.encoder_learning_rate,
            },
        ]
        optimizer = optim.Adam(parameters, lr=self.hparams.learning_rate)
        return [optimizer], []

    def on_epoch_end(self):
        """ Pytorch lightning hook """
        if self.current_epoch + 1 >= self.nr_frozen_epochs:
            self.unfreeze_encoder()

    def __retrieve_dataset(self, train=True, val=True, test=True):
        """ Retrieves task specific dataset """
        if train:
            return self.dataset.load_dataset(hparams.train_csv)
        elif val:
            return self.dataset.load_dataset(hparams.dev_csv)
        elif test:
            return self.dataset.load_dataset(hparams.test_csv)
        else:
            print('Incorrect dataset split')

    def train_dataloader(self) -> DataLoader:
        """ Function that loads the train set. """
        self._train_dataset = self.__retrieve_dataset(val=False, test=False)
        return DataLoader(
            dataset=self._train_dataset,
            sampler=RandomSampler(self._train_dataset),
            batch_size=self.hparams.batch_size,
            collate_fn=self.prepare_sample,
            num_workers=self.hparams.loader_workers,
        )

    def val_dataloader(self) -> DataLoader:
        """ Function that loads the validation set. """
        self._dev_dataset = self.__retrieve_dataset(train=False, test=False)
        return DataLoader(
            dataset=self._dev_dataset,
            batch_size=self.hparams.batch_size,
            collate_fn=self.prepare_sample,
            num_workers=self.hparams.loader_workers,
        )

    def test_dataloader(self) -> DataLoader:
        """ Function that loads the validation set. """
        self._test_dataset = self.__retrieve_dataset(train=False, val=False)
        return DataLoader(
            dataset=self._test_dataset,
            batch_size=self.hparams.batch_size,
            collate_fn=self.prepare_sample,
            num_workers=self.hparams.loader_workers,
        )

    @classmethod
    def add_model_specific_args(
        cls, parser: HyperOptArgumentParser
    ) -> HyperOptArgumentParser:
        """ Parser for Estimator specific arguments/hyperparameters. 
        :param parser: HyperOptArgumentParser obj
        Returns:
            - updated parser
        """
        parser.opt_list(
            "--max_length",
            default=1536,
            type=int,
            help="Maximum sequence length.",
        )
        parser.add_argument(
            "--encoder_learning_rate",
            default=5e-06,
            type=float,
            help="Encoder specific learning rate.",
        )
        parser.add_argument(
            "--learning_rate",
            default=3e-05,
            type=float,
            help="Classification head learning rate.",
        )
        parser.opt_list(
            "--nr_frozen_epochs",
            default=1,
            type=int,
            help="Number of epochs we want to keep the encoder model frozen.",
            tunable=True,
            options=[0, 1, 2, 3, 4, 5],
        )
        # Data Args:
        parser.add_argument(
            "--label_set",
            default="Cell.membrane,Cytoplasm,Endoplasmic.reticulum,Extracellular,Golgi.apparatus,Lysosome/Vacuole,Mitochondrion,Nucleus,Peroxisome,Plastid",
            type=str,
            help="Classification labels set.",
        )
        parser.add_argument(
            "--train_csv",
            default="data/deeploc_per_protein_train.csv",
            type=str,
            help="Path to the file containing the train data.",
        )
        parser.add_argument(
            "--dev_csv",
            default="data/deeploc_per_protein_test.csv",
            type=str,
            help="Path to the file containing the dev data.",
        )
        parser.add_argument(
            "--test_csv",
            default="data/deeploc_per_protein_test.csv",
            type=str,
            help="Path to the file containing the dev data.",
        )
        parser.add_argument(
            "--loader_workers",
            default=8,
            type=int,
            help="How many subprocesses to use for data loading. 0 means that \
                the data will be loaded in the main process.",
        )
        parser.add_argument(
            "--gradient_checkpointing",
            default=True,
            type=bool,
            help="Enable or disable gradient checkpointing which use the cpu memory \
                with the gpu memory to store the model.",
        )
        return parser

**4. Create the experiments folder**

In [None]:
def setup_testube_logger() -> TestTubeLogger:
    """ Function that sets the TestTubeLogger to be used. """
    now = datetime.now()
    dt_string = now.strftime("%d-%m-%Y--%H-%M-%S")

    return TestTubeLogger(
        save_dir="experiments/",
        version=dt_string,
        name="lightning_logs",
    )

logger = setup_testube_logger()

**5. Create traininng arguments**

In [None]:
# these are project-wide arguments
parser = HyperOptArgumentParser(
    strategy="random_search",
    description="Minimalist ProtBERT Classifier",
    add_help=True,
)
parser.add_argument("--seed", type=int, default=3, help="Training seed.")
parser.add_argument(
    "--save_top_k",
    default=1,
    type=int,
    help="The best k models according to the quantity monitored will be saved.",
)
# Early Stopping
parser.add_argument(
    "--monitor", default="val_acc", type=str, help="Quantity to monitor."
)
parser.add_argument(
    "--metric_mode",
    default="max",
    type=str,
    help="If we want to min/max the monitored quantity.",
    choices=["auto", "min", "max"],
)
parser.add_argument(
    "--patience",
    default=5,
    type=int,
    help=(
        "Number of epochs with no improvement "
        "after which training will be stopped."
    ),
)
parser.add_argument(
    "--min_epochs",
    default=1,
    type=int,
    help="Limits training to a minimum number of epochs",
)
parser.add_argument(
    "--max_epochs",
    default=100,
    type=int,
    help="Limits training to a max number number of epochs",
)

# Batching
parser.add_argument(
    "--batch_size", default=1, type=int, help="Batch size to be used."
)
parser.add_argument(
    "--accumulate_grad_batches",
    default=8,
    type=int,
    help=(
        "Accumulated gradients runs K small batches of size N before "
        "doing a backwards pass."
    ),
)

# gpu/tpu args
parser.add_argument("--gpus", type=int, default=1, help="How many gpus")
parser.add_argument("--tpu_cores", type=int, default=None, help="How many tpus")
parser.add_argument(
    "--val_percent_check",
    default=1.0,
    type=float,
    help=(
        "If you don't want to use the entire dev set (for debugging or "
        "if it's huge), set how much of the dev set you want to use with this flag."
    ),
)

# mixed precision
parser.add_argument("--precision", type=int, default="32", help="full precision or mixed precision mode")
parser.add_argument("--amp_level", type=str, default="O1", help="mixed precision type")

# each LightningModule defines arguments relevant to it
parser = ProtBertBFDClassifier.add_model_specific_args(parser)
hparams = parser.parse_known_args()[0]

**6. Main Training steps**

In [None]:
"""
Main training routine specific for this project
:param hparams:
"""
seed_everything(hparams.seed)

# ------------------------
# 1 INIT LIGHTNING MODEL
# ------------------------
model = ProtBertBFDClassifier(hparams)

# ------------------------
# 2 INIT EARLY STOPPING
# ------------------------
early_stop_callback = EarlyStopping(
    monitor=hparams.monitor,
    min_delta=0.0,
    patience=hparams.patience,
    verbose=True,
    mode=hparams.metric_mode,
)

# --------------------------------
# 3 INIT MODEL CHECKPOINT CALLBACK
# -------------------------------
ckpt_path = os.path.join(
    logger.save_dir,
    logger.name,
    f"version_{logger.version}",
    "checkpoints",
)
# initialize Model Checkpoint Saver
checkpoint_callback = ModelCheckpoint(
    filepath=ckpt_path + "/" + "{epoch}-{val_loss:.2f}-{val_acc:.2f}",
    save_top_k=hparams.save_top_k,
    verbose=True,
    monitor=hparams.monitor,
    period=1,
    mode=hparams.metric_mode,
)

# ------------------------
# 4 INIT TRAINER
# ------------------------
trainer = Trainer(
    gpus=hparams.gpus,
    tpu_cores=hparams.tpu_cores,
    logger=logger,
    early_stop_callback=early_stop_callback,
    distributed_backend="dp",
    max_epochs=hparams.max_epochs,
    min_epochs=hparams.min_epochs,
    accumulate_grad_batches=hparams.accumulate_grad_batches,
    val_percent_check=hparams.val_percent_check,
    checkpoint_callback = checkpoint_callback,
    precision=hparams.precision,
    amp_level=hparams.amp_level,
    deterministic=True
)

HBox(children=(FloatProgress(value=0.0, description='data/deeploc_per_protein_train.csv', max=7417223.0, style…




HBox(children=(FloatProgress(value=0.0, description='data/deeploc_per_protein_test.csv', max=2076681.0, style=…




HBox(children=(FloatProgress(value=0.0, description='models/ProtBert-BFD/pytorch_model.bin', max=1684058277.0,…




HBox(children=(FloatProgress(value=0.0, description='models/ProtBert-BFD/config.json', max=313.0, style=Progre…




HBox(children=(FloatProgress(value=0.0, description='models/ProtBert-BFD/vocab.txt', max=81.0, style=ProgressS…




GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


In [None]:
# ------------------------
# 5 START TENSORBOARD
# ------------------------
%tensorboard --logdir "experiments/lightning_logs/"

In [None]:
# ------------------------
# 6 START TRAINING
# ------------------------
# 0.77633 
#         same but with dropout 0.2
trainer.fit(model)


  | Name                | Type             | Params
---------------------------------------------------------
0 | metric_acc          | Accuracy         | 0     
1 | ProtBertBFD         | BertModel        | 419 M 
2 | classification_head | Sequential       | 40 K  
3 | _loss               | CrossEntropyLoss | 0     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…





HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00000: val_acc reached 0.56026 (best 0.56026), saving model to /content/experiments/lightning_logs/version_15-08-2020--08-15-00/checkpoints/epoch=0-val_loss=1.61-val_acc=0.56.ckpt as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00001: val_acc reached 0.71064 (best 0.71064), saving model to /content/experiments/lightning_logs/version_15-08-2020--08-15-00/checkpoints/epoch=1-val_loss=1.26-val_acc=0.71.ckpt as top 1
INFO:lightning:
Epoch 00001: val_acc reached 0.71064 (best 0.71064), saving model to /content/experiments/lightning_logs/version_15-08-2020--08-15-00/checkpoints/epoch=1-val_loss=1.26-val_acc=0.71.ckpt as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00002: val_acc reached 0.75679 (best 0.75679), saving model to /content/experiments/lightning_logs/version_15-08-2020--08-15-00/checkpoints/epoch=2-val_loss=1.20-val_acc=0.76.ckpt as top 1
INFO:lightning:
Epoch 00002: val_acc reached 0.75679 (best 0.75679), saving model to /content/experiments/lightning_logs/version_15-08-2020--08-15-00/checkpoints/epoch=2-val_loss=1.20-val_acc=0.76.ckpt as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00003: val_acc  was not in top 1
INFO:lightning:
Epoch 00003: val_acc  was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00004: val_acc  was not in top 1
INFO:lightning:
Epoch 00004: val_acc  was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00005: val_acc reached 0.76873 (best 0.76873), saving model to /content/experiments/lightning_logs/version_15-08-2020--08-15-00/checkpoints/epoch=5-val_loss=1.20-val_acc=0.77.ckpt as top 1
INFO:lightning:
Epoch 00005: val_acc reached 0.76873 (best 0.76873), saving model to /content/experiments/lightning_logs/version_15-08-2020--08-15-00/checkpoints/epoch=5-val_loss=1.20-val_acc=0.77.ckpt as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00006: val_acc  was not in top 1
INFO:lightning:
Epoch 00006: val_acc  was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00007: val_acc  was not in top 1
INFO:lightning:
Epoch 00007: val_acc  was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00008: val_acc  was not in top 1
INFO:lightning:
Epoch 00008: val_acc  was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00009: val_acc  was not in top 1
INFO:lightning:
Epoch 00009: val_acc  was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00010: val_acc reached 0.77144 (best 0.77144), saving model to /content/experiments/lightning_logs/version_15-08-2020--08-15-00/checkpoints/epoch=10-val_loss=1.20-val_acc=0.77.ckpt as top 1
INFO:lightning:
Epoch 00010: val_acc reached 0.77144 (best 0.77144), saving model to /content/experiments/lightning_logs/version_15-08-2020--08-15-00/checkpoints/epoch=10-val_loss=1.20-val_acc=0.77.ckpt as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00011: val_acc reached 0.77633 (best 0.77633), saving model to /content/experiments/lightning_logs/version_15-08-2020--08-15-00/checkpoints/epoch=11-val_loss=1.19-val_acc=0.78.ckpt as top 1
INFO:lightning:
Epoch 00011: val_acc reached 0.77633 (best 0.77633), saving model to /content/experiments/lightning_logs/version_15-08-2020--08-15-00/checkpoints/epoch=11-val_loss=1.19-val_acc=0.78.ckpt as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00012: val_acc  was not in top 1
INFO:lightning:
Epoch 00012: val_acc  was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00013: val_acc  was not in top 1
INFO:lightning:
Epoch 00013: val_acc  was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00014: val_acc  was not in top 1
INFO:lightning:
Epoch 00014: val_acc  was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00015: val_acc  was not in top 1
INFO:lightning:
Epoch 00015: val_acc  was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00016: val_acc  was not in top 1
INFO:lightning:
Epoch 00016: val_acc  was not in top 1
Epoch 00017: early stopping triggered.
INFO:lightning:Epoch 00017: early stopping triggered.





1

**7. Test the best checkpoint**

In [None]:
best_checkpoint_path = glob.glob(ckpt_path + "/*")[0]
print(best_checkpoint_path)

experiments/lightning_logs/version_15-08-2020--08-15-00/checkpoints/epoch=11-val_loss=1.19-val_acc=0.78.ckpt


In [None]:
trainer.resume_from_checkpoint = best_checkpoint_path

In [None]:
trainer.test(model)



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': tensor(0.7763, device='cuda:0'),
 'test_loss': tensor(1.1938, device='cuda:0')}
--------------------------------------------------------------------------------



[{'test_acc': 0.7763301134109497, 'test_loss': 1.1937729120254517}]

**8. Predict new sequence**

In [None]:
model = model.load_from_checkpoint(best_checkpoint_path)

model.eval()
model.freeze()

In [None]:
sample = {
  "seq": "M S T D T G V S L P S Y E E D Q G S K L I R K A K E A P F V P V G I A G F A A I V A Y G L Y K L K S R G N T K M S I H L I H M R V A A Q G F V V G A M T V G M G Y S M Y R E F W A K P K P",
}

predictions = model.predict(sample)

print("Sequence Localization Ground Truth is: {} - prediction is: {}".format('Mitochondrion',predictions['predicted_label']))



Sequence Localization Ground Truth is: Mitochondrion - prediction is: Mitochondrion


In [None]:
sample = {
  "seq": "M R C L P V F I I L L L L I P S A P S V D A Q P T T K D D V P L A S L H D N A K R A L Q M F W N K R D C C P A K L L C C N P",
}

predictions = model.predict(sample)

print("Sequence Localization Ground Truth is: {} - prediction is: {}".format('Extracellular',predictions['predicted_label']))



Sequence Localization Ground Truth is: Extracellular - prediction is: Extracellular
