In [1]:
from transformers import AutoModelForTokenClassification, AutoTokenizer
from peft import PeftModel
import torch
import numpy as np
import random



# Path to the saved LoRA model
model_path = "esm2_t6_8M-finetuned-lora_2023-08-05_13-46-01"
# ESM2 base model
base_model_path = "facebook/esm2_t6_8M_UR50D"

# Load the model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
loaded_model = PeftModel.from_pretrained(base_model, model_path)

# Load the tokenizer
loaded_tokenizer = AutoTokenizer.from_pretrained(model_path)

# New unseen protein sequence
new_protein_sequence = "FDLNDFLEQKVLVRMEAIINSMTMKERAKPEIIKGSRKRRIAAGSGMQVQDVNRLLKQFDDMQRMMKKM"

# Tokenize the new sequence
inputs = loaded_tokenizer(new_protein_sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt")

# Make predictions
with torch.no_grad():
    outputs = loaded_model(**inputs)
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=2)

# Print logits for debugging
print("Logits:", logits)

# Convert predictions to a list
predicted_labels = predictions.squeeze().tolist()

# Get input IDs to identify padding and special tokens
input_ids = inputs['input_ids'].squeeze().tolist()

# Define a set of token IDs that correspond to special tokens
special_tokens_ids = {loaded_tokenizer.cls_token_id, loaded_tokenizer.pad_token_id, loaded_tokenizer.eos_token_id}

# Filter the predicted labels using the special_tokens_ids to remove predictions for special tokens
binding_sites = [label for label, token_id in zip(predicted_labels, input_ids) if token_id not in special_tokens_ids]

print("Predicted binding sites:", binding_sites)


Some weights of EsmForTokenClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Logits: tensor([[[ 1.3544, -1.8621],
         [ 1.5548, -2.0320],
         [ 1.5699, -1.9424],
         ...,
         [ 1.3844, -1.7764],
         [ 1.4068, -1.7842],
         [ 1.4042, -1.7909]]])
Predicted binding sites: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


In [4]:
import torch
import numpy as np
import xml.etree.ElementTree as ET
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, EsmForTokenClassification
from peft import LoraConfig, get_peft_model, PeftModel
from sklearn.metrics import accuracy_score

# Load XML file
tree = ET.parse('binding_sites.xml')
root = tree.getroot()

# Define lists to hold sequences and binding site labels
sequences = []
binding_sites = []

# Iterate through the XML and extract the sequences and binding site labels
for partner in root.findall('partner'):
    for bind_partner in partner.findall('BindPartner'):
        sequence = bind_partner.find('proSeq').text
        pro_bnd = bind_partner.find('proBnd').text
        sites = [1 if char == '+' else 0 for char in pro_bnd]
        sequences.append(sequence)
        binding_sites.append(sites)

tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")

class ProteinDataset(torch.utils.data.Dataset):
    def __init__(self, sequences, binding_sites, tokenizer, max_length=512):
        self.sequences = sequences
        self.binding_sites = binding_sites
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences[idx][:self.max_length]
        binding_site = self.binding_sites[idx][:self.max_length]
        encoding = self.tokenizer(sequence, truncation=True, padding='max_length', max_length=self.max_length)
        encoding['labels'] = binding_site + [-100] * (self.max_length - len(binding_site))
        return {key: torch.tensor(val) for key, val in encoding.items()}

# Assuming validation dataset was split as done previously
val_size = int(0.15 * len(sequences))
val_dataset = ProteinDataset(sequences[-val_size:], binding_sites[-val_size:], tokenizer)
val_dataloader = DataLoader(val_dataset, batch_size=8)

# Model directory where the model is saved
model_dir = "esm2_t6_8M-finetuned-lora_2023-08-05_13-46-01"

# Load the base model
base_model = EsmForTokenClassification.from_pretrained("facebook/esm2_t6_8M_UR50D")

loaded_model = PeftModel.from_pretrained(base_model, model_dir)

# Get the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move the model to the device
loaded_model.to(device)

# Evaluate the loaded model on the validation dataset
loaded_model.eval()
all_preds = []
all_labels = []
for batch in val_dataloader:
    inputs = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    labels = batch['labels'].to(device)
    with torch.no_grad():
        outputs = loaded_model(input_ids=inputs, attention_mask=attention_mask)
        preds = torch.argmax(outputs.logits, dim=-1)
        all_preds.extend(preds.cpu().tolist())
        all_labels.extend(labels.cpu().tolist())

# Flatten the lists and remove -100 (padding value)
all_preds_flat = [item for sublist in all_preds for item in sublist]
all_labels_flat = [item for sublist in all_labels for item in sublist]
all_preds_flat = [p for p, l in zip(all_preds_flat, all_labels_flat) if l != -100]
all_labels_flat = [l for l in all_labels_flat if l != -100]

# Calculate the accuracy
accuracy = accuracy_score(all_labels_flat, all_preds_flat)
print(f"Validation Accuracy: {accuracy}")


Some weights of EsmForTokenClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Validation Accuracy: 0.9586577181208054


In [6]:
# Create the entire dataset using the sequences and binding sites
entire_dataset = ProteinDataset(sequences, binding_sites, tokenizer)
entire_dataloader = DataLoader(entire_dataset, batch_size=8)

# Evaluate the loaded model on the entire dataset
loaded_model.eval()  # Ensure the model is in evaluation mode
all_preds = []
all_labels = []
for batch in entire_dataloader:
    inputs = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    labels = batch['labels'].to(device)
    with torch.no_grad():
        outputs = loaded_model(input_ids=inputs, attention_mask=attention_mask)
        preds = torch.argmax(outputs.logits, dim=-1)
        all_preds.extend(preds.cpu().tolist())
        all_labels.extend(labels.cpu().tolist())

# Flatten the lists and remove -100 (padding value)
all_preds_flat = [item for sublist in all_preds for item in sublist]
all_labels_flat = [item for sublist in all_labels for item in sublist]
all_preds_flat = [p for p, l in zip(all_preds_flat, all_labels_flat) if l != -100]
all_labels_flat = [l for l in all_labels_flat if l != -100]

# Calculate the accuracy
accuracy = accuracy_score(all_labels_flat, all_preds_flat)
print(f"Entire Dataset Accuracy: {accuracy}")


Entire Dataset Accuracy: 0.9484533852476746
