# A Small Token Classifier with Three Classes

In [10]:
import torch
from transformers import EsmTokenizer, EsmForTokenClassification, Trainer, TrainingArguments, AutoTokenizer
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

In [11]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, TrainingArguments, Trainer
from transformers import EsmForTokenClassification

class ProteinData:
    def __init__(self, input_ids, labels):
        self.input_ids = input_ids
        self.labels = labels

class ProteinDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Your dataset
dataset = {
    'sequence': [
        "AVILMFCCCYRRNDCQKSTWRNDKNDKNDKNDRRNDCQKSRRNDCQKSAVILMFYKNDKNDKN", # Protein 1
        "RRNDCQKSRRNDCQKSTWAVILMFCCCYRRNDCQKSTWRNDKNDKNDKNDRRNDCQKS", # Protein 2
        "AVILMFYKNDKNDKNDRRNDCQKSAVILMFYKNDKNDKNRRNDCQKSTWRRNDCQKSTW", # Protein 3
        "RRNDCQKSAVILMFYKNDKNDKNDRRNDCQKSRRNDCQKSTWAVILMFCCCYRRNDCQKSTW", # Protein 4
        "TWRRNDCQKSAVILMFYKNDKNDRRNDCQKSTWRRNDCQKSAVILMFYKNDKNDKN", # Protein 5
        "MVLSEGEWQLVLHVWAKVEADVAGHGQDILIRLFKSHPETLEKFDRVKHLKTEAEMKASEDLKKHGVTVLTALGAILKKKGHHEAELKPLAQSHATKHKIPIKYLEFISDAIIHVLHAKHPS",
        "MGSDKIHHHHHHENLYFQGADPKDLAHLLDYFEHKETDGLAKGFGTAKSVFKDATNFAEIISVLKKMRPILFLPLLCVILIFKIKFWT",
        "MKLAILVTIVALVAMYRINHRTQELIELSNKNQPYTINADIEEIELTNRYPALIEYVQQQDKPLPKN", 
        "MPRYKILNSKLTGEKMSLYEFLVTFISKIITVLLTVFLNRYHRRWYHG",
        "PMLKRRTYNLIYIAFLTVFSNKTYRIDGFIKNLPYLFRGVNNTGKPKL",
        "AMNRLIFPLIKILILLLSIPFFLGNLIDKSYLKQIGLKVTFLFMLRYH",
        "MFLKTLLIILWVAAIKNLQTVYQLIRFLKISRKRYEHGKNFVRIWLYK",
        "VIIGDRLVRIWLYKLFIKNLIKNLPYLFLISKNLVTFLMLLRIWLYKQ",
        "QIGLKVTFLMLLRGWKSVYFFLKNLPYLGFSKKNLVTFLMLLRIWLYK",
        "NLKNLPYLGFSKKNLVTFLMLLRGWKSVYFFLKNLPYLFLISKNLVTFL",
        "MLLRGWKSVYFFLKNLPYLFLISKNLVTFLMLLRGWKSVYFFLKNLPY",
        "LFLISKNLVTFLMLLRGWKSVYFFLKNLPYLFLISKNLVTFLMLLRGW",
        "KSVYFFLKNLPYLFLISKNLVTFLMLLRGWKSVYFFLKNLPYLFLISK",
        "NLVTFLMLLRGWKSVYFFLKNLPYLFLISKNLVTFLMLLRGWKSVYFF",
        "LKNLPYLFLISKNLVTFLMLLRGWKSVYFFLKNLPYLFLISKNLVTFLM",
        "LLRGWKSVYFFLKNLPYLFLISKNLVTFLMLLRGWKSVYFFLKNLPYLF",
        "LISKNLVTFLMLLRGWKSVYFFLKNLPYLFLISKNLVTFLMLLRGWKSV"
    ],
    'labels': [
        [0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], # Labels for Protein 1
        [2, 2, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 2, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 2, 1, 1, 1], # Labels for Protein 2
        [0, 0, 0, 0, 0, 0, 2, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 2, 2, 0, 0, 0, 0, 0, 0, 2, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0], # Labels for Protein 3
        [2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 2, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1], # Labels for Protein 4
        [2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 2, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 2, 1, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 2, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1],  # Labels for Protein 5
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 2, 2, 2, 1, 1, 1, 0, 0, 0, 2, 2, 2, 1, 1, 1, 0, 0, 0, 2, 2, 2, 1, 1, 1, 0, 0, 0, 2, 2, 2, 1, 1, 1, 0, 0, 0, 2, 2, 2, 1, 1, 1, 0, 0, 0, 2, 2, 2],
        [0, 0, 0, 2, 2, 2, 1, 1, 1, 0, 0, 0, 2, 2, 2, 1, 1, 1, 0, 0, 0, 2, 2, 2, 1, 1, 1, 0, 0, 0, 2, 2, 2, 1, 1, 1, 0, 0, 0, 2, 2, 2, 1, 1, 1, 0, 0, 0, 2, 2, 2],
        [2, 2, 2, 1, 1, 1, 0, 0, 0, 2, 2, 2, 1, 1, 1, 0, 0, 0, 2, 2, 2, 1, 1, 1, 0, 0, 0, 2, 2, 2, 1, 1, 1, 0, 0, 0, 2, 2, 2, 1, 1, 1, 0, 0, 0, 2, 2, 2, 0, 0, 0],
        [0, 2, 1, 0, 1, 2, 0, 1, 2, 0, 2, 1, 0, 2, 1, 0, 1, 2, 0, 2, 1, 0, 2, 1, 0, 1, 2, 0, 2, 1, 0, 2, 1, 0, 1, 2, 0, 2, 1, 0, 2, 1, 0, 1, 2, 0, 2, 1, 2, 1, 0],
        [1, 2, 0, 2, 1, 0, 1, 2, 0, 2, 1, 0, 2, 1, 0, 1, 2, 0, 2, 1, 0, 2, 1, 0, 1, 2, 0, 2, 1, 0, 2, 1, 0, 1, 2, 0, 2, 1, 0, 2, 1, 0, 1, 2, 0, 2, 1, 2, 0, 1, 2],
        [2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 0, 2, 1],
        [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 1, 0, 2],
        [2, 1, 0, 1, 2, 0, 2, 1, 0, 1, 2, 0, 2, 1, 0, 1, 2, 0, 2, 1, 0, 1, 2, 0, 2, 1, 0, 1, 2, 0, 2, 1, 0, 1, 2, 0, 2, 1, 0, 1, 2, 0, 2, 1, 0, 1, 2, 0, 2, 1, 0],
        [1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2],
        [0, 1, 2, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 1, 2],
        [2, 0, 1, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 1, 2, 0],
        [1, 2, 0, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 0, 1, 2],
        [2, 1, 0, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1],
        [0, 2, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 1, 0, 2],
    ]
}
# 0: Exposed to Solvent
# 1: Binding Site
# 2: Transmembrane region

