In [37]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np

from Bio import SeqIO
from glob import glob
from transformers import (AutoModelForTokenClassification, AutoTokenizer, 
                          AutoModelForMaskedLM,
                           EsmForMaskedLM, EsmTokenizer,
                           TrainingArguments
                        )
from sklearn.metrics import (accuracy_score, precision_recall_fscore_support,
                             matthews_corrcoef, roc_auc_score)
from sklearn.model_selection import train_test_split
from preprocessing import convert_to_binary_list
from pprint import pprint
from datasets import Dataset
from datetime import datetime

In [2]:
dir_path = "data/development_set/"
csv_files = glob(dir_path + "/*.csv")
print(csv_files)

['data/development_set/all_binding_sites_batch_10.csv', 'data/development_set/all_binding_sites_batch_4.csv', 'data/development_set/all_binding_sites_batch_8.csv', 'data/development_set/all_binding_sites_batch_7.csv', 'data/development_set/all_binding_sites_batch_5.csv', 'data/development_set/all_binding_sites_batch_3.csv', 'data/development_set/all_binding_sites_batch_11.csv', 'data/development_set/all_binding_sites_batch_1.csv', 'data/development_set/all_binding_sites_batch_2.csv', 'data/development_set/all_binding_sites_batch_6.csv']


In [3]:
binding_sites_df = pd.DataFrame()

for f in csv_files:
    batch_df = pd.read_csv(f)
    binding_sites_df = pd.concat([binding_sites_df, batch_df])

# New concated binding sites df
display(binding_sites_df)

Unnamed: 0,prot_id,binding_sites,ligand_type,sequence,sequence_length,binary_binding_sites
0,P59857,"[60, 65, 71, 76, 77, 80, 81, 83, 84, 87, 88, 9...",small,MQDAITSVINAADVQGKYLDDSSVEKLRGYFQTGELRVRAAATIAA...,161,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,P26514,"[384, 388, 408, 409, 410, 417, 419, 421, 424, ...",small,,0,[]
2,P12111,"[3128, 3153, 3154, 3155]",small,MRKHRHLPLVAVFCLFLSGFPTTHAQQQQADVKNGAAADIIFLVDS...,3177,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,P05057,"[141, 142, 145, 149, 37, 38, 39, 42, 49, 50, 5...",small,MNGPIIMTREERMKIVHEIKERILDKYGDDVKAIGVYGSLGRQTDG...,253,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,P46859,"[160, 159, 163, 40, 17, 18, 19, 20, 21, 22, 23...",small,MSTTNHDHHIYVLMGVSGSGKSAVASEVAHQLHAAFLDGDFLHPRR...,175,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
...,...,...,...,...,...,...
995,Q14114,"[64, 67, 69, 71, 103, 106, 108, 77, 110, 78, 1...",metal,,0,[]
996,O32553,"[32, 205, 143, 145, 219, 221]",metal,MSKDIKQVIEIAKKHNLFLKEETIQFNESGLDFQAVFAQDNNGIDW...,302,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
997,A8IYS5,"[146, 124]",metal,MSSTYQKFAASLREQEGPSGSLPTNGPSTTTPFANATNRYLNNHSG...,481,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
998,H2EMX8,"[80, 113, 110]",metal,MEEAKVEAKDGTISVATAFAGHQQAVLDSDHKFLTQAVEEAYKGVD...,185,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


