In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from transformers import (
    AutoModelForTokenClassification,
    AutoTokenizer,
    pipeline
)
import re
from tqdm import tqdm

In [None]:
import warnings

warnings.filterwarnings("ignore")

In [None]:
max_sequence_length = 1022
device = 0 if torch.cuda.is_available() else -1
device

0

In [None]:
df = pd.read_csv("../data/uniprot_all_human_proteins.txt.gz", sep='\t')
df.head(), len(df)

(        Entry  Length                                           Sequence  \
 0  A0A087X1C5     515  MGLEALVPLAMIVAIFLLLVDLMHRHQRWAARYPPGPLPLPGLGNL...   
 1  A0A0B4J2F0      54  MFRRLTFAQLLFATVLGIAGGVYIFQPVFEQYAKDQKELKEKMQLV...   
 2  A0A0B4J2F2     783  MVIMSEFSADPAGQGQGQQKPLRVGFYDIERTLGKGNFAVVKLARH...   
 3  A0A0C5B5G6      16                                   MRWQEMGYIFYPRKLR   
 4  A0A0K2S4Q6     201  MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE...   
 
   Gene Names (primary)                                   Protein families  \
 0               CYP2D7                             Cytochrome P450 family   
 1              PIGBOS1                                                NaN   
 2                SIK1B  Protein kinase superfamily, CAMK Ser/Thr prote...   
 3              MT-RNR1                                                NaN   
 4               CD300H                                       CD300 family   
 
                                          Active site  \
 0       

In [None]:
def build_labels_region(sequence, feature, feature_re):
    # Start with all 0s
    labels = np.zeros(len(sequence), dtype=np.int64)
    region_re = f"{feature_re}\s(\d+)\.\.(\d+)\;"
    residue_re = f'{feature_re}\s(\d+);'

    found_region = re.findall(region_re, feature)

    for start, end in found_region:
        start = int(start) - 1
        end = int(end)
        assert end <= len(sequence)
        labels[start: end] = 1

    found_residue = re.findall(residue_re, feature)
    for pos in found_residue:
        pos = int(pos) -1
        assert pos <= len(sequence)
        labels[pos] = 1

    return ''.join(map(str, labels))


def build_labels_bonds(sequence, feature, feature_re):
    # Start with all 0s
    labels = np.zeros(len(sequence), dtype=np.int64)

    region_re = f"{feature_re}\s(\d+)\.\.(\d+)\;"

    if isinstance(feature, float): # Indicates missing (NaN)
        found_feature = []
    else:
        found_feature = re.findall(region_re, feature)
    for start, end in found_feature:
        start = int(start) - 1
        end = int(end) -1
        assert end <= len(sequence)
        labels[start] = 1
        labels[end] = 1 
    return ''.join(map(str, labels))

In [None]:
def generate_label(sequence, feature, feature_re, colname):
  if colname == 'Disulfide bond':
    return build_labels_bonds(sequence, feature, feature_re)
  else:
    return build_labels_region(sequence, feature, feature_re)


# Function to apply to each row
def process_row(row, colname, feature_re):
    if pd.isna(row[colname]):
        return pd.Series(['pred', np.nan])
    else:
        return pd.Series(['actual', generate_label(row['Sequence'], row[colname], feature_re, colname)])


In [None]:
all_features = ['Active site', 'Binding site', 'DNA binding',
                'Topological domain', 'Transmembrane',
                 'Disulfide bond', 'Propeptide', 'Signal peptide', 'Transit peptide',
                'Beta strand', 'Helix',
                'Coiled coil', 'Compositional bias', 'Domain', 'Region', 'Repeat', 'Zinc finger']

all_features_re = ['ACT_SITE', 'BINDING', 'DNA_BIND',
                   'TOPO_DOM', 'TRANSMEM',
                   'DISULFID',  'PROPEP', 'SIGNAL', 'TRANSIT',
                   'STRAND', 'HELIX',
                   'COILED', 'COMPBIAS', 'DOMAIN', 'REGION', 'REPEAT', 'ZN_FING']

In [None]:
# actual
for i in range(len(all_features)):
  colname = all_features[i]
  feature_re = all_features_re[i]

  # Apply the function and assign new columns
  df[[f'actual_or_pred_{feature_re}', f'annotation_actual_{feature_re}']] = df.apply(lambda row: process_row(row, colname, feature_re), axis=1)


In [None]:
# split long sequences for ESM2 inference
def split_text(row, n=max_sequence_length):
    text = row['Sequence']
    if len(text) <= n:
        # No split needed, return the row with an added 'split_indicator' column
        return pd.DataFrame({**row, 'split_indicator': [0]}).iloc[[0]]
    else:
        # Split needed, generate parts and a corresponding DataFrame
        splits = [text[i:i+n] for i in range(0, len(text), n)]
        part_data = {col: [row[col]] * len(splits) if col != 'Sequence' else splits for col in df.columns}
        part_data['split_indicator'] = range(1, len(splits) + 1)  # numbering the splits
        return pd.DataFrame(part_data)

In [None]:
# pred
for i in range(len(all_features)):
#for i in range(0, 1):

  print(f'i = {i}')
  colname = all_features[i]
  feature_re = all_features_re[i]

  model_checkpoint = f"AliSaadatV/esm2_t12_35M_UR50D-finetuned-{feature_re}_earlystop"
  token_classifier = pipeline(
      "token-classification", model=model_checkpoint, device=device
  )

  df_pred = df[df[colname].isna()].copy()
  split_df = df_pred.apply(split_text, axis=1)
  df_pred = pd.concat([item for item in split_df]).reset_index(drop=True)

  with torch.no_grad():
    outputs = token_classifier(list(df_pred["Sequence"].values))

  temp_results_list = []
  for output in outputs:
    temp_results_list.append(
        ''.join([entry['entity'].replace('LABEL_', '') for entry in output])
    )

  df_pred[f'annotation_pred_{feature_re}'] = temp_results_list
  df_pred = df_pred[['Entry', f'annotation_pred_{feature_re}']]
  df_pred = df_pred.groupby('Entry')[f'annotation_pred_{feature_re}'].agg(''.join).reset_index()
  df = pd.merge(df, df_pred, on='Entry', how='left')
  df[['Entry', f'actual_or_pred_{feature_re}', f'annotation_actual_{feature_re}', f'annotation_pred_{feature_re}']].to_csv(f'uniprot_all_human_proteins_annotated_{feature_re}.txt', sep='\t', index=False)


df.to_csv('uniprot_all_human_proteins_annotated.txt', sep='\t', index=False)

i = 0


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


i = 1


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


i = 2


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


i = 3


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


i = 4


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


i = 5


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


i = 6


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


i = 7


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


i = 8


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


i = 9


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


i = 10


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


i = 11


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


i = 12


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


i = 13


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


i = 14


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


i = 15


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


i = 16


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


i = 17


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


i = 18


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


i = 19


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


i = 20


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


i = 21


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
