In [None]:
from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
%pip install -r "/content/drive/MyDrive/Protein-binding/requirements.txt"
%pip install datasets
%pip install peft



In [None]:
import torch
import ast
import torch.nn as nn
import pandas as pd
import numpy as np
import torch.nn.functional as F

from Bio import SeqIO
from glob import glob
from transformers import (AutoModelForTokenClassification, AutoTokenizer,
                          AutoModelForMaskedLM, DataCollatorForTokenClassification,
                           EsmForMaskedLM, EsmTokenizer,
                           TrainingArguments, Trainer, TrainerCallback
                        )
from transformers.trainer_callback import ProgressCallback
from sklearn.metrics import (accuracy_score, precision_recall_fscore_support,
                             matthews_corrcoef, roc_auc_score)
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from pprint import pprint
from datasets import Dataset
from datetime import datetime
from datasets import Dataset
from torch.utils.data import DataLoader
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType

### Preparing train-test dataset

In [None]:
needed_training_sites_df = pd.read_csv("/content/drive/MyDrive/Protein-binding/data/development_set/full_grouped_train_binding_sites_df.csv")
needed_training_sites_df['binding_sites'] = needed_training_sites_df['binding_sites'].apply(ast.literal_eval)
needed_training_sites_df['any_ligand_binding_sites'] = needed_training_sites_df['any_ligand_binding_sites'].apply(ast.literal_eval)

In [None]:
np.random.seed(42)

train_df, val_df = train_test_split(
    needed_training_sites_df,
    test_size = 0.1,  # 90% train, 10% test
    stratify = None,
    random_state = 42
)


In [None]:
test_df = pd.read_csv("/content/drive/MyDrive/Protein-binding/data/development_set/full_grouped_test_binding_sites_df.csv")
test_df['binding_sites'] = test_df['binding_sites'].apply(ast.literal_eval)
test_df['any_ligand_binding_sites'] = test_df['any_ligand_binding_sites'].apply(ast.literal_eval)

In [None]:
def split_into_chunks(sequences, labels, chunk_size = 1000):
    """Split sequences and labels into chunks of size "chunk_size" or less."""
    new_sequences = []
    new_labels = []
    for seq, lbl in zip(sequences, labels):
        if len(seq) > chunk_size:
            # Split the sequence and labels into chunks of size "chunk_size" or less
            for i in range(0, len(seq), chunk_size):
                new_sequences.append(seq[i:i+chunk_size])
                new_labels.append(lbl[i:i+chunk_size])
        else:
            new_sequences.append(seq)
            new_labels.append(lbl)

    return new_sequences, new_labels

In [None]:
# Initial sequences
test_seq = test_df['sequence'].tolist()
test_labels = test_df['any_ligand_binding_sites'].tolist()

train_seq = train_df['sequence'].tolist()
train_labels = train_df['any_ligand_binding_sites'].tolist()

val_seq = val_df['sequence'].tolist()
val_labels = val_df['any_ligand_binding_sites'].tolist()

# Apply new sequences by chunking
chunk_size = 1000

test_seq, test_labels = split_into_chunks(test_seq, test_labels, chunk_size)
train_seq, train_labels = split_into_chunks(train_seq, train_labels, chunk_size)
val_seq, val_labels = split_into_chunks(val_seq, val_labels, chunk_size)

### Tokenization and preparing model

In [None]:
pretrained_model = "facebook/esm2_t33_650M_UR50D"

tokenizer = EsmTokenizer.from_pretrained(pretrained_model)
max_sequence_length = 1000

