The goal of this is to apply the fine-tuned model on all human proteins w/o annotations.

# Setup

In [None]:
#import dependencies
import os.path
#os.chdir("set working path here")

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.utils.data import DataLoader

import re
import numpy as np
import pandas as pd
import copy

import transformers, datasets
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
from transformers import T5EncoderModel, T5Tokenizer
from transformers.models.esm.modeling_esm import EsmPreTrainedModel, EsmModel
from transformers import AutoTokenizer
from transformers import TrainingArguments, Trainer, set_seed
from transformers import DataCollatorForTokenClassification

from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union

# for custom DataCollator
from transformers.data.data_collator import DataCollatorMixin
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy

import peft
from peft import get_peft_config, PeftModel, PeftConfig, inject_adapter_in_model, LoraConfig

from evaluate import load
from datasets import Dataset

from tqdm import tqdm
import random

from scipy import stats
from sklearn.metrics import accuracy_score

import matplotlib.pyplot as plt

In [2]:
print("Torch version: ",torch.__version__)
print("Cuda version: ",torch.version.cuda)
print("Numpy version: ",np.__version__)
print("Pandas version: ",pd.__version__)
print("Transformers version: ",transformers.__version__)
print("Datasets version: ",datasets.__version__)

Torch version:  2.5.1+cu124
Cuda version:  12.4
Numpy version:  2.0.2
Pandas version:  2.2.3
Transformers version:  4.51.3
Datasets version:  3.1.0


In [3]:
ESMs = ["facebook/esm2_t6_8M_UR50D",
         "facebook/esm2_t12_35M_UR50D",
         "facebook/esm2_t30_150M_UR50D",
         "facebook/esm2_t33_650M_UR50D",
         "facebook/esm2_t36_3B_UR50D"]

ProtT5 = ["Rostlab/prot_t5_xl_uniref50"]

selected_checkpoint = "facebook/esm2_t36_3B_UR50D"

# model architecture

