In [1]:
"""
In this notebook, we will finetune the Protobert model on the dataset.
We first have to create a dataset class that will be used to load the data.
"""
from typing import List, Optional, Tuple, Union

from torch.utils.data import Dataset, DataLoader
from transformers import BertForMaskedLM, BertTokenizer

import torch
import numpy as np
import pandas as pd
from tqdm import tqdm

# import lightning utils
import lightning as L

# init wandb logger
import wandb


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class ProteinDataset(Dataset):
    """
    Dataset class for the protein sequences.
    """
    def __init__(self, parquet_file_input, parquet_target=None) -> None:
        """
        params:
        - parquet_input: path to the parquet file containing the sequences
        - target: path to the target file
        """
        super().__init__()

        # we read the parquet file
        self.sequences = pd.read_parquet(parquet_file_input)
        self.target = pd.read_parquet(parquet_target) if parquet_target else None

        self.tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False )

    def __len__(self) -> int:
        return len(self.sequences)
    
    def __getitem__(self, index: int) -> dict:

        # we get the sequence
        sequence = self.sequences.iloc[index].sequence


        # we retrieve EntryID
        entry_id = self.sequences.iloc[index].other_entry
    
        # we tokenize the sequence
        tokenized_sequence = self.tokenizer(sequence, return_tensors='pt')

        # we squeeze every tensor
        tokenized_sequence = {key: value.squeeze() for key, value in tokenized_sequence.items()}

        # scalar input
        scalar_organism = self.sequences.iloc[index].organism  
        animal = self.sequences.iloc[index].animal

        # we get the target value if it exists
        target = self.target.loc[entry_id].values if self.target is not None else None

        return {
            'input_ids': tokenized_sequence,
            'target': target,
            'organism': scalar_organism,
            'animal': animal
        }

In [3]:
path_parquet_input = "/workspaces/protein_ontologies/dataset_kg/Train/train_sequences.parquet"
path_parquet_target = "/workspaces/protein_ontologies/dataset_kg/Train/train_labels.parquet"

# we create the dataset
dataset = ProteinDataset(path_parquet_input, path_parquet_target)

# we create the dataloader
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [4]:
dataset[0]["input_ids"]