In [4]:
needed_binding_sites_df = binding_sites_df[binding_sites_df['sequence_length'] > 0]
needed_binding_sites_df['binary_binding_sites'] = [
            convert_to_binary_list(row['binding_sites'], row['sequence_length']) 
            for _, row in needed_binding_sites_df.iterrows()
        ]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  needed_binding_sites_df['binary_binding_sites'] = [


In [5]:
needed_binding_sites_df.ligand_type.value_counts()

ligand_type
metal      4163
small      1291
nuclear     440
Name: count, dtype: int64

In [6]:
# Data verification

sample_idx = 10
seq_label = needed_binding_sites_df.iloc[sample_idx]['binary_binding_sites']
seq_len = needed_binding_sites_df.iloc[sample_idx]['sequence_length']
seq = needed_binding_sites_df.iloc[sample_idx]['sequence']
seq_binding_sites = needed_binding_sites_df.iloc[sample_idx]['binding_sites']

print(seq_label)
print(seq)
print(len(seq_label))
print(seq_len)
print(seq_binding_sites)

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [7]:
np.random.seed(42)

train_df, test_df = train_test_split(
    needed_binding_sites_df,
    test_size=0.2,  # 80% train, 20% test
    stratify=needed_binding_sites_df['ligand_type'],
    random_state=42
)

# Verify the split ratios
print("Training set ligand distribution:")
print(train_df['ligand_type'].value_counts())
print("\nTest set ligand distribution:")
print(test_df['ligand_type'].value_counts())

# Calculate percentages to verify similar ratios
print("\nTraining set percentages:")
print(train_df['ligand_type'].value_counts(normalize=True))
print("\nTest set percentages:")
print(test_df['ligand_type'].value_counts(normalize=True))

Training set ligand distribution:
ligand_type
metal      3330
small      1033
nuclear     352
Name: count, dtype: int64

Test set ligand distribution:
ligand_type
metal      833
small      258
nuclear     88
Name: count, dtype: int64

Training set percentages:
ligand_type
metal      0.706257
small      0.219088
nuclear    0.074655
Name: proportion, dtype: float64

Test set percentages:
ligand_type
metal      0.706531
small      0.218830
nuclear    0.074640
Name: proportion, dtype: float64


In [8]:
test_df.sequence_length.describe()

count    1179.000000
mean      373.797286
std       285.582743
min        37.000000
25%       195.000000
50%       301.000000
75%       464.500000
max      3391.000000
Name: sequence_length, dtype: float64

In [9]:
train_df.head()

Unnamed: 0,prot_id,binding_sites,ligand_type,sequence,sequence_length,binary_binding_sites
633,Q7CPA2,"[99, 12, 13, 14, 84, 85, 87, 88, 89, 90]",metal,MLDVKSQDISIPEAVVVLCTAPDEATAQDLAAKVLAEKLAACATLL...,115,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
507,P0AB60,"[360, 371, 357, 374]",metal,MLELLFLLLPVAAAYGWYMGRRSAQQNKQDEANRLSRDYVAGVNFL...,389,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
980,Q16611,"[160, 164, 103, 104, 73, 170, 76, 46, 176, 177...",metal,MASGQGPGPPRQECGEPALPSASEEQVAQDTEEVFRSYVFYRHQQE...,211,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
648,Q9HYC5,"[384, 258, 382, 262, 275, 181, 375, 280, 378, ...",metal,MTATSDLIESLISYSWDDWQVTRQEARRVIAAIRNDNVPDATIAAL...,408,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
626,P44007,"[128, 58, 54]",metal,MPLLDSFKVDHTKMNAPAVRIAKTMLTPKGDNITVFDLRFCIPNKE...,167,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


### Split sequences into chunks

In [10]:
def split_into_chunks(sequences, labels, chunk_size = 1000):
    """Split sequences and labels into chunks of size "chunk_size" or less."""
    new_sequences = []
    new_labels = []
    for seq, lbl in zip(sequences, labels):
        if len(seq) > chunk_size:
            # Split the sequence and labels into chunks of size "chunk_size" or less
            for i in range(0, len(seq), chunk_size):
                new_sequences.append(seq[i:i+chunk_size])
                new_labels.append(lbl[i:i+chunk_size])
        else:
            new_sequences.append(seq)
            new_labels.append(lbl)

    return new_sequences, new_labels

### Prepare training and testing sequences, labels

In [11]:
# Initial sequences
test_seq = test_df['sequence'].tolist()
test_labels = test_df['binary_binding_sites'].tolist()
train_seq = train_df['sequence'].tolist()
train_labels = train_df['binary_binding_sites'].tolist()

In [12]:
# Apply new sequences by chunking
chunk_size = 1000
test_seq, test_labels = split_into_chunks(test_seq, test_labels, chunk_size)
train_seq, train_labels = split_into_chunks(train_seq, train_labels, chunk_size)

In [13]:
print(len(train_seq[10]))
print(len(train_labels[10]))

473
473


### Tokenization and helper functions

In [15]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
max_sequence_length = 1000 # here note that data was processed with chunks (context window) of 1000 residues - adapt accordingly

train_tokenized = tokenizer(train_seq, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
test_tokenized = tokenizer(test_seq, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)

In [16]:
train_tokenized

{'input_ids': tensor([[ 0, 20,  4,  ...,  1,  1,  1],
        [ 0, 20,  4,  ...,  1,  1,  1],
        [ 0, 20,  5,  ...,  1,  1,  1],
        ...,
        [ 0, 20,  9,  ...,  1,  1,  1],
        [ 0, 20, 17,  ...,  1,  1,  1],
        [ 0, 20, 20,  ...,  1,  1,  1]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}

In [17]:
def truncate_labels(labels, max_length):
    """Truncate labels to the specified max_length."""
    return [label[:max_length] for label in labels]

def compute_metrics_train(p):
    """Compute metrics for evaluation."""
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove padding (-100 labels)
    predictions = predictions[labels != -100].flatten()
    labels = labels[labels != -100].flatten()

    # Compute accuracy
    accuracy = accuracy_score(labels, predictions)

    # Compute precision, recall, F1 score, and AUC
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
    auc = roc_auc_score(labels, predictions)

    # Compute MCC
    mcc = matthews_corrcoef(labels, predictions)

    return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}

def compute_loss(model, inputs):
    """Custom compute_loss function."""
    logits = model(**inputs).logits
    labels = inputs["labels"]
    loss_fct = nn.CrossEntropyLoss()
    active_loss = inputs["attention_mask"].view(-1) == 1
    active_logits = logits.view(-1, model.config.num_labels)
    active_labels = torch.where(
        active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
    )
    loss = loss_fct(active_logits, active_labels)
    return loss

In [27]:
# Verify some encoded and dedcode input, output
idx = 100
print(len(train_seq[idx]))
print(train_seq[idx])
encoded_input = tokenizer(train_seq[idx])
print(encoded_input)
decoded_output = tokenizer.decode(encoded_input['input_ids'])
print(decoded_output)

print(len(train_seq[idx]), len(encoded_input["input_ids"]))

190
MRGETLKLKKDKRREAIRQQIDSNPFITDHELSDLFQVSIQTIRLDRTYLNIPELRKRIKLVAEKNYDQISSIEEQEFIGDLIQVNPNVKAQSILDITSDSVFHKTGIARGHVLFAQANSLCVALIKQPTVLTHESSIQFIEKVKLNDTVRAEARVVNQTAKHYYVEVKSYVKHTLVFKGNFKMFYDKRG
{'input_ids': [0, 20, 10, 6, 9, 11, 4, 15, 4, 15, 15, 13, 15, 10, 10, 9, 5, 12, 10, 16, 16, 12, 13, 8, 17, 14, 18, 12, 11, 13, 21, 9, 4, 8, 13, 4, 18, 16, 7, 8, 12, 16, 11, 12, 10, 4, 13, 10, 11, 19, 4, 17, 12, 14, 9, 4, 10, 15, 10, 12, 15, 4, 7, 5, 9, 15, 17, 19, 13, 16, 12, 8, 8, 12, 9, 9, 16, 9, 18, 12, 6, 13, 4, 12, 16, 7, 17, 14, 17, 7, 15, 5, 16, 8, 12, 4, 13, 12, 11, 8, 13, 8, 7, 18, 21, 15, 11, 6, 12, 5, 10, 6, 21, 7, 4, 18, 5, 16, 5, 17, 8, 4, 23, 7, 5, 4, 12, 15, 16, 14, 11, 7, 4, 11, 21, 9, 8, 8, 12, 16, 18, 12, 9, 15, 7, 15, 4, 17, 13, 11, 7, 10, 5, 9, 5, 10, 7, 7, 17, 16, 11, 5, 15, 21, 19, 19, 7, 9, 7, 15, 8, 19, 7, 15, 21, 11, 4, 7, 18, 15, 6, 17, 18, 15, 20, 18, 19, 13, 15, 10, 6, 2], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

### Training

In [34]:
train_labels = truncate_labels(train_labels, max_sequence_length)
test_labels = truncate_labels(test_labels, max_sequence_length)

train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)

In [39]:
timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

model = AutoModelForTokenClassification.from_pretrained("facebook/esm2_t6_8M_UR50D", num_labels=2)
training_args = TrainingArguments(
    output_dir=f"data/trained_models/esm2_t6_8M-binding-sites_{timestamp}",
    seed=42,
    num_train_epochs = 3,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,
    logging_dir=f"data/trained_models/esm2_t6_8M-binding-sites_{timestamp}",
    logging_strategy="steps",
    logging_steps=10,
    save_total_limit=5,
    fp16=True,
    report_to="none"  # Disable wandb
)

Some weights of EsmForTokenClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ImportError: Using the `Trainer` with `PyTorch` requires `accelerate>=0.26.0`: Please run `pip install transformers[torch]` or `pip install 'accelerate>=0.26.0'`

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics_train,
    tokenizer=tokenizer,
    data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer)

)