In [1]:
#!pip install datasets evaluate accelerate



In [2]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from datetime import datetime
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, matthews_corrcoef
from transformers import (
    AutoModelForTokenClassification,
    AutoTokenizer,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)
from datasets import Dataset, load_metric
import pickle
from tqdm import tqdm
import re
from evaluate import load

In [3]:
def seed_everything(seed: int):
    import random, os
    import numpy as np
    import torch

    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(42)

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [5]:
model_checkpoint = "facebook/esm2_t12_35M_UR50D"
model_name = model_checkpoint.split("/")[-1]
num_labels = 2
max_sequence_length = 1024

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
data_collator = DataCollatorForTokenClassification(tokenizer)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [6]:
df = pd.read_csv("../data/uniprot_all_human_proteins.txt.gz", sep='\t')
df = df[df["Length"] <= max_sequence_length]
df.head(), len(df)

(        Entry  Length                                           Sequence  \
 0  A0A087X1C5     515  MGLEALVPLAMIVAIFLLLVDLMHRHQRWAARYPPGPLPLPGLGNL...   
 1  A0A0B4J2F0      54  MFRRLTFAQLLFATVLGIAGGVYIFQPVFEQYAKDQKELKEKMQLV...   
 2  A0A0B4J2F2     783  MVIMSEFSADPAGQGQGQQKPLRVGFYDIERTLGKGNFAVVKLARH...   
 3  A0A0C5B5G6      16                                   MRWQEMGYIFYPRKLR   
 4  A0A0K2S4Q6     201  MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE...   
 
   Gene Names (primary)                                   Protein families  \
 0               CYP2D7                             Cytochrome P450 family   
 1              PIGBOS1                                                NaN   
 2                SIK1B  Protein kinase superfamily, CAMK Ser/Thr prote...   
 3              MT-RNR1                                                NaN   
 4               CD300H                                       CD300 family   
 
                                          Active site  \
 0       

In [7]:
def split_data(df, colname):
    print(colname)

    df = df[~df[colname].isna()]

    # Get a unique list of protein families
    unique_families = df['Protein families'].unique().tolist()
    np.random.shuffle(unique_families)  # Shuffle the list to randomize the order of families

    test_data = []
    test_families = []
    total_entries = len(df)
    total_families = len(unique_families)

    # Set up tqdm progress bar
    with tqdm(total=total_families) as pbar:
        for family in unique_families:
            # Separate out all proteins in the current family into the test data
            family_data = df[df['Protein families'] == family]
            test_data.append(family_data)

            # Update the list of test families
            test_families.append(family)

            # Remove the current family data from the original DataFrame
            df = df[df['Protein families'] != family]

            # Calculate the percentage of test data and the percentage of families in the test data
            percent_test_data = sum(len(data) for data in test_data) / total_entries * 100
            percent_test_families = len(test_families) / total_families * 100

            # Update tqdm progress bar with readout of percentages
            pbar.set_description(f'% Test Data: {percent_test_data:.2f}% | % Test Families: {percent_test_families:.2f}%')
            pbar.update(1)

            # Check if the 20% threshold for test data is crossed
            if percent_test_data >= 20:
                break

    # Concatenate the list of test data DataFrames into a single DataFrame
    test_df = pd.concat(test_data, ignore_index=True)

    print(f"nrow train: {len(df)}")
    print(f"nrow test: {len(test_df)}")

    return df, test_df  # Return the remaining data and the test data

In [8]:
def build_labels_region(sequence, feature, feature_re):
    # Start with all 0s
    labels = np.zeros(len(sequence), dtype=np.int64)
    region_re = f"{feature_re}\s(\d+)\.\.(\d+)\;"
    residue_re = f'{feature_re}\s(\d+);'

    found_region = re.findall(region_re, feature)

    for start, end in found_region:
        start = int(start) - 1
        end = int(end)
        assert end <= len(sequence)
        labels[start: end] = 1

    found_residue = re.findall(residue_re, feature)
    for pos in found_residue:
        pos = int(pos) -1
        assert pos <= len(sequence)
        labels[pos] = 1

    return labels


def build_labels_bonds(sequence, feature, feature_re):
    # Start with all 0s
    labels = np.zeros(len(sequence), dtype=np.int64)

    region_re = f"{feature_re}\s(\d+)\.\.(\d+)\;"

    if isinstance(feature, float): # Indicates missing (NaN)
        found_feature = []
    else:
        found_feature = re.findall(region_re, feature)
    for start, end in found_feature:
        start = int(start) - 1
        end = int(end) -1
        assert end <= len(sequence)
        labels[start] = 1
        labels[end] = 1 #check this. is it end or end-1
    return labels


In [9]:
def get_train_test_dataset(train_df, test_df, colname, labeler_func, feature_re):

  train_sequences = []
  train_labels = []
  for row_idx, row in train_df.iterrows():
      row_labels = labeler_func(row["Sequence"], row[colname], feature_re)
      train_sequences.append(row["Sequence"])
      train_labels.append(row_labels)

  test_sequences = []
  test_labels = []
  for row_idx, row in test_df.iterrows():
      row_labels = labeler_func(row["Sequence"], row[colname], feature_re)
      test_sequences.append(row["Sequence"])
      test_labels.append(row_labels)

  train_tokenized = tokenizer(train_sequences, add_special_tokens=False)#, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False, add_special_tokens=False)
  test_tokenized = tokenizer(test_sequences, add_special_tokens=False) #, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False, add_special_tokens=False)


  train_dataset = Dataset.from_dict(train_tokenized)
  test_dataset = Dataset.from_dict(test_tokenized)

  train_dataset = train_dataset.add_column("labels", train_labels)
  test_dataset = test_dataset.add_column("labels", test_labels)

  return train_dataset, test_dataset

In [10]:
acc = load("accuracy")
def compute_metrics(dataset, model):

    # Get the predictions using the trained model
    trainer = Trainer(model=model, data_collator=data_collator)
    predictions, labels, _ = trainer.predict(test_dataset=dataset)

    labels = labels.reshape((-1,))
    predictions = np.argmax(predictions, axis=2)
    predictions = predictions.reshape((-1,))
    predictions = predictions[labels!=-100]
    labels = labels[labels!=-100]
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
    auc = roc_auc_score(labels, predictions)
    mcc = matthews_corrcoef(labels, predictions)
    accuracy = acc.compute(predictions=predictions, references=labels)
    return {'accuracy': accuracy['accuracy'], 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}

In [11]:
all_features = ['Active site', 'Binding site', 'DNA binding', 
                'Topological domain', 'Transmembrane',
                'Disulfide bond', 'Modified residue', 'Propeptide', 'Signal peptide', 'Transit peptide',
                'Beta strand', 'Helix', 'Turn',
                'Coiled coil', 'Compositional bias', 'Domain [FT]', 'Motif', 'Region', 'Repeat', 'Zinc finger']

all_features_re = ['ACT_SITE', 'BINDING', 'DNA_BIND', 
                   'TOPO_DOM', 'TRANSMEM',
                   'DISULFID', 'MOD_RES',  'PROPEP', 'SIGNAL', 'TRANSIT',
                   'STRAND', 'HELIX', 'TURN',
                   'COILED', 'COMPBIAS', 'DOMAIN', 'MOTIF', 'REGION', 'REPEAT', 'ZN_FING']

In [16]:
all_metrics = {}

for i in range(len(all_features)):
  print(all_features[i])

  train_df, test_df = split_data(df, all_features[i])

  if all_features[i] == 'Disulfide bond':
    train_dataset, test_dataset = get_train_test_dataset(train_df, test_df, all_features[i], build_labels_bonds, all_features_re[i])
  else:
    train_dataset, test_dataset = get_train_test_dataset(train_df, test_df, all_features[i], build_labels_region, all_features_re[i])

  model_checkpoint = f"AliSaadatV/esm2_t12_35M_UR50D-finetuned-{all_features_re[i]}_earlystop"

  model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=num_labels)
  train_metrics = compute_metrics(train_dataset, model)
  test_metrics = compute_metrics(test_dataset, model)

  all_metrics[f'train_{all_features_re[i]}'] = train_metrics
  all_metrics[f'test_{all_features_re[i]}'] = test_metrics



