In [4]:
# Dependencies
!pip3 install pytorch_lightning torch torchvision torchaudio transformers sentencepiece accelerate --extra-index-url https://download.pytorch.org/whl/cu116

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/, https://download.pytorch.org/whl/cu116
Collecting pytorch_lightning
  Downloading pytorch_lightning-2.0.0-py3-none-any.whl (715 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m715.6/715.6 KB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.7.0
  Downloading lightning_utilities-0.8.0-py3-none-any.whl (20 kB)
Collecting torchmetrics>=0.7.0
  Downloading torchmetrics-0.11.4-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.2/519.2 KB[0m [31m17.5 MB/s[0m eta [36m0:00:00[0m
Collecting aiohttp!=4.0.0a0,!=4.0.0a1
  Downloading aiohttp-3.8.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m21.2 MB/s[0m eta [36m0:00:00[0m
Collecting async-timeout<5.0,>=4.0.0a3
  Downloading async_timeout-4.0.2-py3-none

In [5]:
from transformers import T5Tokenizer, T5EncoderModel
import torch
import re

import pytorch_lightning as pl
from torchnlp.encoders import LabelEncoder
from torchnlp.datasets.dataset import Dataset
from torchnlp.utils import collate_tensors

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Using device: {}".format(device))

Using device: cuda:0


In [6]:
# Load ProtT5 in half-precision (more specifically: the encoder-part of ProtT5-XL-U50 in half-precision)

transformer_link = "Rostlab/prot_t5_xl_half_uniref50-enc"
print("Loading: {}".format(transformer_link))
model = T5EncoderModel.from_pretrained(transformer_link)

# *What does full-precision mean and why only use it on cpu?
model.full() if device=='cpu' else model.half() # only cast to full-precision if no GPU is available
model = model.to(device)
model = model.eval()
tokenizer = T5Tokenizer.from_pretrained(transformer_link, do_lower_case=False )

Loading: Rostlab/prot_t5_xl_half_uniref50-enc


Downloading pytorch_model.bin:   0%|          | 0.00/2.42G [00:00<?, ?B/s]

In [7]:
# This cell shows how to extract embedded layers.

# fake data from the notebook example
sequence_examples = ["PRTEINO", "SEQWENCE"]

# **Not sure what this is for. Another good thing to circle back with the 
#   collaborator about. Do we need to be concerned with "rare/ambiguous" amino acids?
#   Are there other edge cases or sections of data that we'll need to think
#   about when preparing the data?
# this will replace all rare/ambiguous amino acids by X and introduce white-space between all amino acids
sequence_examples = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in sequence_examples]


# tokenize sequences and pad up to the longest sequence in the batch
# *What are special tokens here?
# *Look into the batch encode plus method - how does it work. Should be able to explain it to soemone else
ids = tokenizer.batch_encode_plus(sequence_examples, add_special_tokens=True, padding="longest") # this padding option is great/convenient
input_ids = torch.tensor(ids['input_ids']).to(device)
attention_mask = torch.tensor(ids['attention_mask']).to(device)


# generate embeddings
with torch.no_grad():
    embedding_repr = model(input_ids=input_ids,attention_mask=attention_mask)

# **Look into this architecture more. I don't understand why they are pulling the dimensions they are pulling (i.e. [0,:7] and [1,:8])
# extract embeddings for the first ([0,:]) sequence in the batch while removing padded & special tokens ([0,:7]) 
emb_0 = embedding_repr.last_hidden_state[0,:7] # shape (7 x 1024)
print(f"Shape of per-residue embedding of first sequences: {emb_0.shape}")

# *Based on this and the previous two lines, it seems like you can ignore padding
#  and special tokens by taking only up to :7. I don't fully understand this - how does it work?
# do the same for the second ([1,:]) sequence in the batch while taking into account different sequence lengths ([1,:8])
emb_1 = embedding_repr.last_hidden_state[1,:8] # shape (8 x 1024)

# if you want to derive a single representation (per-protein embedding) for the whole protein
emb_0_per_protein = emb_0.mean(dim=0) # shape (1024)

print(f"Shape of per-protein embedding of first sequences: {emb_0_per_protein.shape}")

In [None]:
# Next Steps:
# We need to fine-tune this model to predict protein-family. 
#   We can use an existing dataset for this but it should be one that either has
#   a large variety of proteins labeled by protein family, or else have a lot of
#   proteins belonging to families relevant for this task. My guess is we'll want
#   a datest with a larger variety because we probably don't know which families are relevant for this task
# 
# With the ProtTrans architecutre for pre-trained models it will be easy to extract the vector reprenstations
#
# Next, how can we map those representations into a vector space. I don't have much experience with this - may take more research.
#
# Test the architecuture: are we able to see that the vector representation of test 
# data proteins cluster closely with other proteins in the same family? 
#
# Finally, find where the vector representations of the unknown proteins land in the vector space.

In [None]:
# Fine-tuned ProtT5 model

class ProtT5Classifier(pl.LightningModule):

  def __init__(self, hparams):
    super(ProtT5Classifier, self).__init__()
    # The author of the model I'm basing this on uses a an
    self.hparams = hparams
    self.batch_size = self.hparams.batch_size

    # self.data_set = We need to find a dataset labels for protein family

    self.metric = pl.metrics.sklearns.Accuracy()
    self.loss = torch.nn.CrossEntropyLoss()

    self._build_model()

    # freeze layers
    for param in self.model.parameters()[:-2]: # Will this work? Need to verify that freeze all but last two layers
      param.require_grad = False

  def _build_model(self):
    transformer_link = "Rostlab/prot_t5_xl_half_uniref50-enc"
    self.model = T5EncoderModel.from_pretrained(transformer_link)
    print("Loading: {}".format(transformer_link))

    # *What does full-precision mean and why only use it on cpu?
    model.full() if device=='cpu' else model.half() # only cast to full-precision if no GPU is available

    self.encoder_features = 1024 # embedding layer. this should be the size of the second to last layer assuming it is a FC layer

    tokenizer = T5Tokenizer.from_pretrained(transformer_link, do_lower_case=False )

    # this is pulle din directly from https://github.com/agemagician/ProtTrans/blob/master/Fine-Tuning/ProtBert-BFD-FineTuning-PyTorchLightning-Localization.ipynb
    # I am not sure how this works - what is the "label Encoder doing?"
    # Label Encoder
    self.label_encoder = LabelEncoder(
        self.hparams.label_set.split(","), reserved_labels=[]
    )
    self.label_encoder.unknown_index = None

    # Classification head
    x = 1 # X is the number of classes. How many families of proteins are in our data?
    self.classification_head = torch.nn.Sequential(
        torch.nn.Linear(self.encoder_features*x, self.label_encoder.vocab_size),
        torch.nn.Tanh(),
    )

    # def predict(self, sample: dict) -> dict:
    #   """ Predict function.
    #   :param sample: dictionary with the text we want to classify.
    #   Returns:
    #       Dictionary with the input text and the predicted label.
    #   """
    #   if self.training:
    #       self.eval()

    #   with torch.no_grad():
    #       model_input, _ = self.prepare_sample([sample], prepare_target=False)
    #       model_out = self.forward(**model_input)
    #       logits = model_out["logits"].numpy()
    #       predicted_labels = [
    #           self.label_encoder.index_to_token[prediction]
    #           for prediction in np.argmax(logits, axis=1)
    #       ]
    #       sample["predicted_label"] = predicted_labels[0]

    #   return sample

    # def prepare_sample(self, sample: list, prepare_target: bool = True) -> (dict, dict):
    #   """
    #   Function that prepares a sample to input the model.
    #   :param sample: list of dictionaries.
      
    #   Returns:
    #       - dictionary with the expected model inputs.
    #       - dictionary with the expected target labels.
    #   """
    #   sample = collate_tensors(sample)

    #   inputs = self.tokenizer.batch_encode_plus(sample["seq"],
    #                                             add_special_tokens=True,
    #                                             padding=True,
    #                                             truncation=True,
    #                                             max_length=self.hparams.max_length)

    #   if not prepare_target:
    #       return inputs, {}

    #   # Prepare target:
    #   try:
    #       targets = {"labels": self.label_encoder.batch_encode(sample["label"])}
    #       return inputs, targets
    #   except RuntimeError:
    #       print(sample["label"])
    #       raise Exception("Label encoder found an unknown label.")

    def training_step(self, batch: tuple, batch_nb: int, *args, **kwargs) -> dict:
      """ 
      Runs one training step. This usually consists in the forward function followed
          by the loss function.
      
      :param batch: The output of your dataloader. 
      :param batch_nb: Integer displaying which batch this is
      Returns:
          - dictionary containing the loss and the metrics to be added to the lightning logger.
      """
      inputs, targets = batch
      model_out = self.forward(**inputs)
      loss_val = self.loss(model_out, targets)

      tqdm_dict = {"train_loss": loss_val}
      output = OrderedDict(
          {"loss": loss_val, "progress_bar": tqdm_dict, "log": tqdm_dict})

      # can also return just a scalar instead of a dict (return loss_val)
      return output

    def test_step(self, batch: tuple, batch_nb: int, *args, **kwargs) -> dict:
        """ Similar to the training step but with the model in eval mode.
        Returns:
            - dictionary passed to the validation_end function.
        """
        inputs, targets = batch
        model_out = self.forward(**inputs)
        loss_test = self.loss(model_out, targets)

        y = targets["labels"]
        y_hat = model_out["logits"]
        
        labels_hat = torch.argmax(y_hat, dim=1)
        test_acc = self.metric_acc(labels_hat, y)
        
        output = OrderedDict({"test_loss": loss_test, "test_acc": test_acc,})

        return output

    # def train_dataloader(self) -> DataLoader:
    #     """ Function that loads the train set. """
    #     self._train_dataset = self.__retrieve_dataset(val=False, test=False)
    #     return DataLoader(
    #         dataset=self._train_dataset,
    #         sampler=RandomSampler(self._train_dataset),
    #         batch_size=self.hparams.batch_size,
    #         collate_fn=self.prepare_sample,
    #         num_workers=self.hparams.loader_workers,
    #     )

    # def test_dataloader(self) -> DataLoader:
    #     """ Function that loads the validation set. """
    #     self._test_dataset = self.__retrieve_dataset(train=False, val=False)
    #     return DataLoader(
    #         dataset=self._test_dataset,
    #         batch_size=self.hparams.batch_size,
    #         collate_fn=self.prepare_sample,
    #         num_workers=self.hparams.loader_workers,
    #     )