In [11]:
from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [12]:
%pip install -r "/content/drive/MyDrive/Deep learning models/requirements.txt"



In [13]:
%pip install datasets



In [14]:
import torch
import ast
import torch.nn as nn
import pandas as pd
import numpy as np

from Bio import SeqIO
from glob import glob
from transformers import (AutoModelForTokenClassification, AutoTokenizer,
                          AutoModelForMaskedLM, DataCollatorForTokenClassification,
                           EsmForMaskedLM, EsmTokenizer,
                           TrainingArguments, Trainer
                        )
from transformers.trainer_callback import ProgressCallback
from sklearn.metrics import (accuracy_score, precision_recall_fscore_support,
                             matthews_corrcoef, roc_auc_score)
from sklearn.model_selection import train_test_split
from pprint import pprint
from datasets import Dataset
from datetime import datetime

In [15]:
def convert_to_binary_list(original_binding_sites_lst, sequence_len):
    """Convert a Binding-Active site string to a binary list based on the sequence length."""
    binary_list = [0] * sequence_len  # Initialize a list of zeros

    # Ensure original_binding_sites_lst is a list and not empty
    if isinstance(original_binding_sites_lst, list) and len(original_binding_sites_lst) > 0:
        for idx in original_binding_sites_lst:
            if isinstance(idx, int) and 1 <= idx <= sequence_len:  # Ensure index is valid
                binary_list[idx - 1] = 1

    return binary_list

In [16]:
binding_sites_df = pd.read_csv("/content/drive/MyDrive/Deep learning models/data/development_set/all_binding_sites_complete.csv")
binding_sites_df['binding_sites'] = binding_sites_df['binding_sites'].apply(ast.literal_eval)
# New concated binding sites df
display(binding_sites_df)

Unnamed: 0,prot_id,binding_sites,ligand_type,sequence,sequence_length,binary_binding_sites
0,P02185,"[65, 36, 37, 69, 39, 42, 44, 13, 120, 25, 123,...",metal,,0,[]
1,P09211,"[8, 14, 78, 82, 114, 148, 117, 86, 118, 30, 31]",metal,MPPYTVVYFPVRGRCAALRMLLADQGQSWKEEVVTVETWQEGSLKA...,210,"[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, ..."
2,P00817,"[193, 121, 102, 59, 79, 116, 148, 118, 153, 15...",metal,,0,[]
3,P01112,"[3, 137, 138, 13, 16, 17, 153, 154, 28, 30, 31...",metal,,0,[]
4,P07378,"[39, 399, 376, 377, 219]",metal,MTLNEKKSINECDLKGKKVLIRVDFNVPVKNGKITNDYRIRSALPT...,440,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
...,...,...,...,...,...,...
18048,C7G9B5,"[64, 102, 134, 136, 110, 62, 95]",small,MVFTGIGPVLTPYLENAVVYADENENSGVKKVFTADQLKVAWGDAD...,1356,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
18049,Q91159,"[76, 77, 80, 81, 91, 125, 126, 127]",small,,0,[]
18050,Q4K977,"[101, 102, 71, 14, 17, 152, 121, 123, 124]",small,MAVQIGFLLFPEVQQLDLTGPHDVLASLPDVQVHLIWKEPGPVVAS...,228,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, ..."
18051,Q9Y697,"[128, 257, 258, 255, 232, 234, 203, 235, 207, ...",small,,0,[]


