In [2]:
"""
In this notebook, we will finetune the Protobert model on the dataset.
We first have to create a dataset class that will be used to load the data.
"""

from torch.utils.data import Dataset, DataLoader
from transformers import BertForMaskedLM, BertTokenizer

import torch
import numpy as np
import pandas as pd
from tqdm import tqdm

# bio  package
from Bio import SeqIO

# import lightning utils
import lightning as L


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class ProteinDataset(Dataset):
    """
    Dataset class for the protein sequences.
    """
    def __init__(self, parquet_file_input, parquet_target=None) -> None:
        """
        params:
        - parquet_input: path to the parquet file containing the sequences
        - target: path to the target file
        """
        super().__init__()

        # we read the parquet file
        self.sequences = pd.read_parquet(parquet_file_input)
        self.target = pd.read_parquet(parquet_target) if parquet_target else None

        self.tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False )

    def __len__(self) -> int:
        return len(self.sequences)
    
    def __getitem__(self, index: int) -> dict:

        # we get the sequence
        sequence = self.sequences.iloc[index].sequence


        # we retrieve EntryID
        entry_id = self.sequences.iloc[index].other_entry
    
        # we tokenize the sequence
        tokenized_sequence = self.tokenizer(sequence, return_tensors='pt')

        # scalar input
        scalar_organism = self.sequences.iloc[index].organism  
        animal = self.sequences.iloc[index].animal

        # we get the target value if it exists
        target = self.target.loc[entry_id].values if self.target is not None else None

        return {
            'input_ids': tokenized_sequence,
            'target': target,
            'organism': scalar_organism,
            'animal': animal
        }

In [4]:
path_parquet_input = "/workspaces/protein_ontologies/dataset_kg/Train/train_sequences.parquet"
path_parquet_target = "/workspaces/protein_ontologies/dataset_kg/Train/train_labels.parquet"

# we create the dataset
dataset = ProteinDataset(path_parquet_input, path_parquet_target)

In [6]:
dataset[0]["input_ids"]

{'input_ids': tensor([[ 2,  2, 21, 17, 10,  8, 15,  8, 10, 22,  6, 16, 20, 15, 11, 15, 20, 22,
         14, 14, 24,  9, 16,  8, 21, 10, 18,  5,  8,  9, 19, 20, 17,  9,  8,  6,
         10, 24,  5,  5, 13, 14,  9, 15, 10, 16, 11, 16, 14, 12, 19, 19, 11, 18,
          5, 12, 18, 16,  5, 13, 17, 12, 13,  8, 23,  8, 23,  7, 11, 14, 16, 20,
         16, 12, 14,  7, 15,  7,  8, 16, 19,  9, 10, 16, 17, 19, 15, 12, 12, 10,
         11, 12,  9, 11,  6, 10, 10, 11, 10, 13,  5, 15,  7,  8, 11, 14, 20, 12,
          7, 20, 17,  5, 17, 11, 11, 14,  7,  8, 11, 16, 24, 17, 20, 20,  5, 10,
         23, 12,  5,  7,  9, 15, 12, 10, 22,  6, 11, 20, 24, 14, 12, 11, 10, 12,
          5,  5,  5, 18, 22, 11, 15, 12, 22,  8, 10,  8,  5, 20, 23,  5,  7, 12,
         15, 14, 19, 10, 17, 11, 13,  6, 12,  5,  9, 10, 16,  8, 15, 15, 11,  8,
          7, 20, 22, 16,  6,  6, 13, 14, 13, 18, 19,  9, 12, 14, 13, 10, 19,  9,
         11, 11, 17,  8,  5,  5,  9,  5, 14, 17, 12,  8, 16, 11, 17, 24,  6, 18,
          7, 1

In [None]:
from typing import Any
from torch import nn

class ProteinModel(L.LightningModule):
    """
    Module to finetune the Protobert model.
    """
    def __init__(self, nb_layer) -> None:
        super().__init__()

        self.model = BertForMaskedLM.from_pretrained("Rostlab/prot_bert")

        loss_fn = nn.BCELoss()

    def compute_loss(self, batch: dict) -> Any:
        """
        Compute the loss for a batch.
        """
        # we get the input_ids
        input_ids = batch["input_ids"]

        # we get the output
        output = self.model(input_ids)

        # we get the target
        target = batch["target"]

        # we compute the loss
        loss = self.loss_fn(output, target)

        return loss
    
    def forward(self, batch: dict) -> Any:
        """
        Forward pass.
        """
        return self.model(batch["input_ids"])
    
    def training_step(self, batch: dict, batch_idx: int) -> dict:
        """
        Training step.
        """
        loss = self.compute_loss(batch)

        return {
            "loss": loss
        }
    
    def validation_step(self, batch: dict, batch_idx: int) -> dict:
        """
        Validation step.
        """
        loss = self.compute_loss(batch)

        return {
            "loss": loss
        }