In [7]:
from transformers import AutoModelForTokenClassification, AutoTokenizer
import torch
import numpy as np
import random

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)

# Path to the saved model
model_path = "AmelieSchreiber/esm2_t6_8M_UR50D_lora_rna_binding_sites"

# Load the model
loaded_model = AutoModelForTokenClassification.from_pretrained(model_path)
loaded_model.eval()  # Set the model to evaluation mode

# Load the tokenizer
loaded_tokenizer = AutoTokenizer.from_pretrained(model_path)

# New unseen protein sequence
new_protein_sequence = "FDLNDFLEQKVLVRMEAIINSMTMKERAKPEIIKGSRKRRIAAGSGMQVQDVNRLLKQFDDMQRMMKKM"

# Tokenize the new sequence
inputs = loaded_tokenizer(new_protein_sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt")

# Make predictions
with torch.no_grad():
    outputs = loaded_model(**inputs)
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=2)

# Print logits for debugging
print("Logits:", logits)

# Convert predictions to a list
predicted_labels = predictions.squeeze().tolist()

# Get input IDs to identify padding and special tokens
input_ids = inputs['input_ids'].squeeze().tolist()

# Define a set of token IDs that correspond to special tokens
special_tokens_ids = {loaded_tokenizer.cls_token_id, loaded_tokenizer.pad_token_id, loaded_tokenizer.eos_token_id}

# Filter the predicted labels using the special_tokens_ids to remove predictions for special tokens
binding_sites = [label for label, token_id in zip(predicted_labels, input_ids) if token_id not in special_tokens_ids]

print("Predicted binding sites:", binding_sites)


Some weights of the model checkpoint at AmelieSchreiber/esm2_t6_8M_UR50D_lora_rna_binding_sites were not used when initializing EsmForTokenClassification: ['base_model.model.esm.encoder.layer.2.attention.self.query.lora_B.default.weight', 'base_model.model.esm.encoder.layer.0.output.dense.weight', 'base_model.model.esm.contact_head.regression.weight', 'base_model.model.esm.embeddings.position_ids', 'base_model.model.esm.encoder.layer.1.attention.self.value.lora_B.default.weight', 'base_model.model.esm.encoder.layer.0.attention.self.query.bias', 'base_model.model.esm.encoder.layer.1.attention.self.key.lora_B.default.weight', 'base_model.model.esm.encoder.layer.4.attention.self.key.bias', 'base_model.model.esm.encoder.layer.1.intermediate.dense.weight', 'base_model.model.esm.encoder.layer.4.output.dense.bias', 'base_model.model.esm.encoder.layer.3.attention.output.dense.weight', 'base_model.model.esm.encoder.layer.5.attention.self.rotary_embeddings.inv_freq', 'base_model.model.esm.encode

Logits: tensor([[[ 0.0193, -0.6626],
         [-0.1956, -0.7331],
         [-0.1698, -0.7026],
         ...,
         [-0.1562, -0.7934],
         [-0.1561, -0.7931],
         [-0.1559, -0.7933]]])
Predicted binding sites: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