In [17]:
needed_binding_sites_df = binding_sites_df[binding_sites_df['sequence_length'] > 0]
needed_binding_sites_df['binary_binding_sites'] = [
            convert_to_binary_list(row['binding_sites'], row['sequence_length'])
            for _, row in needed_binding_sites_df.iterrows()
        ]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  needed_binding_sites_df['binary_binding_sites'] = [


In [18]:
for _, row in needed_binding_sites_df.iterrows():
    print(row['binding_sites'], type(row['binding_sites']))
    print(row['sequence_length'], type(row['sequence_length']))
    break

[8, 14, 78, 82, 114, 148, 117, 86, 118, 30, 31] <class 'list'>
210 <class 'int'>


In [19]:
display(needed_binding_sites_df)

Unnamed: 0,prot_id,binding_sites,ligand_type,sequence,sequence_length,binary_binding_sites
1,P09211,"[8, 14, 78, 82, 114, 148, 117, 86, 118, 30, 31]",metal,MPPYTVVYFPVRGRCAALRMLLADQGQSWKEEVVTVETWQEGSLKA...,210,"[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, ..."
4,P07378,"[39, 399, 376, 377, 219]",metal,MTLNEKKSINECDLKGKKVLIRVDFNVPVKNGKITNDYRIRSALPT...,440,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
5,P00720,"[96, 128, 132, 105, 11, 76, 142, 144, 80, 114,...",metal,MNIFEMLRIDERLRLKIYKDTEGYYTIGIGHLLTKSPSLNAAKSEL...,164,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, ..."
6,P00698,"[128, 131, 132, 136, 139, 142, 29, 32, 33, 42,...",metal,MRSLLILVLCFLPLAALGKVFGRCELAAAMKRHGLDNYRGYSLGNW...,147,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
8,P0AEC3,"[760, 730, 756]",metal,MKQIRLLAQYYVDLMMKLGLVRFSMLLALALVVLAIVVQMAVTMVL...,778,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
...,...,...,...,...,...,...
18046,Q39VU6,"[98, 133, 134, 135, 136, 137, 138, 139, 269, 1...",small,MANMHQLLTELVNRGGSDLHLTTNSPPQIRIDGKLLPLDMPPLNAV...,365,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
18047,B5ZA44,"[256, 137, 138, 13, 14, 16, 283, 284, 287, 35,...",small,MAKITTVIDIGSNSVRLAVFKKTSQFGFYLLFETKSKVRISEGCYA...,484,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, ..."
18048,C7G9B5,"[64, 102, 134, 136, 110, 62, 95]",small,MVFTGIGPVLTPYLENAVVYADENENSGVKKVFTADQLKVAWGDAD...,1356,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
18050,Q4K977,"[101, 102, 71, 14, 17, 152, 121, 123, 124]",small,MAVQIGFLLFPEVQQLDLTGPHDVLASLPDVQVHLIWKEPGPVVAS...,228,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, ..."


In [20]:
needed_binding_sites_df.ligand_type.value_counts()

Unnamed: 0_level_0,count
ligand_type,Unnamed: 1_level_1
small,6391
metal,4163
nuclear,440


In [21]:
# Data verification

sample_idx = 10
seq_label = needed_binding_sites_df.iloc[sample_idx]['binary_binding_sites']
seq_len = needed_binding_sites_df.iloc[sample_idx]['sequence_length']
seq = needed_binding_sites_df.iloc[sample_idx]['sequence']
seq_binding_sites = needed_binding_sites_df.iloc[sample_idx]['binding_sites']

print(seq_label)
print(seq)
print(len(seq_label))
print(seq_len)
print(seq_binding_sites)

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
MLVMTEYLLSAGICMAIVSILLIGMAISNVSKGQYAKRFFFFATSCLVLTLVVVSSLSSSANASQTDNGVNRSGSEDPTVYSATSTKKLHKEPATLIKAIDGDTVKLMYKGQPMTFRLLLVDTPETKHPKKGVEKYGPEASAFTKKMVENAKKIEVEFDKGQRTDKYGRGLAYIYADGKMVNEALVRQGLAKVAYVYKPNNTHEQHLRKSEAQAKKEKLNIWSEDNADSGQ
231
231
[101, 103, 104, 169, 117, 122, 123, 124, 125]


In [22]:
np.random.seed(42)

train_df, test_df = train_test_split(
    needed_binding_sites_df,
    test_size=0.2,  # 80% train, 20% test
    stratify=needed_binding_sites_df['ligand_type'],
    random_state=42
)

# Verify the split ratios
print("Training set ligand distribution:")
print(train_df['ligand_type'].value_counts())
print("\nTest set ligand distribution:")
print(test_df['ligand_type'].value_counts())

# Calculate percentages to verify similar ratios
print("\nTraining set percentages:")
print(train_df['ligand_type'].value_counts(normalize=True))
print("\nTest set percentages:")
print(test_df['ligand_type'].value_counts(normalize=True))

Training set ligand distribution:
ligand_type
small      5113
metal      3330
nuclear     352
Name: count, dtype: int64

Test set ligand distribution:
ligand_type
small      1278
metal       833
nuclear      88
Name: count, dtype: int64

Training set percentages:
ligand_type
small      0.581353
metal      0.378624
nuclear    0.040023
Name: proportion, dtype: float64

Test set percentages:
ligand_type
small      0.581173
metal      0.378809
nuclear    0.040018
Name: proportion, dtype: float64


In [23]:
test_df.sequence_length.describe()

Unnamed: 0,sequence_length
count,2199.0
mean,374.134152
std,277.577322
min,31.0
25%,214.0
50%,318.0
75%,450.0
max,3863.0


In [24]:
train_df.head()

Unnamed: 0,prot_id,binding_sites,ligand_type,sequence,sequence_length,binary_binding_sites
12928,C4LSE7,"[7, 8, 9, 10, 11, 44, 45, 12, 13, 123, 125, 126]",small,MKLLFVCLGNICRSPAAEAVMKKVIQNHHLTEKYICDSAGTCSYHE...,157,"[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, ..."
2801,Q30V63,"[88, 124, 317, 150]",metal,MTQLELLLERIIDRVNVNLRHQKFDVGDYVRRQTPHLHYSKFYAFY...,477,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
148,P10824,"[200, 43, 47, 178, 181]",metal,MGCTLSAEDKAAVERSKMIDRNLREDGEKAAREVKLLLLGAGESGK...,354,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
13713,Q9X1H7,"[33, 34, 35, 36, 37, 38, 174, 142, 144, 176, 1...",small,MIIRDVELVKVARTPGDYPPPLKGEVAFVGRSNVGKSSLLNALFNR...,195,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
9828,Q6PYX1,"[161, 196, 198, 8, 145, 147, 159]",small,YYYGMDVWGQGTTVTVSSASTKGPSVFPLAPSSKSTSGGTAALGCL...,348,"[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, ..."


### Split sequences into chunks

In [25]:
def split_into_chunks(sequences, labels, chunk_size = 1000):
    """Split sequences and labels into chunks of size "chunk_size" or less."""
    new_sequences = []
    new_labels = []
    for seq, lbl in zip(sequences, labels):
        if len(seq) > chunk_size:
            # Split the sequence and labels into chunks of size "chunk_size" or less
            for i in range(0, len(seq), chunk_size):
                new_sequences.append(seq[i:i+chunk_size])
                new_labels.append(lbl[i:i+chunk_size])
        else:
            new_sequences.append(seq)
            new_labels.append(lbl)

    return new_sequences, new_labels

### Prepare training and testing sequences, labels

In [26]:
# Initial sequences
test_seq = test_df['sequence'].tolist()
test_labels = test_df['binary_binding_sites'].tolist()
train_seq = train_df['sequence'].tolist()
train_labels = train_df['binary_binding_sites'].tolist()

In [27]:
# Apply new sequences by chunking
chunk_size = 1000
test_seq, test_labels = split_into_chunks(test_seq, test_labels, chunk_size)
train_seq, train_labels = split_into_chunks(train_seq, train_labels, chunk_size)

In [28]:
print(len(train_seq[0]))
print(len(train_labels[0]))
print(train_labels[0])

157
157
[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


### Tokenization and helper functions

In [29]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
max_sequence_length = 1000 # here note that data was processed with chunks (context window) of 1000 residues - adapt accordingly

train_tokenized = tokenizer(train_seq, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
test_tokenized = tokenizer(test_seq, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [30]:
train_tokenized

{'input_ids': tensor([[ 0, 20, 15,  ...,  1,  1,  1],
        [ 0, 20, 11,  ...,  1,  1,  1],
        [ 0, 20,  6,  ...,  1,  1,  1],
        ...,
        [ 0, 20, 11,  ...,  1,  1,  1],
        [ 0,  8, 17,  ...,  1,  1,  1],
        [ 0, 20,  7,  ...,  1,  1,  1]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}

In [31]:
def truncate_labels(labels, max_length):
    """Truncate labels to the specified max_length."""
    return [label[:max_length] for label in labels]

def compute_metrics_train(p):
    """Compute metrics for evaluation."""
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove padding (-100 labels)
    predictions = predictions[labels != -100].flatten()
    labels = labels[labels != -100].flatten()

    # Compute accuracy
    accuracy = accuracy_score(labels, predictions)

    # Compute precision, recall, F1 score, and AUC
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
    auc = roc_auc_score(labels, predictions)

    # Compute MCC
    mcc = matthews_corrcoef(labels, predictions)

    return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}

def compute_loss(model, inputs):
    """Custom compute_loss function."""
    logits = model(**inputs).logits
    labels = inputs["labels"]
    loss_fct = nn.CrossEntropyLoss()
    active_loss = inputs["attention_mask"].view(-1) == 1
    active_logits = logits.view(-1, model.config.num_labels)
    active_labels = torch.where(
        active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
    )
    loss = loss_fct(active_logits, active_labels)
    return loss

In [32]:
# Verify some encoded and dedcode input, output
idx = 100
print(len(train_seq[idx]))
print(train_seq[idx])
encoded_input = tokenizer(train_seq[idx])
print(encoded_input)
decoded_output = tokenizer.decode(encoded_input['input_ids'])
print(decoded_output)

print(len(train_seq[idx]), len(encoded_input["input_ids"]))

322
MAASGEPQRQWQEEVAAVVVVGSCMTDLVSLTSRLPKTGETIHGHKFFIGFGGKGANQCVQAARLGAMTSMVCKVGKDSFGNDYIENLKQNDISTEFTYQTKDAATGTASIIVNNEGQNIIVIVAGANLLLNTEDLRAAANVISRAKVMVCQLEITPATSLEALTMARRSGVKTLFNPAPAIADLDPQFYTLSDVFCCNESEAEILTGLTVGSAADAGEAALVLLKRGCQVVIITLGAEGCVVLSQTEPEPKHIPTEKVKAVDTTGAGDSFVGALAFYLAYYPNLSLEDMLNRSNFIAAVSVQAAGTQSSYPYKKDLPLTLF
{'input_ids': [0, 20, 5, 5, 8, 6, 9, 14, 16, 10, 16, 22, 16, 9, 9, 7, 5, 5, 7, 7, 7, 7, 6, 8, 23, 20, 11, 13, 4, 7, 8, 4, 11, 8, 10, 4, 14, 15, 11, 6, 9, 11, 12, 21, 6, 21, 15, 18, 18, 12, 6, 18, 6, 6, 15, 6, 5, 17, 16, 23, 7, 16, 5, 5, 10, 4, 6, 5, 20, 11, 8, 20, 7, 23, 15, 7, 6, 15, 13, 8, 18, 6, 17, 13, 19, 12, 9, 17, 4, 15, 16, 17, 13, 12, 8, 11, 9, 18, 11, 19, 16, 11, 15, 13, 5, 5, 11, 6, 11, 5, 8, 12, 12, 7, 17, 17, 9, 6, 16, 17, 12, 12, 7, 12, 7, 5, 6, 5, 17, 4, 4, 4, 17, 11, 9, 13, 4, 10, 5, 5, 5, 17, 7, 12, 8, 10, 5, 15, 7, 20, 7, 23, 16, 4, 9, 12, 11, 14, 5, 11, 8, 4, 9, 5, 4, 11, 20, 5, 10, 10, 8, 6, 7, 15, 11, 4, 18, 17, 14, 5, 14, 5, 12, 5, 13, 4, 13, 

### Training

In [33]:
train_labels = truncate_labels(train_labels, max_sequence_length)
test_labels = truncate_labels(test_labels, max_sequence_length)

train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)

In [34]:
timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

model = AutoModelForTokenClassification.from_pretrained("facebook/esm2_t6_8M_UR50D", num_labels=2)
training_args = TrainingArguments(
    output_dir=f"/content/drive/MyDrive/Deep learning models/data/trained_models/esm2_t6_8M-binding-sites_{timestamp}",
    seed = 42,
    num_train_epochs = 3,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_mcc",
    greater_is_better=True,
    logging_dir=f"/content/drive/MyDrive/Deep learning models/data/trained_models/esm2_t6_8M-binding-sites_{timestamp}",
    logging_strategy="steps",
    logging_steps=10,
    save_total_limit=5,
    bf16 = True,
    report_to="none" # Disable Wandb
)

Some weights of EsmForTokenClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [35]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics_train,
    tokenizer=tokenizer,
    data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer),
    callbacks=[
        ProgressCallback()
    ]
)

  trainer = Trainer(


In [36]:
trainer.train()

  0%|          | 0/3426 [00:00<?, ?it/s]

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1,Auc,Mcc
1,0.1449,0.131011,0.963798,0.659679,0.114758,0.195505,0.556199,0.265155
2,0.1143,0.125774,0.964647,0.667625,0.154712,0.25121,0.575821,0.310294
3,0.1372,0.124849,0.964464,0.641975,0.164891,0.262388,0.580613,0.313513


{'loss': 0.6409, 'grad_norm': 2.7544972896575928, 'learning_rate': 4.988324576765908e-05, 'epoch': 0.01}
{'loss': 0.3762, 'grad_norm': 1.3896403312683105, 'learning_rate': 4.9737302977232924e-05, 'epoch': 0.02}
{'loss': 0.191, 'grad_norm': 0.09449131041765213, 'learning_rate': 4.959136018680677e-05, 'epoch': 0.03}
{'loss': 0.1853, 'grad_norm': 0.337075799703598, 'learning_rate': 4.944541739638062e-05, 'epoch': 0.04}
{'loss': 0.171, 'grad_norm': 0.5517352819442749, 'learning_rate': 4.929947460595447e-05, 'epoch': 0.04}
{'loss': 0.1582, 'grad_norm': 0.26244011521339417, 'learning_rate': 4.915353181552832e-05, 'epoch': 0.05}
{'loss': 0.1983, 'grad_norm': 0.15868711471557617, 'learning_rate': 4.9007589025102165e-05, 'epoch': 0.06}
{'loss': 0.1699, 'grad_norm': 0.3275112509727478, 'learning_rate': 4.886164623467601e-05, 'epoch': 0.07}
{'loss': 0.1766, 'grad_norm': 0.24391567707061768, 'learning_rate': 4.8715703444249855e-05, 'epoch': 0.08}
{'loss': 0.1659, 'grad_norm': 0.045629944652318954,

  0%|          | 0/284 [00:00<?, ?it/s]

{'eval_loss': 0.13101059198379517, 'eval_accuracy': 0.9637981770247751, 'eval_precision': 0.6596791833758658, 'eval_recall': 0.11475773718924404, 'eval_f1': 0.19550537518232403, 'eval_auc': 0.5561989928386358, 'eval_mcc': 0.26515478781315255, 'eval_runtime': 45.4412, 'eval_samples_per_second': 49.867, 'eval_steps_per_second': 6.25, 'epoch': 1.0}
{'loss': 0.1144, 'grad_norm': 0.21268026530742645, 'learning_rate': 3.324576765907764e-05, 'epoch': 1.01}
{'loss': 0.1384, 'grad_norm': 0.19536332786083221, 'learning_rate': 3.309982486865149e-05, 'epoch': 1.02}
{'loss': 0.148, 'grad_norm': 0.494547575712204, 'learning_rate': 3.295388207822534e-05, 'epoch': 1.02}
{'loss': 0.1312, 'grad_norm': 0.22485028207302094, 'learning_rate': 3.2807939287799185e-05, 'epoch': 1.03}
{'loss': 0.1194, 'grad_norm': 0.4368244409561157, 'learning_rate': 3.266199649737303e-05, 'epoch': 1.04}
{'loss': 0.1504, 'grad_norm': 0.2168777734041214, 'learning_rate': 3.251605370694688e-05, 'epoch': 1.05}
{'loss': 0.1292, 'gr

  0%|          | 0/284 [00:00<?, ?it/s]

{'eval_loss': 0.12577374279499054, 'eval_accuracy': 0.9646465812833269, 'eval_precision': 0.6676245210727969, 'eval_recall': 0.15471207508878743, 'eval_f1': 0.25120996807743795, 'eval_auc': 0.5758209983310618, 'eval_mcc': 0.3102942552490858, 'eval_runtime': 45.4023, 'eval_samples_per_second': 49.909, 'eval_steps_per_second': 6.255, 'epoch': 2.0}
{'loss': 0.1253, 'grad_norm': 0.26059338450431824, 'learning_rate': 1.6608289550496207e-05, 'epoch': 2.01}
{'loss': 0.1276, 'grad_norm': 0.1974610537290573, 'learning_rate': 1.6462346760070052e-05, 'epoch': 2.01}
{'loss': 0.1281, 'grad_norm': 0.2528744041919708, 'learning_rate': 1.63164039696439e-05, 'epoch': 2.02}
{'loss': 0.1205, 'grad_norm': 0.2423267960548401, 'learning_rate': 1.6170461179217745e-05, 'epoch': 2.03}
{'loss': 0.1357, 'grad_norm': 0.16660158336162567, 'learning_rate': 1.6024518388791597e-05, 'epoch': 2.04}
{'loss': 0.1062, 'grad_norm': 0.280439555644989, 'learning_rate': 1.5878575598365442e-05, 'epoch': 2.05}
{'loss': 0.1247, 

  0%|          | 0/284 [00:00<?, ?it/s]

{'eval_loss': 0.12484878301620483, 'eval_accuracy': 0.9644642594512599, 'eval_precision': 0.6419753086419753, 'eval_recall': 0.16489091831557584, 'eval_f1': 0.26238772832778284, 'eval_auc': 0.5806127651608086, 'eval_mcc': 0.3135131992861407, 'eval_runtime': 45.4529, 'eval_samples_per_second': 49.854, 'eval_steps_per_second': 6.248, 'epoch': 3.0}
{'train_runtime': 1481.8462, 'train_samples_per_second': 18.494, 'train_steps_per_second': 2.312, 'train_loss': 0.13616159779247455, 'epoch': 3.0}


TrainOutput(global_step=3426, training_loss=0.13616159779247455, metrics={'train_runtime': 1481.8462, 'train_samples_per_second': 18.494, 'train_steps_per_second': 2.312, 'total_flos': 1216644372090000.0, 'train_loss': 0.13616159779247455, 'epoch': 3.0})

### Evaluation

In [37]:
from accelerate import Accelerator

accelerator = Accelerator()
saved_model_path = f"/content/drive/MyDrive/Deep learning models/data/trained_models/esm2_t6_8M-binding-sites_2025-02-04_10-38-23/checkpoint-3426"

loaded_model = AutoModelForTokenClassification.from_pretrained(saved_model_path)
model = accelerator.prepare(loaded_model)

In [38]:
# Define label mappings
id2label = {0: "No binding site", 1: "Binding site"}
label2id = {v: k for k, v in id2label.items()}

# Create a data collator
data_collator = DataCollatorForTokenClassification(tokenizer)

In [39]:
def compute_metrics_dataset(dataset, data_collator):
    # Get the predictions using the trained model
    trainer = Trainer(model=model, data_collator=data_collator)
    predictions, labels, _ = trainer.predict(test_dataset=dataset)

    # Remove padding and special tokens
    mask = labels != -100
    true_labels = labels[mask].flatten()
    flat_predictions = np.argmax(predictions, axis=2)[mask].flatten().tolist()

    # Compute the metrics
    accuracy = accuracy_score(true_labels, flat_predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, flat_predictions, average='binary')
    auc = roc_auc_score(true_labels, flat_predictions)
    mcc = matthews_corrcoef(true_labels, flat_predictions)  # Compute the MCC

    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "auc": auc, "mcc": mcc}  # Include the MCC in the returned dictionary

In [40]:
test_metrics = compute_metrics_dataset(test_dataset, data_collator)
print(test_metrics)

[34m[1mwandb[0m: Currently logged in as: [33mdangkhoa20006[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


{'accuracy': 0.9644642594512599, 'precision': 0.6419753086419753, 'recall': 0.16489091831557584, 'f1': 0.26238772832778284, 'auc': 0.5806127651608086, 'mcc': 0.3135131992861407}


### Inference

In [41]:
model.eval()

EsmForTokenClassification(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 320, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 320, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-5): 6 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=320, out_features=320, bias=True)
              (key): Linear(in_features=320, out_features=320, bias=True)
              (value): Linear(in_features=320, out_features=320, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=320, out_features=320, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((320,), eps=1e-05, el

In [42]:
def infer_binding_sites(sequence, tokenizer, model, accelerator):
  tokenized_sequence = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')

  tokenized_sequence = {key: value.to(accelerator.device) for key, value in tokenized_sequence.items()}
  tokens = tokenizer.convert_ids_to_tokens(tokenized_sequence["input_ids"][0])  # Convert input ids back to tokens (aminoacids in this case, but also special tokens generated by the tokenizer)

  with torch.no_grad():
      logits = model(**tokenized_sequence).logits

  predictions = torch.argmax(logits, dim=2)
  predictions = predictions[0].cpu().numpy()


  special_tokens = [tokenizer.pad_token, tokenizer.cls_token, tokenizer.sep_token, tokenizer.eos_token]
  special_tokens = [token for token in special_tokens if token is not None]  # Remove None tokens
  filtered_predictions =[]
  for token, pred in zip(tokens, predictions):
    if token not in special_tokens:
      filtered_predictions.append(pred)

  if len(filtered_predictions) != len(sequence):
      error_message = (
          f"Error: Length mismatch! Sequence length: {len(sequence)}, "
          f"Filtered predictions length: {len(filtered_predictions)}. "
          "This could be due to unaccounted special tokens in the tokenizer."
      )
      raise ValueError(error_message)
  return filtered_predictions

In [44]:
sample_sequence = test_df.iloc[100]['sequence']
print(f"Sequence: {sample_sequence}")

filtered_predictions = infer_binding_sites(sample_sequence, tokenizer, model, accelerator)
print(f"Filtered predictions: {filtered_predictions}")

Sequence: MAFVVTDNCIKCKYTDCVEVCPVDCFYEGPNFLVIHPDECIDCALCEPECPAQAIFSEDEVPEDMQEFIQLNAELAEVWPNITEKKDPLPDAEDWDGVKGKLQHLER
Filtered predictions: [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