# Pad the sequences and labels
for i in range(len(dataset['sequence'])):
    while len(dataset['sequence'][i]) < 80:
        dataset['sequence'][i] += 'X' # padding token for sequence
        dataset['labels'][i].append(-100) # padding token for labels

# Create the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
# Create the model and specify the number of labels
model = EsmForTokenClassification.from_pretrained("facebook/esm2_t6_8M_UR50D", num_labels=3)


# Convert sequences to input IDs and labels to tensors
inputs = [tokenizer(seq, truncation=True, padding='max_length', max_length=50)["input_ids"] for seq in dataset["sequence"]]
labels = [label[:50] + [-100]*(50-len(label)) for label in dataset["labels"]] # Truncate/pad labels to match input IDs


# Create ProteinData objects and then ProteinDataset
data = [ProteinData(input_id, label) for input_id, label in zip(inputs, labels)]
protein_dataset = ProteinDataset(data)

# Define the training arguments
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=10,   # increase the number of epochs
    per_device_train_batch_size=32,  # increase if you have enough memory
    per_device_eval_batch_size=64,
    warmup_steps=500,
    weight_decay=0.01,
)


# Define the trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=protein_dataset,
)

# Train the model
trainer.train()

# Save the model
model.save_pretrained("trained_token_classifier")


Some weights of EsmForTokenClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/10 [00:00<?, ?it/s]

{'train_runtime': 14.9385, 'train_samples_per_second': 14.727, 'train_steps_per_second': 0.669, 'train_loss': 1.098867130279541, 'epoch': 10.0}


In [12]:
# Let's say this is your new sequence
new_sequence = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"

# Convert sequence to input IDs
inputs = tokenizer(new_sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt")["input_ids"]

# Apply model to get the logits
with torch.no_grad():
    outputs = model(inputs)

# Get the predictions by picking the label (class) with the highest logit
predictions = torch.argmax(outputs.logits, dim=-1)

# print(predictions)

# Here we map each label number to its corresponding class
class_mapping = {
    0: 'Exposed to Solvent',
    1: 'Binding Site',
    2: 'Transmembrane region',
}

# Then, we convert the tensor to a list of integers
prediction_list = predictions.tolist()[0]

# Now, we get a list of class labels using the mapping
class_labels = [class_mapping[pred] for pred in prediction_list]

# Create a list that matches each amino acid in the sequence to its predicted class label
residue_to_label = list(zip(list(new_sequence), class_labels))

# Print out the list
for i, (residue, label) in enumerate(residue_to_label):
    print(f"Position {i+1} - {residue}: {label}")


Position 1 - M: Binding Site
Position 2 - A: Binding Site
Position 3 - P: Binding Site
Position 4 - L: Binding Site
Position 5 - R: Binding Site
Position 6 - K: Binding Site
Position 7 - T: Binding Site
Position 8 - Y: Exposed to Solvent
Position 9 - V: Exposed to Solvent
Position 10 - L: Binding Site
Position 11 - K: Exposed to Solvent
Position 12 - L: Exposed to Solvent
Position 13 - Y: Exposed to Solvent
Position 14 - V: Exposed to Solvent
Position 15 - A: Binding Site
Position 16 - G: Binding Site
Position 17 - N: Binding Site
Position 18 - T: Exposed to Solvent
Position 19 - P: Exposed to Solvent
Position 20 - N: Exposed to Solvent
Position 21 - S: Exposed to Solvent
Position 22 - V: Exposed to Solvent
Position 23 - R: Exposed to Solvent
Position 24 - A: Exposed to Solvent
Position 25 - L: Binding Site
Position 26 - K: Exposed to Solvent
Position 27 - T: Exposed to Solvent
Position 28 - L: Exposed to Solvent
Position 29 - N: Exposed to Solvent
Position 30 - N: Exposed to Solvent
P