train_tokenized = tokenizer(train_seq, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
eval_tokenized = tokenizer(val_seq, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
test_tokenized = tokenizer(test_seq, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
def truncate_labels(labels, max_length):
    """Truncate labels to the specified max_length."""
    return [label[:max_length] for label in labels]

def compute_metrics_train(p):
    """Compute metrics for evaluation."""
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove padding (-100 labels)
    predictions = predictions[labels != -100].flatten()
    labels = labels[labels != -100].flatten()

    # Compute accuracy
    accuracy = accuracy_score(labels, predictions)

    # Compute precision, recall, F1 score, and AUC
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
    auc = roc_auc_score(labels, predictions)

    # Compute MCC
    mcc = matthews_corrcoef(labels, predictions)

    return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}

In [None]:
train_labels = truncate_labels(train_labels, max_sequence_length)
eval_labels = truncate_labels(val_labels, max_sequence_length)
test_labels = truncate_labels(test_labels, max_sequence_length)

train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
eval_dataset = Dataset.from_dict({k: v for k, v in eval_tokenized.items()}).add_column("labels", eval_labels)
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)

### Configure model

In [None]:
KEY_AMINO_ACIDS = {'C', 'H', 'R', 'W', 'Y'}

config = {
        "lora_alpha": 1, #try 0.5, 1, 2, ..., 16
        "lora_dropout": 0.2,
        "lr": 5.5e-04,
        "lr_scheduler_type": "cosine",
        "max_grad_norm": 100,
        "num_train_epochs": 10,
        "per_device_train_batch_size": 8,
        "r": 2,
        "weight_decay": 0.2,
        # Add other hyperparameters as needed
    }

peft_config = LoraConfig(
        task_type = TaskType.TOKEN_CLS,
        inference_mode = False,
        r = config["r"],
        lora_alpha = config["lora_alpha"],
        target_modules = ["query", "key", "value"],
        lora_dropout = config["lora_dropout"],
        bias = "all"
    )

In [None]:
timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

training_args = TrainingArguments(
    output_dir=f"/content/drive/MyDrive/Protein-binding/trained_models/esm2_t33_650M-binding-sites_{timestamp}",
    seed = 42,
    num_train_epochs = config["num_train_epochs"],
    eval_strategy = "epoch",
    save_strategy = "epoch",
    load_best_model_at_end = True,
    metric_for_best_model = "eval_f1",
    greater_is_better = True,
    logging_dir = f"/content/drive/MyDrive/Protein-binding/trained_models/esm2_t33_650M-binding-sites_{timestamp}",
    logging_strategy = "steps",
    logging_steps = 10,
    save_total_limit = 1,
    bf16 = True,
    report_to = "none",
    lr_scheduler_type = "cosine",
    max_grad_norm = config['max_grad_norm']
)

In [None]:
class ESM2WithBiasedAttention(AutoModelForTokenClassification):
    def __init__(self, *args):
        super().__init__(*args)
        self.key_attention_bias = nn.Parameter(torch.zeros(self.config.hidden_size))  # Learnable bias for key residues
        self.key_residues = {aa: idx for idx, aa in enumerate('ACDEFGHIKLMNPQRSTVWY') if aa in KEY_AMINO_ACIDS}
        # Ensure esm2 and classifier are accessible (they are inherited)
        self.esm2 = AutoModelForTokenClassification.from_pretrained(pretrained_model)
        self.classifier = nn.Linear(self.config.hidden_size, self.config.num_labels)

    def forward(self, input_ids, attention_mask=None, output_attentions=False, **kwargs):
        outputs = self.esm2(input_ids=input_ids, attention_mask=attention_mask, output_attentions=output_attentions)
        sequence_output = outputs.last_hidden_state  # [batch_size, seq_len, hidden_size]
        attentions = outputs.attentions if output_attentions else None  # Attention weights from all layers

        # Create a residue mask for key amino acids
        batch_size, seq_len = input_ids.shape
        residue_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool).to(input_ids.device)

        for b in range(batch_size):
            for i in range(seq_len):
                token_id = input_ids[b, i]
                token = tokenizer.decode(token_id).strip()
                if token in KEY_AMINO_ACIDS:
                    residue_mask[b, i] = True

        # Bias the sequence output for key residues
        new_sequence_output = sequence_output.clone()
        if residue_mask.any():
            # Apply bias to the hidden state for key residues
            bias = self.key_attention_bias.unsqueeze(0).unsqueeze(0)  # [1, 1, hidden_size]
            new_sequence_output[residue_mask] += bias

        # Pass through the classifier
        logits = self.classifier(new_sequence_output)  # Use the inherited classifier

        if output_attentions:
            return {"logits": logits, "attentions": attentions}
        return logits