{'input_ids': tensor([ 2,  2, 21, 17, 10,  8, 15,  8, 10, 22,  6, 16, 20, 15, 11, 15, 20, 22,
         14, 14, 24,  9, 16,  8, 21, 10, 18,  5,  8,  9, 19, 20, 17,  9,  8,  6,
         10, 24,  5,  5, 13, 14,  9, 15, 10, 16, 11, 16, 14, 12, 19, 19, 11, 18,
          5, 12, 18, 16,  5, 13, 17, 12, 13,  8, 23,  8, 23,  7, 11, 14, 16, 20,
         16, 12, 14,  7, 15,  7,  8, 16, 19,  9, 10, 16, 17, 19, 15, 12, 12, 10,
         11, 12,  9, 11,  6, 10, 10, 11, 10, 13,  5, 15,  7,  8, 11, 14, 20, 12,
          7, 20, 17,  5, 17, 11, 11, 14,  7,  8, 11, 16, 24, 17, 20, 20,  5, 10,
         23, 12,  5,  7,  9, 15, 12, 10, 22,  6, 11, 20, 24, 14, 12, 11, 10, 12,
          5,  5,  5, 18, 22, 11, 15, 12, 22,  8, 10,  8,  5, 20, 23,  5,  7, 12,
         15, 14, 19, 10, 17, 11, 13,  6, 12,  5,  9, 10, 16,  8, 15, 15, 11,  8,
          7, 20, 22, 16,  6,  6, 13, 14, 13, 18, 19,  9, 12, 14, 13, 10, 19,  9,
         11, 11, 17,  8,  5,  5,  9,  5, 14, 17, 12,  8, 16, 11, 17, 24,  6, 18,
          7, 19

In [5]:
# we test the dataloader
for batch in dataloader:
    print(batch["input_ids"])
    print(batch["target"])
    # look at the shape of the input_ids
    print(batch["input_ids"]["input_ids"].shape)
    # look at the shape of the target
    print(batch["target"].shape)

    break

{'input_ids': tensor([[ 2,  2, 21,  5,  6,  7, 13, 20, 23,  5, 15, 15, 22, 14,  6, 16, 15, 13,
         22, 11,  9, 24,  5, 16,  5, 20, 19, 22, 10, 13, 22, 20,  5,  5, 10, 19,
         19,  5, 19, 10, 19,  5, 19, 19, 19, 16, 23,  8,  8, 23,  8, 23, 18, 15,
          9, 12,  7, 17, 12,  6, 16,  6, 21,  5, 13,  8, 10, 13,  5, 10,  5,  6,
         11,  7, 15, 20, 10,  5, 19, 21, 12,  9, 18, 12, 17, 17, 16,  6,  5, 12,
          7,  5, 16,  8,  6, 12, 13,  7, 18,  6, 15,  6, 12,  5, 20, 13,  9,  5,
         10,  8, 15,  9, 13,  9,  9,  5,  6, 12, 13,  6, 12,  6,  6, 16, 10,  6,
         15, 13, 12, 12, 10,  8, 16, 12, 15, 12,  7, 12,  9, 12,  8, 15,  7,  7,
          7,  7, 12, 13, 12,  6, 10,  9, 20, 15,  9, 19,  8, 12, 10, 17, 11, 10,
         12, 20, 10, 17,  5, 16, 18, 13,  9, 13, 21, 15,  6,  8,  6, 12,  5, 24,
         12, 18, 18, 12, 18, 21, 13, 12,  3,  3]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0

In [71]:
from typing import Any
from torch import nn
import transformers

global_model = "Rostlab/prot_bert"
NB_OUTPUT = 500

class ProteinModel(L.LightningModule):
    """
    Module to finetune the Protobert model.
    """
    def __init__(self) -> None:
        super().__init__()

        self.config = transformers.AutoConfig.from_pretrained(global_model)
        self.config.update({'output_hidden_states':True})

        self.model = BertForMaskedLM.from_pretrained(global_model, config=self.config)

        self.head = nn.Sequential(
            nn.Linear(self.config.hidden_size * 2, NB_OUTPUT),
            nn.Sigmoid()
        )

        self.loss_fn = nn.BCELoss()

    def compute_loss(self, batch: dict) -> Any:
        """
        Compute the loss for a batch.
        """
        # we get the output
        output = self(batch)

        # we get the target
        target = batch["target"]

        # we compute the loss
        loss = self.loss_fn(output, target.float())

        return loss
    
    def forward(self, batch: dict) -> Any:
        """
        Forward pass.
        """
        hidden_states = self.model(**batch["input_ids"])["hidden_states"]
        
        # concat the 2 hidden states
        hidden_states = torch.cat([hidden_states[-1], hidden_states[-2]], dim=-1)

        # average the hidden states
        hidden_states = torch.mean(hidden_states, dim=1)

        # we pass it through the head
        hidden_states = self.head(hidden_states)

        return hidden_states
    
    def training_step(self, batch: dict, batch_idx: int) -> dict:
        """
        Training step.
        """
        loss = self.compute_loss(batch)

        return {
            "loss": loss
        }
    
    def validation_step(self, batch: dict, batch_idx: int) -> dict:
        """
        Validation step.
        """
        loss = self.compute_loss(batch)

        return {
            "loss": loss
        }
    
    def configure_optimizers(self) -> Any:
        """
        Configure the optimizer.
        """
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-6)

        return optimizer

In [72]:
# we initialize the model
model = ProteinModel()

# test the model
with torch.no_grad():
    output = model(batch)

    loss = model.compute_loss(batch)

Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [73]:


# init wandb
wandb.init(project="protein-ontologies", name="protobert")

logger = L.WandbLogger(project="protein-ontologies", name="protobert", entity='forbu14')

# init the trainer
trainer = L.Trainer(
    accelerator="gpu",
    devices=1,
    precision="16-mixed",
    accumulate_grad_batches=8,
    max_time="00:01:00:00",
    logger=logger,
    gradient_clip_val=0.5,
)

torch.Size([1, 500])

In [74]:
loss

tensor(0.6905)