Active site
Active site


% Test Data: 20.04% | % Test Families: 19.84%:  20%|█▉        | 124/625 [00:00<00:01, 295.12it/s]


nrow train: 1608
nrow test: 403


Binding site
Binding site


% Test Data: 20.01% | % Test Families: 20.93%:  21%|██        | 278/1328 [00:01<00:04, 235.56it/s]


nrow train: 3078
nrow test: 770


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

DNA binding
DNA binding


% Test Data: 21.69% | % Test Families: 32.29%:  32%|███▏      | 31/96 [00:00<00:00, 397.53it/s]

nrow train: 444
nrow test: 123





config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

Site
Site


% Test Data: 20.05% | % Test Families: 18.46%:  18%|█▊        | 110/596 [00:00<00:01, 353.81it/s]


nrow train: 973
nrow test: 244


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

  _warn_prf(average, modifier, msg_start, len(result))


  _warn_prf(average, modifier, msg_start, len(result))


Intramembrane
Intramembrane


% Test Data: 20.23% | % Test Families: 16.67%:  17%|█▋        | 20/120 [00:00<00:00, 387.20it/s]

nrow train: 138
nrow test: 35





config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

Topological domain
Topological domain


% Test Data: 20.02% | % Test Families: 33.51%:  34%|███▎      | 312/931 [00:01<00:02, 230.94it/s]


nrow train: 2841
nrow test: 711


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

Transmembrane
Transmembrane


% Test Data: 20.14% | % Test Families: 29.97%:  30%|██▉       | 368/1228 [00:01<00:03, 232.44it/s]


nrow train: 3748
nrow test: 945


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