In [None]:
model = ESM2WithBiasedAttention.from_pretrained(pretrained_model, num_labels=2)
# model = get_peft_model(base_model, peft_config)

Some weights of EsmForTokenClassification were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Training

#### Define loss functions

In [None]:
class WeightedCrossEntropyLoss(nn.Module):
    def __init__(self, pos_weight):  # High weight since binding sites are rare, edit soon for experiments
        super().__init__()
        self.pos_weight = pos_weight

    def forward(self, logits, labels, attention_mask):
        # Handle class imbalance with weighted loss
        weight = torch.ones(2).to(logits.device)
        weight[1] = self.pos_weight  # Higher weight for binding sites

        loss_fct = nn.CrossEntropyLoss(weight = weight, ignore_index=-100)
        active_loss = attention_mask.view(-1) == 1
        active_logits = logits.view(-1, logits.shape[-1])
        active_labels = torch.where(active_loss, labels.view(-1),
                                  torch.tensor(loss_fct.ignore_index).type_as(labels))

        return loss_fct(active_logits, active_labels)

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, ignore_index=-100):
        super().__init__()
        self.alpha = alpha  # Weight for positive class
        self.gamma = gamma  # Focusing parameter
        self.ignore_index = ignore_index

    def forward(self, logits, labels, attention_mask):
        # Ensure shapes are correct
        if logits.shape[:2] != labels.shape or logits.shape[:2] != attention_mask.shape:
            raise ValueError(f"Shape mismatch: logits {logits.shape}, labels {labels.shape}, attention_mask {attention_mask.shape}")

        # Ensure all tensors are on the same device
        device = logits.device
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        # Compute probabilities for class 1 (binding) using sigmoid
        probs = torch.sigmoid(logits[:, :, 1])  # [batch_size, seq_len]

        # Create one-hot labels for class 1 (binding)
        batch_size, seq_len = labels.shape
        one_hot_labels = (labels == 1).float()  # Binary mask for class 1

        # Mask out ignored indices
        active_mask = attention_mask.bool() & (labels != self.ignore_index)
        if not active_mask.any():
            return torch.tensor(0.0).to(device)

        probs_positive = probs[active_mask]
        one_hot_labels = one_hot_labels[active_mask]

        # Compute focal loss
        ce_loss = -torch.log(probs_positive + 1e-8) * (1 - probs_positive) ** self.gamma
        loss = self.alpha * ce_loss * one_hot_labels
        return loss.mean()

# Add position-aware loss
class PositionAwareLoss(nn.Module):
    def __init__(self, pos_weight, position_weight):
        super().__init__()
        self.weighted_ce = WeightedCrossEntropyLoss(pos_weight)
        self.focal_loss = FocalLoss() # Using default parameters' values
        self.position_weight = position_weight

    def forward(self, logits, labels, attention_mask):
        # base_loss = self.weighted_ce(logits, labels, attention_mask)
        base_loss = self.weighted_ce(logits, labels, attention_mask)

        # Position-aware component
        probs = torch.softmax(logits, dim=-1)[:, :, 1]  # Get binding probabilities
        position_loss = torch.tensor(0.0).to(logits.device)

        # Penalize offset predictions
        batch_size = logits.shape[0]
        for b in range(batch_size):
            for i in range(1, logits.shape[1]-1):
                # Encourage predictions to match true binding site positions
                if labels[b,i] == 1 or labels[b,i-1] == 1 or labels[b,i+1] == 1:
                    position_loss += torch.abs(probs[b,i] - (labels[b,i] == 1).float())

        return base_loss + self.position_weight * position_loss

In [None]:
num_layers_to_freeze = 12
for i, layer in enumerate(model.esm.encoder.layer):
    if i < num_layers_to_freeze:
        for param in layer.parameters():
            param.requires_grad = False

unfreeze_schedule = {
    4: (10, 12),  # Unfreeze layers 10-11
    6: (6, 10),   # Unfreeze layers 6-9
    8: (0, 6),    # Unfreeze layers 0-5
}