In [4]:
class EsmForTokenClassificationCustom(EsmPreTrainedModel):
    _keys_to_ignore_on_load_unexpected = [r"pooler"]
    _keys_to_ignore_on_load_missing = [r"position_ids"]

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.esm = EsmModel(config, add_pooling_layer=False)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, TokenClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.esm(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        loss = None
        # changed to ignore special tokens at the seq start and end 
        # as well as invalid positions (labels -100)
        if labels is not None:
            loss_fct = CrossEntropyLoss()

            active_loss = attention_mask.view(-1) == 1
            active_logits = logits.view(-1, self.num_labels)

            active_labels = torch.where(
              active_loss, labels.view(-1), torch.tensor(-100).type_as(labels)
            )

            valid_logits=active_logits[active_labels!=-100]
            valid_labels=active_labels[active_labels!=-100]
            
            valid_labels=valid_labels.type(torch.LongTensor).to('cuda:0')
            
            loss = loss_fct(valid_logits, valid_labels)

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    
# based on transformers DataCollatorForTokenClassification
@dataclass
class DataCollatorForTokenClassificationESM(DataCollatorMixin):
    """
    Data collator that will dynamically pad the inputs received, as well as the labels.
    Args:
        tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
            The tokenizer used for encoding the data.
        padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
              sequence is provided).
            - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
              acceptable input length for the model if that argument is not provided.
            - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
        max_length (`int`, *optional*):
            Maximum length of the returned list and optionally padding length (see above).
        pad_to_multiple_of (`int`, *optional*):
            If set will pad the sequence to a multiple of the provided value.
            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
            7.5 (Volta).
        label_pad_token_id (`int`, *optional*, defaults to -100):
            The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
        return_tensors (`str`):
            The type of Tensor to return. Allowable values are "np", "pt" and "tf".
    """

    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    label_pad_token_id: int = -100
    return_tensors: str = "pt"

    def torch_call(self, features):
        import torch

        label_name = "label" if "label" in features[0].keys() else "labels"
        labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None

        no_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]

        batch = self.tokenizer.pad(
            no_labels_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        if labels is None:
            return batch

        sequence_length = batch["input_ids"].shape[1]
        padding_side = self.tokenizer.padding_side

        def to_list(tensor_or_iterable):
            if isinstance(tensor_or_iterable, torch.Tensor):
                return tensor_or_iterable.tolist()
            return list(tensor_or_iterable)

        if padding_side == "right":
            batch[label_name] = [
                # to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
                # changed to pad the special tokens at the beginning and end of the sequence
                [self.label_pad_token_id] + to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)-1) for label in labels
            ]
        else:
            batch[label_name] = [
                [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels
            ]

        batch[label_name] = torch.tensor(batch[label_name], dtype=torch.float)
        return batch

def _torch_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
    """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
    import torch

    # Tensorize if necessary.
    if isinstance(examples[0], (list, tuple, np.ndarray)):
        examples = [torch.tensor(e, dtype=torch.long) for e in examples]

    length_of_first = examples[0].size(0)

    # Check if padding is necessary.

    are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
    if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
        return torch.stack(examples, dim=0)

    # If yes, check if we have a `pad_token`.
    if tokenizer._pad_token is None:
        raise ValueError(
            "You are attempting to pad samples but the tokenizer you are using"
            f" ({tokenizer.__class__.__name__}) does not have a pad token."
        )

    # Creating the full tensor and filling it with our data.
    max_length = max(x.size(0) for x in examples)
    if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
        max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
    result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
    for i, example in enumerate(examples):
        if tokenizer.padding_side == "right":
            result[i, : example.shape[0]] = example
        else:
            result[i, -example.shape[0] :] = example
    return result

def tolist(x):
    if isinstance(x, list):
        return x
    elif hasattr(x, "numpy"):  # Checks for TF tensors without needing the import
        x = x.numpy()
    return x.tolist()


# functions to load the fine-tuned weights

In [5]:
#load ESM2 models
def load_esm_model_classification(checkpoint, num_labels, half_precision, full=False, deepspeed=False):
    
    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    
    if half_precision and deepspeed:
        model = EsmForTokenClassificationCustom.from_pretrained(checkpoint, num_labels = num_labels, torch_dtype = torch.float16)
    else:
        model = EsmForTokenClassificationCustom.from_pretrained(checkpoint, num_labels = num_labels)
        
    if full == True:
        return model, tokenizer 
        
    peft_config = LoraConfig(
        r=4, lora_alpha=1, bias="all", target_modules=["query","key","value","dense"]
    )
    
    model = inject_adapter_in_model(peft_config, model)
    
    # Unfreeze the prediction head
    for (param_name, param) in model.classifier.named_parameters():
                param.requires_grad = True  
    
    return model, tokenizer

In [6]:
def load_model(checkpoint, filepath, num_labels=2, half_precision = True, full = False, deepspeed=False):
# Creates a new PT5 model and loads the finetuned weights from a file
#example: tokenizer, model_reload = load_model(checkpoint, f"./fine_tuned_models/{checkpoint}/{all_features_re[0]}.pth", num_labels=2)
    
    # load model
    if "esm" in checkpoint:
        model, tokenizer = load_esm_model_classification(checkpoint, num_labels, half_precision, full, deepspeed)
    else:
        model, tokenizer = load_T5_model_classification(checkpoint, num_labels, half_precision, full, deepspeed)
    
    # Load the non-frozen parameters from the saved file
    non_frozen_params = torch.load(filepath)

    # Assign the non-frozen parameters to the corresponding parameters of the model
    for param_name, param in model.named_parameters():
        if param_name in non_frozen_params:
            param.data = non_frozen_params[param_name].data

    return tokenizer, model

In [7]:
df = pd.read_csv("../data/uniprot_all_human_proteins.txt.gz", sep='\t')
df.head(), len(df)

(        Entry  Length                                           Sequence  \
 0  A0A087X1C5     515  MGLEALVPLAMIVAIFLLLVDLMHRHQRWAARYPPGPLPLPGLGNL...   
 1  A0A0B4J2F0      54  MFRRLTFAQLLFATVLGIAGGVYIFQPVFEQYAKDQKELKEKMQLV...   
 2  A0A0B4J2F2     783  MVIMSEFSADPAGQGQGQQKPLRVGFYDIERTLGKGNFAVVKLARH...   
 3  A0A0C5B5G6      16                                   MRWQEMGYIFYPRKLR   
 4  A0A0K2S4Q6     201  MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE...   
 
   Gene Names (primary)                                   Protein families  \
 0               CYP2D7                             Cytochrome P450 family   
 1              PIGBOS1                                                NaN   
 2                SIK1B  Protein kinase superfamily, CAMK Ser/Thr prote...   
 3              MT-RNR1                                                NaN   
 4               CD300H                                       CD300 family   
 
                                          Active site  \
 0       

# functions to convert the uniprot text into labels (for annotated proteins)

In [8]:
def build_labels_region(sequence, feature, feature_re):
    # Start with all 0s
    labels = np.zeros(len(sequence), dtype=np.int64)
    region_re = f"{feature_re}\s(\d+)\.\.(\d+)\;"
    residue_re = f'{feature_re}\s(\d+);'

    found_region = re.findall(region_re, feature)

    for start, end in found_region:
        start = int(start) - 1
        end = int(end)
        assert end <= len(sequence)
        labels[start: end] = 1

    found_residue = re.findall(residue_re, feature)
    for pos in found_residue:
        pos = int(pos) -1
        assert pos <= len(sequence)
        labels[pos] = 1

    return ''.join(map(str, labels))


def build_labels_bonds(sequence, feature, feature_re):
    # Start with all 0s
    labels = np.zeros(len(sequence), dtype=np.int64)

    region_re = f"{feature_re}\s(\d+)\.\.(\d+)\;"

    if isinstance(feature, float): # Indicates missing (NaN)
        found_feature = []
    else:
        found_feature = re.findall(region_re, feature)
    for start, end in found_feature:
        start = int(start) - 1
        end = int(end) -1
        assert end <= len(sequence)
        labels[start] = 1
        labels[end] = 1 
    return ''.join(map(str, labels))

In [9]:
def generate_label(sequence, feature, feature_re, colname):
  if colname == 'Disulfide bond':
    return build_labels_bonds(sequence, feature, feature_re)
  else:
    return build_labels_region(sequence, feature, feature_re)


# Function to apply to each row
def process_row(row, colname, feature_re):
    if pd.isna(row[colname]):
        return pd.Series(['pred', np.nan])
    else:
        return pd.Series(['actual', generate_label(row['Sequence'], row[colname], feature_re, colname)])


In [10]:
all_features = ['Active site', 'Binding site', 'DNA binding', 
                'Topological domain', 'Transmembrane',
                'Disulfide bond', 'Modified residue', 'Propeptide', 'Signal peptide', 'Transit peptide',
                'Beta strand', 'Helix', 'Turn',
                'Coiled coil', 'Compositional bias', 'Domain [FT]', 'Motif', 'Region', 'Repeat', 'Zinc finger']

all_features_re = ['ACT_SITE', 'BINDING', 'DNA_BIND', 
                   'TOPO_DOM', 'TRANSMEM',
                   'DISULFID', 'MOD_RES',  'PROPEP', 'SIGNAL', 'TRANSIT',
                   'STRAND', 'HELIX', 'TURN',
                   'COILED', 'COMPBIAS', 'DOMAIN', 'MOTIF', 'REGION', 'REPEAT', 'ZN_FING']

In [11]:
# actual
for i in range(len(all_features)):
  colname = all_features[i]
  feature_re = all_features_re[i]

  # Apply the function and assign new columns
  df[[f'actual_or_pred_{feature_re}', f'annotation_actual_{feature_re}']] = df.apply(lambda row: process_row(row, colname, feature_re), axis=1)


In [12]:
# split long sequences for ESM2 inference (max size = 1022 AA)
def split_text(row, n=1022):
    text = row['Sequence']
    if len(text) <= n:
        # No split needed, return the row with an added 'split_indicator' column
        return pd.DataFrame({**row, 'split_indicator': [0]}).iloc[[0]]
    else:
        # Split needed, generate parts and a corresponding DataFrame
        splits = [text[i:i+n] for i in range(0, len(text), n)]
        part_data = {col: [row[col]] * len(splits) if col != 'Sequence' else splits for col in df.columns}
        part_data['split_indicator'] = range(1, len(splits) + 1)  # numbering the splits
        return pd.DataFrame(part_data)

# run inference for all features and store them

In [None]:
def single_inference(model, tokenizer, aa_seq, checkpoint):
    # Preprocess the input sequence
    aa_seq = aa_seq.replace("O", "X").replace("B", "X").replace("U", "X").replace("Z", "X")
    
    if "prot_t5" in checkpoint:
        aa_seq = " ".join(aa_seq)

    # Set the device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Tokenize input
    encoded = tokenizer(
        aa_seq,
        return_tensors="pt",
        padding="longest",
        truncation=True,
        is_split_into_words=False,
    )

    # Move inputs to device
    input_ids = encoded["input_ids"].to(device)
    attention_mask = encoded["attention_mask"].to(device)

    # Forward pass
    model.eval()
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # shape: (1, seq_len, num_classes)
        probs = torch.softmax(logits, dim=-1)

    return probs[:,:,1].squeeze().tolist()  #probability of being 1



# pred
for i in range(len(all_features)):

    print(f'i = {i}')
    colname = all_features[i]
    feature_re = all_features_re[i]

    # load model
    finetuned_params_path = f'../res/models/ft_{feature_re}_{selected_checkpoint.split("/")[1]}.pth'
    tokenizer, model = load_model(selected_checkpoint, finetuned_params_path)
    
    df_pred = df[df[colname].isna()].copy()
    split_df = df_pred.apply(split_text, axis=1)
    df_pred = pd.concat([item for item in split_df]).reset_index(drop=True)

    temp_results_list = []
    for aa_seq in tqdm(df_pred["Sequence"].values, desc="Predicting"):
        probs = single_inference(model, tokenizer, aa_seq, selected_checkpoint)  # list of probabilities
        pred_str = ''.join(['1' if p > 0.5 else '0' for p in probs])
        temp_results_list.append(pred_str)

    df_pred[f'annotation_pred_{feature_re}'] = temp_results_list
    df_pred = df_pred[['Entry', f'annotation_pred_{feature_re}']]
    df_pred = df_pred.groupby('Entry')[f'annotation_pred_{feature_re}'].agg(''.join).reset_index()
    df = pd.merge(df, df_pred, on='Entry', how='left')

    df[['Entry', f'actual_or_pred_{feature_re}', f'annotation_actual_{feature_re}', f'annotation_pred_{feature_re}']].to_csv(
        f'../res/proteome_inference/uniprot_all_human_proteins_annotated_{feature_re}.txt.gz',
        sep='\t', index=False
    )


df.to_csv('../res/proteome_inference/uniprot_all_human_proteins_annotated.txt.gz', sep='\t', index=False)


i = 0


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of EsmForTokenClassificationCustom were not initialized from the model checkpoint at facebook/esm2_t36_3B_UR50D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  non_frozen_params = torch.load(filepath)
Predicting:   0%|                                                                                                                                                                        | 0/20813 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Predicting: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20813/20813 [52:39<00:00,  6.59it/s]


i = 1


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of EsmForTokenClassificationCustom were not initialized from the model checkpoint at facebook/esm2_t36_3B_UR50D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  non_frozen_params = torch.load(filepath)
Predicting:   0%|                                                                                                                                                                        | 0/17964 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Predicting: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17964/17964 [44:00<00:00,  6.80it/s]


i = 2


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of EsmForTokenClassificationCustom were not initialized from the model checkpoint at facebook/esm2_t36_3B_UR50D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  non_frozen_params = torch.load(filepath)
Predicting:   0%|                                                                                                                                                                        | 0/22810 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Predicting:  32%|██████████████████████████████████████████████████▏                                                                                                          | 7290/22810 [19:14<28:59,  8.92it/s]

In [None]:
''' batch inference

def batch_inference(model, tokenizer, sequences, checkpoint, batch_size=16):
    # Preprocess sequences
    sequences = [
        seq.replace("O", "X").replace("B", "X").replace("U", "X").replace("Z", "X")
        for seq in sequences
    ]
    
    if "prot_t5" in checkpoint:
        sequences = [" ".join(seq) for seq in sequences]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    all_probs = []

    for i in tqdm(range(0, len(sequences), batch_size), desc="Batch inference"):
        batch_seqs = sequences[i:i+batch_size]

        # Tokenize
        encoded = tokenizer(
            batch_seqs,
            return_tensors="pt",
            padding=True,
            truncation=True,
            is_split_into_words=False,
        )

        input_ids = encoded["input_ids"].to(device)
        attention_mask = encoded["attention_mask"].to(device)

        with torch.no_grad():
            logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
            probs = torch.softmax(logits, dim=-1) # probs.shape: (batch_size, seq_len, num_classes)
            all_probs.extend(probs[:, :, 1].cpu().numpy())  # class 1 probability

    return all_probs  # list of numpy arrays, one per sequence


for i in range(len(all_features)):
    print(f'i = {i}')
    colname = all_features[i]
    feature_re = all_features_re[i]

    # load model
    finetuned_params_path = f'../res/models/ft_{feature_re}_{selected_checkpoint.split("/")[1]}.pth'
    tokenizer, model = load_model(selected_checkpoint, finetuned_params_path)

    df_pred = df[df[colname].isna()].copy()
    split_df = df_pred.apply(split_text, axis=1)
    df_pred = pd.concat([item for item in split_df]).reset_index(drop=True)

    seqs = list(df_pred["Sequence"].values)
    probs_list = batch_inference(model, tokenizer, seqs, selected_checkpoint, batch_size=16)

    temp_results_list = []
    for probs in tqdm(probs_list, desc="Postprocessing"):
        pred_str = ''.join(['1' if p > 0.5 else '0' for p in probs])
        temp_results_list.append(pred_str)

    df_pred[f'annotation_pred_{feature_re}'] = temp_results_list
    df_pred = df_pred[['Entry', f'annotation_pred_{feature_re}']]
    df_pred = df_pred.groupby('Entry')[f'annotation_pred_{feature_re}'].agg(''.join).reset_index()
    df = pd.merge(df, df_pred, on='Entry', how='left')

    df[['Entry', f'actual_or_pred_{feature_re}', f'annotation_actual_{feature_re}', f'annotation_pred_{feature_re}']].to_csv(
        f'../res/proteome_inference/uniprot_all_human_proteins_annotated_{feature_re}.txt.gz',
        sep='\t', index=False
    )

df.to_csv('../res/proteome_inference/uniprot_all_human_proteins_annotated.txt.gz', sep='\t', index=False)
'''