Cross-link
Cross-link


% Test Data: 20.00% | % Test Families: 30.17%:  30%|███       | 207/686 [00:00<00:01, 329.87it/s]


nrow train: 1216
nrow test: 304


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

Disulfide bond
Disulfide bond


% Test Data: 20.46% | % Test Families: 35.92%:  36%|███▌      | 222/618 [00:00<00:01, 265.05it/s]


nrow train: 2620
nrow test: 674


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

Glycosylation
Glycosylation


% Test Data: 20.04% | % Test Families: 31.82%:  32%|███▏      | 288/905 [00:01<00:02, 234.84it/s]


nrow train: 3228
nrow test: 809


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

Lipidation
Lipidation


% Test Data: 20.03% | % Test Families: 35.57%:  36%|███▌      | 90/253 [00:00<00:00, 373.33it/s]


nrow train: 587
nrow test: 147


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

  _warn_prf(average, modifier, msg_start, len(result))


  _warn_prf(average, modifier, msg_start, len(result))


Modified residue
Modified residue


% Test Data: 20.01% | % Test Families: 26.30%:  26%|██▋       | 822/3126 [00:05<00:14, 161.94it/s]


nrow train: 6262
nrow test: 1566


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

Peptide
Peptide


% Test Data: 20.25% | % Test Families: 14.08%:  14%|█▍        | 10/71 [00:00<00:00, 451.01it/s]

nrow train: 130
nrow test: 33





config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

Propeptide
Propeptide


% Test Data: 20.00% | % Test Families: 26.98%:  27%|██▋       | 51/189 [00:00<00:00, 377.05it/s]

nrow train: 548
nrow test: 137





config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

Signal peptide
Signal peptide


% Test Data: 20.49% | % Test Families: 26.91%:  27%|██▋       | 180/669 [00:00<00:01, 266.22it/s]


nrow train: 2526
nrow test: 651


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

Transit peptide
Transit peptide


% Test Data: 20.33% | % Test Families: 20.37%:  20%|██        | 77/378 [00:00<00:00, 387.31it/s]


nrow train: 439
nrow test: 112


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

Beta strand
Beta strand


% Test Data: 20.09% | % Test Families: 25.83%:  26%|██▌       | 615/2381 [00:03<00:09, 188.02it/s]


nrow train: 4629
nrow test: 1164


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

Helix
Helix


% Test Data: 20.01% | % Test Families: 22.31%:  22%|██▏       | 570/2555 [00:03<00:12, 164.26it/s]


nrow train: 4969
nrow test: 1243


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

Turn
Turn


% Test Data: 20.11% | % Test Families: 25.17%:  25%|██▌       | 542/2153 [00:02<00:07, 207.06it/s]


nrow train: 4005
nrow test: 1008


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

Coiled coil
Coiled coil


% Test Data: 20.03% | % Test Families: 33.33%:  33%|███▎      | 197/591 [00:00<00:01, 329.20it/s]


nrow train: 1214
nrow test: 304


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

Compositional bias
Compositional bias


% Test Data: 20.01% | % Test Families: 31.96%:  32%|███▏      | 669/2093 [00:03<00:07, 180.67it/s]


nrow train: 5193
nrow test: 1299


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

Domain [FT]
Domain [FT]


% Test Data: 20.01% | % Test Families: 29.75%:  30%|██▉       | 401/1348 [00:02<00:05, 174.55it/s]


nrow train: 5734
nrow test: 1434


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

Motif
Motif


% Test Data: 20.04% | % Test Families: 30.28%:  30%|███       | 245/809 [00:00<00:01, 301.52it/s]


nrow train: 1596
nrow test: 400


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

Region
Region


% Test Data: 20.00% | % Test Families: 28.98%:  29%|██▉       | 926/3195 [00:07<00:17, 131.17it/s]


nrow train: 8479
nrow test: 2120


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

Repeat
Repeat


% Test Data: 20.53% | % Test Families: 35.03%:  35%|███▌      | 138/394 [00:00<00:00, 330.37it/s]


nrow train: 1161
nrow test: 300


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

Zinc finger
Zinc finger


% Test Data: 20.54% | % Test Families: 61.23%:  61%|██████    | 139/227 [00:00<00:00, 325.23it/s]


nrow train: 1226
nrow test: 317


config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]

In [20]:
columns = all_metrics["train_ACT_SITE"].keys()
metrics_df = {metric: [] for metric in columns}

for item in all_features_re:
    train_key = f"train_{item}"
    test_key = f"test_{item}"
    for metric in columns:
        # Format and append the metrics as a string separated by comma
        train_metric = f"{all_metrics[train_key][metric]:.3f}"
        test_metric = f"{all_metrics[test_key][metric]:.3f}"
        metrics_df[metric].append(f"{train_metric}, {test_metric}")

# Construct the DataFrame
metrics_df = pd.DataFrame(metrics_df, index=all_features_re)

In [22]:
metrics_df.to_csv("metrics_df.tsv", sep="\t")