class GradualUnfreezeCallback(TrainerCallback):
    def __init__(self, model, total_epochs, unfreeze_schedule):
        self.model = model
        self.total_epochs = total_epochs
        self.unfreeze_schedule = unfreeze_schedule

    def on_epoch_begin(self, args, state, control, **kwargs):
        epoch = state.epoch
        if epoch in self.unfreeze_schedule:
            start_idx, end_idx = self.unfreeze_schedule[epoch]
            print(f"Unfreezing layers {start_idx} to {end_idx-1} at epoch {epoch}")
            for layer_idx, layer in enumerate(self.model.esm.encoder.layer):
                if start_idx <= layer_idx < end_idx:
                    for param in layer.parameters():
                        param.requires_grad = True

In [None]:
class CustomTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_fct = PositionAwareLoss(pos_weight = 8.0, position_weight = 0.5)

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        logits = model(**inputs).logits
        labels = inputs["labels"]
        attention_mask = inputs["attention_mask"]

        loss = self.loss_fct(logits, labels, attention_mask)

        return (loss, {"logits": logits}) if return_outputs else loss

In [None]:
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics_train,
    tokenizer=tokenizer,
    data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer),
    callbacks=[ProgressCallback(), GradualUnfreezeCallback(model = model, total_epochs = 10, unfreeze_schedule = unfreeze_schedule)]
)

  super().__init__(*args, **kwargs)


In [None]:
trainer.train()

  0%|          | 0/1140 [00:00<?, ?it/s]

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1,Auc,Mcc
1,38.5343,49.581497,0.672431,0.168849,0.637076,0.266947,0.656579,0.190854
2,26.3085,46.992176,0.731404,0.204642,0.64747,0.310991,0.693772,0.247253
3,14.1031,45.592865,0.712294,0.194864,0.661953,0.301093,0.689723,0.237342
4,9.9518,46.116379,0.725055,0.201732,0.654967,0.308458,0.693631,0.245256
5,6.9175,46.489552,0.772017,0.231495,0.618674,0.336921,0.703265,0.273411
6,5.2952,46.292305,0.760708,0.228311,0.653774,0.338434,0.712763,0.279835
7,6.1205,47.72583,0.822888,0.278108,0.558869,0.371398,0.704514,0.304871
8,4.6984,48.225636,0.831167,0.287861,0.545067,0.376752,0.702893,0.309521
9,2.6324,48.604328,0.839637,0.299078,0.530584,0.382532,0.701072,0.314769
10,3.399,48.762722,0.842237,0.302757,0.525814,0.384261,0.700367,0.316355


{'loss': 57.1741, 'grad_norm': 24.93232536315918, 'learning_rate': 4.999050767562379e-05, 'epoch': 0.09}
{'loss': 50.7771, 'grad_norm': 29.12437629699707, 'learning_rate': 4.996203791083291e-05, 'epoch': 0.18}
{'loss': 62.4631, 'grad_norm': 30.102428436279297, 'learning_rate': 4.991461232516675e-05, 'epoch': 0.26}
{'loss': 54.978, 'grad_norm': 30.36556625366211, 'learning_rate': 4.984826693294874e-05, 'epoch': 0.35}
{'loss': 50.0495, 'grad_norm': 29.2490234375, 'learning_rate': 4.976305211593758e-05, 'epoch': 0.44}
{'loss': 54.7899, 'grad_norm': 57.500370025634766, 'learning_rate': 4.965903258506806e-05, 'epoch': 0.53}
{'loss': 52.0325, 'grad_norm': 39.0261344909668, 'learning_rate': 4.953628733131045e-05, 'epoch': 0.61}
{'loss': 47.6829, 'grad_norm': 43.170223236083984, 'learning_rate': 4.9394909565685894e-05, 'epoch': 0.7}
{'loss': 44.2027, 'grad_norm': 59.30971908569336, 'learning_rate': 4.923500664848326e-05, 'epoch': 0.79}
{'loss': 35.9737, 'grad_norm': 46.03225326538086, 'learnin

  0%|          | 0/38 [00:00<?, ?it/s]

{'eval_loss': 49.58149719238281, 'eval_accuracy': 0.6724305699564517, 'eval_precision': 0.16884934971098267, 'eval_recall': 0.6370761628897598, 'eval_f1': 0.2669474886659765, 'eval_auc': 0.656579264127034, 'eval_mcc': 0.19085354187704667, 'eval_runtime': 32.9996, 'eval_samples_per_second': 9.091, 'eval_steps_per_second': 1.152, 'epoch': 1.0}
{'loss': 24.8648, 'grad_norm': 89.06075286865234, 'learning_rate': 4.864543104251587e-05, 'epoch': 1.05}
{'loss': 24.9986, 'grad_norm': 36.03846740722656, 'learning_rate': 4.841278102992637e-05, 'epoch': 1.14}
{'loss': 24.3842, 'grad_norm': 66.42259216308594, 'learning_rate': 4.8162351680370044e-05, 'epoch': 1.23}
{'loss': 26.8688, 'grad_norm': 67.50228118896484, 'learning_rate': 4.789433316637644e-05, 'epoch': 1.32}
{'loss': 22.216, 'grad_norm': 69.79364776611328, 'learning_rate': 4.760892901743944e-05, 'epoch': 1.4}
{'loss': 26.3032, 'grad_norm': 121.55693054199219, 'learning_rate': 4.730635596545985e-05, 'epoch': 1.49}
{'loss': 19.3264, 'grad_no

  0%|          | 0/38 [00:00<?, ?it/s]

{'eval_loss': 46.9921760559082, 'eval_accuracy': 0.7314042335976009, 'eval_precision': 0.2046421455113361, 'eval_recall': 0.6474697563469075, 'eval_f1': 0.31099107946640475, 'eval_auc': 0.6937718369907716, 'eval_mcc': 0.24725285303124378, 'eval_runtime': 32.9625, 'eval_samples_per_second': 9.101, 'eval_steps_per_second': 1.153, 'epoch': 2.0}
{'loss': 19.9436, 'grad_norm': 112.7906723022461, 'learning_rate': 4.514412764152446e-05, 'epoch': 2.02}
{'loss': 14.3289, 'grad_norm': 87.37691497802734, 'learning_rate': 4.4728512734909844e-05, 'epoch': 2.11}
{'loss': 16.5702, 'grad_norm': 144.21176147460938, 'learning_rate': 4.4297916272908024e-05, 'epoch': 2.19}
{'loss': 14.4181, 'grad_norm': 54.26575469970703, 'learning_rate': 4.385266524442241e-05, 'epoch': 2.28}
{'loss': 13.0211, 'grad_norm': 142.22775268554688, 'learning_rate': 4.3393097766828293e-05, 'epoch': 2.37}
{'loss': 13.6944, 'grad_norm': 219.2491455078125, 'learning_rate': 4.2919562829211283e-05, 'epoch': 2.46}
{'loss': 11.825, 'gr

  0%|          | 0/38 [00:00<?, ?it/s]

{'eval_loss': 45.592864990234375, 'eval_accuracy': 0.7122940228748266, 'eval_precision': 0.19486382103626423, 'eval_recall': 0.6619526324757199, 'eval_f1': 0.30109276912345967, 'eval_auc': 0.6897232363364167, 'eval_mcc': 0.23734225635200532, 'eval_runtime': 32.861, 'eval_samples_per_second': 9.129, 'eval_steps_per_second': 1.156, 'epoch': 3.0}
{'loss': 12.2953, 'grad_norm': 115.70751190185547, 'learning_rate': 3.9245201437756654e-05, 'epoch': 3.07}
{'loss': 11.2718, 'grad_norm': 43.62197494506836, 'learning_rate': 3.867370395306068e-05, 'epoch': 3.16}
{'loss': 9.5909, 'grad_norm': 67.68357849121094, 'learning_rate': 3.8091822849696954e-05, 'epoch': 3.25}
{'loss': 10.4735, 'grad_norm': 33.231876373291016, 'learning_rate': 3.7500000000000003e-05, 'epoch': 3.33}
{'loss': 8.9016, 'grad_norm': 18.987834930419922, 'learning_rate': 3.689868482592684e-05, 'epoch': 3.42}
{'loss': 10.736, 'grad_norm': 139.0265655517578, 'learning_rate': 3.628833395777224e-05, 'epoch': 3.51}
{'loss': 7.8383, 'gra

  0%|          | 0/38 [00:00<?, ?it/s]

{'eval_loss': 46.11637878417969, 'eval_accuracy': 0.7250554323725056, 'eval_precision': 0.20173182891629493, 'eval_recall': 0.6549667745782928, 'eval_f1': 0.30845771144278605, 'eval_auc': 0.6936308705696814, 'eval_mcc': 0.24525635696740986, 'eval_runtime': 32.9794, 'eval_samples_per_second': 9.097, 'eval_steps_per_second': 1.152, 'epoch': 4.0}
Unfreezing layers 10 to 11 at epoch 4.0
{'loss': 12.5517, 'grad_norm': 70.51602172851562, 'learning_rate': 3.246287027504237e-05, 'epoch': 4.04}
{'loss': 6.5206, 'grad_norm': 65.14412689208984, 'learning_rate': 3.180258662113338e-05, 'epoch': 4.12}
{'loss': 7.8464, 'grad_norm': 105.302001953125, 'learning_rate': 3.1137137178519985e-05, 'epoch': 4.21}
{'loss': 8.692, 'grad_norm': 218.39413452148438, 'learning_rate': 3.04670272801594e-05, 'epoch': 4.3}
{'loss': 7.8122, 'grad_norm': 4.0827250480651855, 'learning_rate': 2.9792765798093465e-05, 'epoch': 4.39}
{'loss': 7.841, 'grad_norm': 23.746862411499023, 'learning_rate': 2.9114864757018352e-05, 'ep

  0%|          | 0/38 [00:00<?, ?it/s]

{'eval_loss': 46.48955154418945, 'eval_accuracy': 0.7720174193239643, 'eval_precision': 0.2314950589735416, 'eval_recall': 0.6186743908672687, 'eval_f1': 0.3369212211190498, 'eval_auc': 0.7032653897314168, 'eval_mcc': 0.2734112823602604, 'eval_runtime': 33.1196, 'eval_samples_per_second': 9.058, 'eval_steps_per_second': 1.147, 'epoch': 5.0}
{'loss': 7.6828, 'grad_norm': 36.11544418334961, 'learning_rate': 2.4311141440795953e-05, 'epoch': 5.09}
{'loss': 7.6878, 'grad_norm': 59.20942306518555, 'learning_rate': 2.3622805991103362e-05, 'epoch': 5.18}
{'loss': 6.3058, 'grad_norm': 75.95262908935547, 'learning_rate': 2.2935516363191693e-05, 'epoch': 5.26}
{'loss': 4.8919, 'grad_norm': 12.741786003112793, 'learning_rate': 2.224979447514802e-05, 'epoch': 5.35}
{'loss': 5.6805, 'grad_norm': 36.316898345947266, 'learning_rate': 2.1566161054539798e-05, 'epoch': 5.44}
{'loss': 6.2881, 'grad_norm': 68.0743408203125, 'learning_rate': 2.088513524298165e-05, 'epoch': 5.53}
{'loss': 8.2877, 'grad_norm'

  0%|          | 0/38 [00:00<?, ?it/s]

{'eval_loss': 46.29230499267578, 'eval_accuracy': 0.7607076201566463, 'eval_precision': 0.22831131738664762, 'eval_recall': 0.6537740671323905, 'eval_f1': 0.33843439911797135, 'eval_auc': 0.7127634855197328, 'eval_mcc': 0.27983477225511644, 'eval_runtime': 32.9386, 'eval_samples_per_second': 9.108, 'eval_steps_per_second': 1.154, 'epoch': 6.0}
Unfreezing layers 6 to 9 at epoch 6.0
{'loss': 3.3458, 'grad_norm': 5.499903678894043, 'learning_rate': 1.6882513269882917e-05, 'epoch': 6.05}
{'loss': 7.0819, 'grad_norm': 55.14468002319336, 'learning_rate': 1.6234061120181142e-05, 'epoch': 6.14}
{'loss': 5.0881, 'grad_norm': 2.896456241607666, 'learning_rate': 1.5592265701304114e-05, 'epoch': 6.23}
{'loss': 4.577, 'grad_norm': 32.65731430053711, 'learning_rate': 1.495761438367577e-05, 'epoch': 6.32}
{'loss': 4.1656, 'grad_norm': 14.732405662536621, 'learning_rate': 1.433058911258991e-05, 'epoch': 6.4}
{'loss': 4.5672, 'grad_norm': 37.43701171875, 'learning_rate': 1.3711666042227772e-05, 'epoch'

  0%|          | 0/38 [00:00<?, ?it/s]

{'eval_loss': 47.725830078125, 'eval_accuracy': 0.8228875879340873, 'eval_precision': 0.27810751229438696, 'eval_recall': 0.5588686317941728, 'eval_f1': 0.3713978372869841, 'eval_auc': 0.7045135133627676, 'eval_mcc': 0.30487130771076804, 'eval_runtime': 32.9306, 'eval_samples_per_second': 9.11, 'eval_steps_per_second': 1.154, 'epoch': 7.0}
{'loss': 6.2347, 'grad_norm': 3.1285386085510254, 'learning_rate': 1.0194118683375503e-05, 'epoch': 7.02}
{'loss': 3.5813, 'grad_norm': 19.460174560546875, 'learning_rate': 9.644682182758306e-06, 'epoch': 7.11}
{'loss': 4.5314, 'grad_norm': 37.87614440917969, 'learning_rate': 9.106906294750805e-06, 'epoch': 7.19}
{'loss': 5.069, 'grad_norm': 6.061787128448486, 'learning_rate': 8.581199398806641e-06, 'epoch': 7.28}
{'loss': 3.9919, 'grad_norm': 25.08302116394043, 'learning_rate': 8.067960709356478e-06, 'epoch': 7.37}
{'loss': 5.1791, 'grad_norm': 4.019735813140869, 'learning_rate': 7.5675799726501155e-06, 'epoch': 7.46}
{'loss': 3.9549, 'grad_norm': 5

  0%|          | 0/38 [00:00<?, ?it/s]

{'eval_loss': 48.22563552856445, 'eval_accuracy': 0.8311665523457066, 'eval_precision': 0.28786106361918473, 'eval_recall': 0.5450673027773045, 'eval_f1': 0.37675185490519375, 'eval_auc': 0.7028926798997399, 'eval_mcc': 0.3095205994988983, 'eval_runtime': 32.9441, 'eval_samples_per_second': 9.106, 'eval_steps_per_second': 1.153, 'epoch': 8.0}
Unfreezing layers 0 to 5 at epoch 8.0
{'loss': 4.9236, 'grad_norm': 32.065303802490234, 'learning_rate': 4.4555546193688735e-06, 'epoch': 8.07}
{'loss': 4.9191, 'grad_norm': 4.153754711151123, 'learning_rate': 4.070838043436786e-06, 'epoch': 8.16}
{'loss': 3.5994, 'grad_norm': 4.528178691864014, 'learning_rate': 3.7020147790418263e-06, 'epoch': 8.25}
{'loss': 4.9768, 'grad_norm': 3.970609664916992, 'learning_rate': 3.3493649053890326e-06, 'epoch': 8.33}
{'loss': 4.6053, 'grad_norm': 2.9108381271362305, 'learning_rate': 3.013156219837776e-06, 'epoch': 8.42}
{'loss': 5.0159, 'grad_norm': 15.0165433883667, 'learning_rate': 2.6936440345401493e-06, 'ep

  0%|          | 0/38 [00:00<?, ?it/s]

{'eval_loss': 48.60432815551758, 'eval_accuracy': 0.839636937899791, 'eval_precision': 0.2990779869381483, 'eval_recall': 0.530584426648492, 'eval_f1': 0.3825317855168601, 'eval_auc': 0.7010718683752845, 'eval_mcc': 0.3147687217085311, 'eval_runtime': 33.0895, 'eval_samples_per_second': 9.066, 'eval_steps_per_second': 1.148, 'epoch': 9.0}
{'loss': 3.7563, 'grad_norm': 20.240684509277344, 'learning_rate': 1.1398749530123127e-06, 'epoch': 9.04}
{'loss': 3.0389, 'grad_norm': 4.250957012176514, 'learning_rate': 9.432999922687396e-07, 'epoch': 9.12}
{'loss': 3.8361, 'grad_norm': 19.873981475830078, 'learning_rate': 7.649933515167407e-07, 'epoch': 9.21}
{'loss': 4.5322, 'grad_norm': 4.913013458251953, 'learning_rate': 6.050904343141095e-07, 'epoch': 9.3}
{'loss': 4.9322, 'grad_norm': 36.87843704223633, 'learning_rate': 4.637126686895532e-07, 'epoch': 9.39}
{'loss': 4.1248, 'grad_norm': 39.85763168334961, 'learning_rate': 3.4096741493194197e-07, 'epoch': 9.47}
{'loss': 3.2217, 'grad_norm': 4.

  0%|          | 0/38 [00:00<?, ?it/s]

{'eval_loss': 48.76272201538086, 'eval_accuracy': 0.8422370750849432, 'eval_precision': 0.3027567938781517, 'eval_recall': 0.5258135968648833, 'eval_f1': 0.38426098866890795, 'eval_auc': 0.700367199699601, 'eval_mcc': 0.3163553439642934, 'eval_runtime': 37.8803, 'eval_samples_per_second': 7.92, 'eval_steps_per_second': 1.003, 'epoch': 10.0}
{'train_runtime': 1544.4725, 'train_samples_per_second': 5.905, 'train_steps_per_second': 0.738, 'train_loss': 12.74169400365729, 'epoch': 10.0}


TrainOutput(global_step=1140, training_loss=12.74169400365729, metrics={'train_runtime': 1544.4725, 'train_samples_per_second': 5.905, 'train_steps_per_second': 0.738, 'total_flos': 1.982744217908928e+16, 'train_loss': 12.74169400365729, 'epoch': 10.0})

### Evaluation

In [None]:
from accelerate import Accelerator

key_token_ids = torch.tensor([tokenizer.convert_tokens_to_ids(aa) for aa in KEY_AMINO_ACIDS], dtype=torch.long)
accelerator = Accelerator()

saved_model_path = "/content/drive/MyDrive/Protein-binding/trained_models/esm2_t33_650M-binding-sites_2025-03-04_16-23-29/checkpoint-1140"

loaded_model = ESM2WithBiasedAttention.from_pretrained(
    saved_model_path,
    num_labels = 2,
)
# Prepare the model with Accelerator
model = accelerator.prepare(loaded_model)

In [None]:
# Define label mappings
id2label = {0: "No binding site", 1: "Binding site"}
label2id = {v: k for k, v in id2label.items()}

# Create a data collator
data_collator = DataCollatorForTokenClassification(tokenizer)

In [None]:
def compute_metrics_dataset(dataset, data_collator):
    # Get the predictions using the trained model
    trainer = Trainer(model=model, data_collator=data_collator)
    predictions, labels, _ = trainer.predict(test_dataset=dataset)

    # Remove padding and special tokens
    mask = labels != -100
    true_labels = labels[mask].flatten()
    print(f"True labels: {list(true_labels)}")
    flat_predictions = np.argmax(predictions, axis=2)[mask].flatten().tolist()
    print(f"Flat predictions: {flat_predictions}")

    # Compute the metrics
    accuracy = accuracy_score(true_labels, flat_predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, flat_predictions, average='binary')
    auc = roc_auc_score(true_labels, flat_predictions)
    mcc = matthews_corrcoef(true_labels, flat_predictions)

    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "auc": auc, "mcc": mcc}

In [None]:
# import wandb
# wandb.login(key = "8e2ece42312dcb7bef3cf33270f391abb4e5cf5c")
# trainer = Trainer(model=model, data_collator=data_collator)

In [None]:
# WANDB key = "8e2ece42312dcb7bef3cf33270f391abb4e5cf5c"
test_metrics = compute_metrics_dataset(test_dataset, data_collator)
print(test_metrics)