If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers as well as some other libraries. Uncomment the following cell and run it.

In [87]:
# ! pip install transformers evaluate datasets requests pandas scikit-learn

In [88]:
# ! pip install -U accelerate
# ! pip install -U transformers

Derived from https://github.com/huggingface/notebooks/blob/main/examples/protein_language_modeling.ipynb

# Fine-Tuning Protein Language Models

In this notebook, we're going to do some transfer learning to fine-tune some large, pre-trained protein language models on tasks of interest. If that sentence feels a bit intimidating to you, don't panic - there's [a blog post](https://huggingface.co/blog/deep-learning-with-proteins) that explains the concepts here in much more detail.

The specific model we're going to use is ESM-2, which is the state-of-the-art protein language model at the time of writing (November 2022). The citation for this model is [Lin et al, 2022](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v1).

There are several ESM-2 checkpoints with differing model sizes. Larger models will generally have better accuracy, but they require more GPU memory and will take much longer to train. The available ESM-2 checkpoints (at time of writing) are:

| Checkpoint name | Num layers | Num parameters |
|------------------------------|----|----------|
| `esm2_t48_15B_UR50D`         | 48 | 15B     |
| `esm2_t36_3B_UR50D`          | 36 | 3B      | 
| `esm2_t33_650M_UR50D`        | 33 | 650M    | 
| `esm2_t30_150M_UR50D`        | 30 | 150M    | 
| `esm2_t12_35M_UR50D`         | 12 | 35M     | 
| `esm2_t6_8M_UR50D`           | 6  | 8M      | 

Note that the larger checkpoints may be very difficult to train without a large cloud GPU like an A100 or H100, and the largest 15B parameter checkpoint will probably be impossible to train on **any** single GPU! Also, note that memory usage for attention during training will scale as `O(batch_size * num_layers * seq_len^2)`, so larger models on long sequences will use quite a lot of memory! We will use the `esm2_t12_35M_UR50D` checkpoint for this notebook, which should train on any Colab instance or modern GPU.

In [89]:
model_checkpoint = "facebook/esm2_t12_35M_UR50D"

# Sequence classification

One of the most common tasks you can perform with a language model is **sequence classification**. In sequence classification, we classify an entire protein into a category, from a list of two or more possibilities. There's no limit on the number of categories you can use, or the specific problem you choose, as long as it's something the model could in theory infer from the raw protein sequence. To keep things simple for this example, though, let's try classifying proteins by their cellular localization - given their sequence, can we predict if they're going to be found in the cytosol (the fluid inside the cell) or embedded in the cell membrane?

## Data preparation

In this section, we're going to gather some training data from UniProt. Our goal is to create a pair of lists: `sequences` and `labels`. `sequences` will be a list of protein sequences, which will just be strings like "MNKL...", where each letter represents a single amino acid in the complete protein. `labels` will be a list of the category for each sequence. The categories will just be integers, with 0 representing the first category, 1 representing the second and so on. In other words, if `sequences[i]` is a protein sequence then `labels[i]` should be its corresponding category. These will form the **training data** we're going to use to teach the model the task we want it to do.

If you're adapting this notebook for your own use, this will probably be the main section you want to change! You can do whatever you want here, as long as you create those two lists by the end of it. If you want to follow along with this example, though, first we'll need to `import requests` and set up our query to UniProt.

In [90]:
import requests

query_url ="https://rest.uniprot.org/uniprotkb/stream?compressed=true&fields=accession%2Csequence%2Ccc_subcellular_location&format=tsv&query=%28%28organism_id%3A9606%29%20AND%20%28reviewed%3Atrue%29%20AND%20%28length%3A%5B80%20TO%20500%5D%29%29"

This query URL might seem mysterious, but it isn't! To get it, we searched for `(organism_id:9606) AND (reviewed:true) AND (length:[80 TO 500])` on UniProt to get a list of reasonably-sized human proteins,
then selected 'Download', and set the format to TSV and the columns to `Sequence` and `Subcellular location [CC]`, since those contain the data we care about for this task.

Once that's done, selecting `Generate URL for API` gives you a URL you can pass to Requests. Alternatively, if you're not on Colab you can just download the data through the web interface and open the file locally.

In [91]:
uniprot_request = requests.get(query_url)

To get this data into Pandas, we use a `BytesIO` object, which Pandas will treat like a file. If you downloaded the data as a file you can skip this bit and just pass the filepath directly to `read_csv`.

In [92]:
from io import BytesIO
import pandas

bio = BytesIO(uniprot_request.content)

df = pandas.read_csv(bio, compression='gzip', sep='\t')
df

Unnamed: 0,Entry,Sequence,Subcellular location [CC]
0,A0A0K2S4Q6,MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE...,SUBCELLULAR LOCATION: [Isoform 1]: Membrane {E...
1,A0AVI4,MDSPEVTFTLAYLVFAVCFVFTPNEFHAAGLTVQNLLSGWLGSEDA...,SUBCELLULAR LOCATION: Endoplasmic reticulum me...
2,A0JLT2,MENFTALFGAQADPPPPPTALGFGPGKPPPPPPPPAGGGPGTAPPP...,SUBCELLULAR LOCATION: Nucleus {ECO:0000305}.
3,A0M8Q6,GQPKAAPSVTLFPPSSEELQANKATLVCLVSDFNPGAVTVAWKADG...,SUBCELLULAR LOCATION: Secreted {ECO:0000303|Pu...
4,A0PJY2,MDSSCHNATTKMLATAPARGNMMSTSKPLAFSIERIMARTPEPKAL...,SUBCELLULAR LOCATION: Nucleus {ECO:0000269|Pub...
...,...,...,...
11975,Q9H8V8,MKPDWPRRGAAGTRVRSRGEGDGTYFARRGAGRRRREIKAPIRAAW...,
11976,Q9HAA7,MLFGIRILVNTPSPLVTGLHHYNPSIHRDQGECANQWRKGPGSAHL...,
11977,Q9NZ38,MAFPGQSDTKMQWPEVPALPLLSSLCMAMVRKSSALGKEVGRRSEG...,
11978,Q9UFV3,MAETYRRSRQHEQLPGQRHMDLLTGYSKLIQSRLKLLLHLGSQPPV...,


Nice! Now we have some proteins and their subcellular locations. Let's start filtering this down. First, let's ditch the columns without subcellular location information. 

In [93]:
df = df.dropna()  # Drop proteins with missing columns

Now we'll make one dataframe of proteins that contain `cytosol` or `cytoplasm` in their subcellular localization column, and a second that mentions the `membrane` or `cell membrane`. To ensure we don't get overlap, we ensure each dataframe only contains proteins that don't match the other search term.

In [94]:
cytosolic = df['Subcellular location [CC]'].str.contains("Cytosol") | df['Subcellular location [CC]'].str.contains("Cytoplasm")
membrane = df['Subcellular location [CC]'].str.contains("Membrane") | df['Subcellular location [CC]'].str.contains("Cell membrane")

In [95]:
cytosolic_df = df[cytosolic & ~membrane]
cytosolic_df

Unnamed: 0,Entry,Sequence,Subcellular location [CC]
9,A1E959,MKIIILLGFLGATLSAPLIPQRLMSASNSNELLLNLNNGQLLPLQL...,SUBCELLULAR LOCATION: Secreted {ECO:0000250|Un...
14,A1XBS5,MMRRTLENRNAQTKQLQTAVSNVEKHFGELCQIFAAYVRKTARLRD...,SUBCELLULAR LOCATION: Cytoplasm {ECO:0000269|P...
18,A2RU49,MSSGNYQQSEALSKPTFSEEQASALVESVFGLKVSKVRPLPSYDDQ...,SUBCELLULAR LOCATION: Cytoplasm {ECO:0000305}.
20,A2RUH7,MEAATAPEVAAGSKLKVKEASPADAEPPQASPGQGAGSPTPQLLPP...,"SUBCELLULAR LOCATION: Cytoplasm, myofibril, sa..."
21,A4D126,MEAGPPGSARPAEPGPCLSGQRGADHTASASLQSVAGTEPGRHPQA...,"SUBCELLULAR LOCATION: Cytoplasm, cytosol {ECO:..."
...,...,...,...
11563,Q8WWF8,MAGTARHDREMAIQAKKKLTTATDPIERLRLQCLARGSAGIKGLGR...,SUBCELLULAR LOCATION: Cytoplasm {ECO:0000305}.
11685,Q9NUJ7,MGGQVSASNSFSRLHCRNANEDWMSALCPRLWDVPLHHLSIPGSHD...,SUBCELLULAR LOCATION: Cytoplasm {ECO:0000269|P...
11693,Q9P2W6,MGRTWCGMWRRRRPGRRSAVPRWPHLSSQSGVEPPDRWTGTPGWPS...,SUBCELLULAR LOCATION: Cytoplasm.
11709,X6R8D5,MCKDSQKPSVPSHGPKTPSCKGVKAPHSSRPRAWKQDLEQSLAAAY...,"SUBCELLULAR LOCATION: Cytoplasm, cytoskeleton,..."


In [96]:
membrane_df = df[membrane & ~cytosolic]
membrane_df

Unnamed: 0,Entry,Sequence,Subcellular location [CC]
0,A0A0K2S4Q6,MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE...,SUBCELLULAR LOCATION: [Isoform 1]: Membrane {E...
3,A0M8Q6,GQPKAAPSVTLFPPSSEELQANKATLVCLVSDFNPGAVTVAWKADG...,SUBCELLULAR LOCATION: Secreted {ECO:0000303|Pu...
17,A2RU14,MAGTVLGVGAGVFILALLWVAVLLLCVLLSRASGAARFSVIFLFFG...,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
33,A5X5Y0,MEGSWFHRKRFSFYLLLGFLLQGRGVTFTINCSGFGQHGADPTALN...,SUBCELLULAR LOCATION: Postsynaptic cell membra...
36,A6ND01,MACWWPLLLELWTVMPTWAGDELLNICMNAKHHKRVPSPEDKLYEE...,SUBCELLULAR LOCATION: Cell membrane {ECO:00002...
...,...,...,...
11901,Q86UQ5,MQSDIYHPGHSFPSWVLCWVHSCGHEGHLRETAEIRKTHQNGDLQI...,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
11924,Q8N8V8,MLLKVRRASLKPPATPHQGAFRAGNVIGQLIYLLTWSLFTAWLRPP...,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
11960,Q96N68,MQGQGALKESHIHLPTEQPEASLVLQGQLAESSALGPKGALRPQAQ...,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
11968,Q9H0A3,MMNNTDFLMLNNPWNKLCLVSMDFCFPLDFVSNLFWIFASKFIIVT...,SUBCELLULAR LOCATION: Membrane {ECO:0000255}; ...


We're almost done! Now, let's make a list of sequences from each df and generate the associated labels. We'll use `0` as the label for cytosolic proteins and `1` as the label for membrane proteins.

In [97]:
cytosolic_sequences = cytosolic_df["Sequence"].tolist()
cytosolic_labels = [0 for protein in cytosolic_sequences]

In [98]:
membrane_sequences = membrane_df["Sequence"].tolist()
membrane_labels = [1 for protein in membrane_sequences]

Now we can concatenate these lists together to get the `sequences` and `labels` lists that will form our final training data. Don't worry - they'll get shuffled during training!

In [112]:
sequences = cytosolic_sequences + membrane_sequences
labels = cytosolic_labels + membrane_labels

# Quick check to make sure we got it right
len(sequences) == len(labels)

True

In [113]:
sequences

['MKIIILLGFLGATLSAPLIPQRLMSASNSNELLLNLNNGQLLPLQLQGPLNSWIPPFSGILQQQQQAQIPGLSQFSLSALDQFAGLLPNQIPLTGEASFAQGAQAGQVDPLQLQTPPQTQPGPSHVMPYVFSFKMPQEQGQMFQYYPVYMVLPWEQPQQTVPRSPQQTRQQQYEEQIPFYAQFGYIPQLAEPAISGGQQQLAFDPQLGTAPEIAVMSTGEEIPYLQKEAINFRHDSAGVFMPSTSPKPSTTNVFTSAVDQTITPELPEEKDKTDSLREP',
 'MMRRTLENRNAQTKQLQTAVSNVEKHFGELCQIFAAYVRKTARLRDKADLLVNEINAYAATETPHLKLGLMNFADEFAKLQDYRQAEVERLEAKVVEPLKTYGTIVKMKRDDLKATLTARNREAKQLTQLERTRQRNPSDRHVISQAETELQRAAMDASRTSRHLEETINNFERQKMKDIKTIFSEFITIEMLFHGKALEVYTAAYQNIQNIDEDEDLEVFRNSLYAPDYSSRLDIVRANSKSPLQRSLSAKCVSGTGQVSTCRLRKDQQAEDDEDDELDVTEEENFLK',
 'MSSGNYQQSEALSKPTFSEEQASALVESVFGLKVSKVRPLPSYDDQNFHVYVSKTKDGPTEYVLKISNTKASKNPDLIEVQNHIIMFLKAAGFPTASVCHTKGDNTASLVSVDSGSEIKSYLVRLLTYLPGRPIAELPVSPQLLYEIGKLAAKLDKTLQRFHHPKLSSLHRENFIWNLKNVPLLEKYLYALGQNRNREIVEHVIHLFKEEVMTKLSHFRECINHGDLNDHNILIESSKSASGNAEYQVSGILDFGDMSYGYYVFEVAITIMYMMIESKSPIQVGGHVLAGFESITPLTAVEKGALFLLVCSRFCQSLVMAAYSCQLYPENKDYLMVTAKTGWKHLQQMFDMGQKAVEEIWFETAKSYESGISM',
 'MEAATAPEVAAGSKLKVKEASPADAEPPQASPGQGAGSPTPQ

In [114]:
labels

[0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,


# Logistic regression with one hot encoding average

In [102]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
import numpy as np

def one_hot_encoding_average(input_sequence_str):
    # Split the input string into individual amino acids
    amino_acids = list(input_sequence_str)

    # one-hot encode the amino acids
    amino_acids_alphabet = "A C D E F G H I K L M N P Q R S T V W Y U".split(" ")
    one_hot_encoding_dict = {
        aa: np.array([1 if aa == amino_acid else 0 for amino_acid in amino_acids_alphabet])
        for aa in amino_acids_alphabet
    }

    # Encode the sequence 
    output_one_hot_encoder = np.array([0 for i in range(len(amino_acids_alphabet))])
    for i in amino_acids:
        one_hot_encode = one_hot_encoding_dict[i]
        output_one_hot_encoder += one_hot_encode 
    
    return output_one_hot_encoder / len(amino_acids)


one_hot_encoding_avg_sequences = [one_hot_encoding_average(seq) for seq in sequences]

# Split the one-hot encoded data into train, validation, and test sets
X_train_oh, X_test_oh, y_train_oh, y_test_oh = train_test_split(
    one_hot_encoding_avg_sequences, labels, test_size=0.2, random_state=42, stratify=labels
)
X_train_oh, X_val_oh, y_train_oh, y_val_oh = train_test_split(
    X_train_oh, y_train_oh, test_size=0.25, random_state=42, stratify=y_train_oh
)

print(f"Logistic Regression - Train size: {len(X_train_oh)}")
print(f"Logistic Regression - Validation size: {len(X_val_oh)}")
print(f"Logistic Regression - Test size: {len(X_test_oh)}")

# Train logistic regression
lr_model = LogisticRegression(random_state=42, max_iter=1000)
lr_model.fit(X_train_oh, y_train_oh)

# Predictions
y_train_pred = lr_model.predict(X_train_oh)
y_val_pred = lr_model.predict(X_val_oh)
y_test_pred = lr_model.predict(X_test_oh)

# Calculate accuracies
train_accuracy_lr = accuracy_score(y_train_oh, y_train_pred)
val_accuracy_lr = accuracy_score(y_val_oh, y_val_pred)
test_accuracy_lr = accuracy_score(y_test_oh, y_test_pred)

print(f"\nLogistic Regression Results:")
print(f"Train Accuracy: {train_accuracy_lr:.4f}")
print(f"Validation Accuracy: {val_accuracy_lr:.4f}")
print(f"Test Accuracy: {test_accuracy_lr:.4f}")

print(f"\nClassification Report (Test Set):")
print(classification_report(y_test_oh, y_test_pred))

Logistic Regression - Train size: 3116
Logistic Regression - Validation size: 1039
Logistic Regression - Test size: 1039

Logistic Regression Results:
Train Accuracy: 0.8338
Validation Accuracy: 0.8383
Test Accuracy: 0.8335

Classification Report (Test Set):
              precision    recall  f1-score   support

           0       0.80      0.89      0.85       532
           1       0.87      0.77      0.82       507

    accuracy                           0.83      1039
   macro avg       0.84      0.83      0.83      1039
weighted avg       0.84      0.83      0.83      1039



# SVM

In [103]:
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, classification_report

# Train SVM with a small kernel (using RBF kernel with default parameters)
svm_model = SVC(kernel='rbf', random_state=42)
svm_model.fit(X_train_oh, y_train_oh)

# Predictions
y_train_pred_svm = svm_model.predict(X_train_oh)
y_val_pred_svm = svm_model.predict(X_val_oh)
y_test_pred_svm = svm_model.predict(X_test_oh)

# Calculate accuracies
train_accuracy_svm = accuracy_score(y_train_oh, y_train_pred_svm)
val_accuracy_svm = accuracy_score(y_val_oh, y_val_pred_svm)
test_accuracy_svm = accuracy_score(y_test_oh, y_test_pred_svm)

print(f"SVM Results:")
print(f"Train Accuracy: {train_accuracy_svm:.4f}")
print(f"Validation Accuracy: {val_accuracy_svm:.4f}")
print(f"Test Accuracy: {test_accuracy_svm:.4f}")

print(f"\nSVM Classification Report (Test Set):")
print(classification_report(y_test_oh, y_test_pred_svm))

SVM Results:
Train Accuracy: 0.8752
Validation Accuracy: 0.8576
Test Accuracy: 0.8585

SVM Classification Report (Test Set):
              precision    recall  f1-score   support

           0       0.83      0.90      0.87       532
           1       0.89      0.81      0.85       507

    accuracy                           0.86      1039
   macro avg       0.86      0.86      0.86      1039
weighted avg       0.86      0.86      0.86      1039



# Simple RNN

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import numpy as np
import os

import torch.nn as nn
import torch.optim as optim

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Create amino acid vocabulary
amino_acids = 'ACDEFGHIKLMNPQRSTVWY'
aa_to_idx = {aa: idx for idx, aa in enumerate(amino_acids)}
aa_to_idx['<PAD>'] = len(amino_acids)  # padding token
vocab_size = len(aa_to_idx)

# Define Dataset class
class ProteinDataset(Dataset):
    def __init__(self, sequences, labels, max_length=500):
        self.sequences = sequences
        self.labels = labels
        self.max_length = max_length
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        label = self.labels[idx]
        
        # Convert sequence to indices
        seq_indices = [aa_to_idx.get(aa, aa_to_idx['<PAD>']) for aa in sequence[:self.max_length]]
        
        # Pad sequence
        if len(seq_indices) < self.max_length:
            seq_indices += [aa_to_idx['<PAD>']] * (self.max_length - len(seq_indices))
        
        return torch.tensor(seq_indices, dtype=torch.long), torch.tensor(label, dtype=torch.long)

# Define RNN model
class ProteinRNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim=64, hidden_dim=128, num_layers=2, num_classes=2, dropout=0.3):
        super(ProteinRNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=aa_to_idx['<PAD>'])
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_dim, num_classes)
        
    def forward(self, x):
        embedded = self.embedding(x)
        rnn_out, (hidden, _) = self.rnn(embedded)
        # Use the last hidden state
        output = self.dropout(hidden[-1])
        output = self.fc(output)
        return output

# Split data: 60% train, 20% validation, 20% test
X_temp, X_test, y_temp, y_test = train_test_split(sequences, labels, test_size=0.2, random_state=42, stratify=labels)
X_train, X_val, y_train, y_val = train_test_split(X_temp, y_temp, test_size=0.25, random_state=42, stratify=y_temp)

print(f"Train size: {len(X_train)}")
print(f"Validation size: {len(X_val)}")
print(f"Test size: {len(X_test)}")

# Create datasets and dataloaders
train_dataset = ProteinDataset(X_train, y_train)
val_dataset = ProteinDataset(X_val, y_val)
test_dataset = ProteinDataset(X_test, y_test)

batch_size = 100
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Initialize model, loss function, and optimizer
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
model = ProteinRNN(vocab_size=vocab_size).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

print(f"Using device: {device}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")

# Training loop
num_epochs = 150
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

def calculate_accuracy(loader, model, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for sequences, label in loader:
            sequences, label = sequences.to(device), label.to(device)
            outputs = model(sequences)
            _, predicted = torch.max(outputs.data, 1)
            total += label.size(0)
            correct += (predicted == label).sum().item()
    return 100 * correct / total

print("Starting training...")
for epoch in range(num_epochs):
    # Training phase
    model.train()
    running_loss = 0.0
    for batch_idx, (sequence, label) in enumerate(train_loader):
        sequence, label = sequence.to(device), label.to(device)

        optimizer.zero_grad()
        outputs = model(sequence)
        loss = criterion(outputs, label)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    # Calculate average training loss
    avg_train_loss = running_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    
    # Validation phase
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for sequence, label in val_loader:
            sequence, label = sequence.to(device), label.to(device)
            outputs = model(sequence)
            loss = criterion(outputs, label)
            val_loss += loss.item()
    
    avg_val_loss = val_loss / len(val_loader)
    val_losses.append(avg_val_loss)
    
    # Calculate accuracies
    train_acc = calculate_accuracy(train_loader, model, device)
    val_acc = calculate_accuracy(val_loader, model, device)
    
    train_accuracies.append(train_acc)
    val_accuracies.append(val_acc)
    
    print(f'Epoch [{epoch+1}/{num_epochs}]')
    print(f'  Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.2f}%')
    print(f'  Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.2f}%')
    print('-' * 50)

# Final test accuracy
test_acc = calculate_accuracy(test_loader, model, device)
print(f'\nFinal Test Accuracy: {test_acc:.2f}%')

# Create models directory if it doesn't exist
os.makedirs('models', exist_ok=True)

# Save the trained model
torch.save(model.state_dict(), 'models/DeepLog_RNN.pth')
print("Model saved to models/DeepLog_RNN.pth")

# Optionally, save additional training information
training_info = {
    'vocab_size': vocab_size,
    'embedding_dim': 64,
    'hidden_dim': 128,
    'num_layers': 2,
    'num_classes': 2,
    'dropout': 0.3,
    'aa_to_idx': aa_to_idx,
    'train_losses': train_losses,
    'val_losses': val_losses,
    'train_accuracies': train_accuracies,
    'val_accuracies': val_accuracies,
    'final_test_accuracy': test_acc
}

torch.save(training_info, 'models/DeepLog_RNN_info.pth')
print("Training information saved to models/DeepLog_RNN_info.pth")

Train size: 3116
Validation size: 1039
Test size: 1039
Using device: mps
Model parameters: 233026
Starting training...
Epoch [1/150]
  Train Loss: 0.6947, Train Acc: 50.26%
  Val Loss: 0.6928, Val Acc: 49.66%
--------------------------------------------------
Epoch [2/150]
  Train Loss: 0.6938, Train Acc: 49.13%
  Val Loss: 0.6933, Val Acc: 48.99%
--------------------------------------------------
Epoch [3/150]
  Train Loss: 0.6935, Train Acc: 51.28%
  Val Loss: 0.6927, Val Acc: 51.59%
--------------------------------------------------
Epoch [4/150]
  Train Loss: 0.6935, Train Acc: 51.25%
  Val Loss: 0.6927, Val Acc: 51.30%
--------------------------------------------------
Epoch [5/150]
  Train Loss: 0.6933, Train Acc: 50.77%
  Val Loss: 0.6943, Val Acc: 51.01%
--------------------------------------------------
Epoch [6/150]
  Train Loss: 0.6918, Train Acc: 51.35%
  Val Loss: 0.6932, Val Acc: 51.59%
--------------------------------------------------
Epoch [7/150]
  Train Loss: 0.6912,

In [107]:
# Final test accuracy
test_acc = calculate_accuracy(test_loader, model, device)
print(f'\nFinal Test Accuracy: {test_acc:.2f}%')


Final Test Accuracy: 91.92%


## Splitting the data

Since the data we're loading isn't prepared for us as a machine learning dataset, we'll have to split the data into train and test sets ourselves! We can use sklearn's function for that:

In [111]:
labels

tensor([0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1], device='mps:0')

In [115]:
from sklearn.model_selection import train_test_split

train_sequences, test_sequences, train_labels, test_labels = train_test_split(sequences, labels, test_size=0.25, shuffle=True)

## Tokenizing the data

All inputs to neural nets must be numerical. The process of converting strings into numerical indices suitable for a neural net is called **tokenization**. For natural language this can be quite complex, as usually the network's vocabulary will not contain every possible word, which means the tokenizer must handle splitting rarer words into pieces, as well as all the complexities of capitalization and unicode characters and so on.

With proteins, however, things are very easy. In protein language models, each amino acid is converted to a single token. Every model on `transformers` comes with an associated `tokenizer` that handles tokenization for it, and protein language models are no different. Let's get our tokenizer!

In [116]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

Let's try a single sequence to see what the outputs from our tokenizer look like:

In [117]:
tokenizer(train_sequences[0])

{'input_ids': [0, 20, 9, 14, 6, 17, 13, 11, 16, 12, 8, 9, 18, 4, 4, 4, 6, 18, 8, 16, 9, 14, 6, 4, 16, 14, 18, 4, 18, 6, 4, 18, 4, 8, 20, 19, 4, 7, 11, 7, 4, 6, 17, 4, 4, 12, 12, 4, 5, 11, 12, 8, 13, 8, 21, 4, 21, 11, 14, 20, 19, 18, 18, 4, 8, 17, 4, 8, 18, 5, 13, 12, 23, 7, 11, 8, 11, 11, 12, 14, 15, 20, 4, 20, 17, 12, 16, 11, 16, 17, 15, 7, 12, 11, 19, 12, 5, 23, 4, 20, 16, 20, 19, 18, 18, 12, 4, 18, 5, 6, 18, 9, 17, 18, 4, 4, 8, 7, 20, 5, 19, 13, 10, 18, 7, 5, 12, 23, 21, 14, 4, 21, 19, 20, 7, 12, 20, 17, 14, 21, 4, 23, 6, 4, 4, 7, 4, 5, 8, 22, 11, 20, 8, 5, 4, 19, 8, 4, 4, 16, 12, 4, 20, 7, 7, 10, 4, 8, 18, 23, 11, 5, 4, 9, 12, 14, 21, 18, 18, 23, 9, 4, 17, 16, 7, 12, 16, 4, 5, 23, 8, 13, 8, 18, 4, 17, 21, 20, 7, 12, 19, 18, 11, 7, 5, 4, 4, 6, 6, 6, 14, 4, 11, 6, 12, 4, 19, 8, 19, 8, 15, 12, 12, 8, 8, 12, 21, 5, 12, 8, 8, 5, 16, 6, 15, 19, 15, 5, 18, 8, 11, 23, 5, 8, 21, 4, 8, 7, 7, 8, 4, 18, 19, 6, 5, 12, 4, 6, 7, 19, 4, 8, 8, 5, 5, 11, 10, 17, 8, 21, 8, 8, 5, 11, 5, 8, 7, 20, 19, 

This looks good! We can see that our sequence has been converted into `input_ids`, which is the tokenized sequence, and an `attention_mask`. The attention mask handles the case when we have sequences of variable length - in those cases, the shorter sequences are padded with blank "padding" tokens, and the attention mask is padded with 0s to indicate that those tokens should be ignored by the model.

So now, let's tokenize our whole dataset. Note that we don't need to do anything with the labels, as they're already in the format we need.

In [118]:
train_tokenized = tokenizer(train_sequences)
test_tokenized = tokenizer(test_sequences)

## Dataset creation

Now we want to turn this data into a dataset that PyTorch can load samples from. We can use the HuggingFace `Dataset` class for this, although if you prefer you can also use `torch.utils.data.Dataset`, at the cost of some more boilerplate code.

In [125]:
from datasets import Dataset
train_dataset = Dataset.from_dict(train_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)

train_dataset

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 3895
})

This looks good, but we're missing our labels! Let's add those on as an extra column to the datasets.

In [126]:
train_dataset = train_dataset.add_column("labels", train_labels)
test_dataset = test_dataset.add_column("labels", test_labels)
train_dataset

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 3895
})

Looks good! We're ready for training.

## Model loading

Next, we want to load our model. Make sure to use exactly the same model as you used when loading the tokenizer, or your model might not understand the tokenization scheme you're using!

In [127]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

num_labels = max(train_labels + test_labels) + 1  # Add 1 since 0 can be a label
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)

Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


These warnings are telling us that the model is discarding some weights that it used for language modelling (the `lm_head`) and adding some weights for sequence classification (the `classifier`). This is exactly what we expect when we want to fine-tune a language model on a sequence classification task!

Next, we initialize our `TrainingArguments`. These control the various training hyperparameters, and will be passed to our `Trainer`.

In [128]:
model_name = model_checkpoint.split("/")[-1]
batch_size = 8

args = TrainingArguments(
    f"{model_name}-finetuned-localization",
    eval_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
)

Next, we define the metric we will use to evaluate our models and write a `compute_metrics` function. We can load this from the `evaluate` library.

In [129]:
from evaluate import load
import numpy as np

metric = load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

And at last we're ready to initialize our `Trainer`:

In [130]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

  trainer = Trainer(


You might wonder why we pass along the `tokenizer` when we already preprocessed our data. This is because we will use it one last time to make all the samples we gather the same length by applying padding, which requires knowing the model's preferences regarding padding (to the left or right? with which token?). The `tokenizer` has a pad method that will do all of this right for us, and the `Trainer` will use it. You can customize this part by defining and passing your own `data_collator` which will receive samples like the dictionaries seen above and will need to return a dictionary of tensors.

We can now finetune our model by just calling the `train` method:

In [131]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.186131,0.942263
2,0.246400,0.188802,0.943033
3,0.156100,0.187268,0.946882




TrainOutput(global_step=1461, training_loss=0.17514701806707467, metrics={'train_runtime': 825.5789, 'train_samples_per_second': 14.154, 'train_steps_per_second': 1.77, 'total_flos': 1047062636036622.0, 'train_loss': 0.17514701806707467, 'epoch': 3.0})

Nice! After three epochs we have a model accuracy of ~94%. Note that we didn't do a lot of work to filter the training data or tune hyperparameters for this experiment, and also that we used one of the smallest ESM-2 models. With a larger starting model and more effort to ensure that the training data categories were cleanly separable, accuracy could almost certainly go a lot higher!

***
# Token classification

Another common language model task is **token classification**. In this task, instead of classifying the whole sequence into a single category, we categorize each token (amino acid, in this case!) into one or more categories. This kind of model could be useful for:

- Predicting secondary structure
- Predicting buried vs. exposed residues
- Predicting residues that will receive post-translational modifications
- Predicting residues involved in binding pockets or active sites
- Probably several other things, it's been a while since I was a postdoc

## Data preparation

In this section, we're going to gather some training data from UniProt. As in the sequence classification example, we aim to create two lists: `sequences` and `labels`. Unlike in that example, however, the `labels` are more than just single integers. Instead, the label for each sample will be **one integer per token in the input**. This should make sense - when we do token classification, different tokens in the input may have different categories!

To demonstrate token classification, we're going to go back to UniProt and get some data on protein secondary structures. As above, this will probably the main section you want to change when adapting this code to your own problems.

In [None]:
import requests

query_url ="https://rest.uniprot.org/uniprotkb/stream?compressed=true&fields=accession%2Csequence%2Cft_strand%2Cft_helix&format=tsv&query=%28%28organism_id%3A9606%29%20AND%20%28reviewed%3Atrue%29%20AND%20%28length%3A%5B80%20TO%20500%5D%29%29"

This time, our UniProt search was `(organism_id:9606) AND (reviewed:true) AND (length:[100 TO 1000])` as it was in the first example, but instead of `Subcellular location [CC]` we take the `Helix` and `Beta strand` columns, as they contain the secondary structure information we want.

In [None]:
uniprot_request = requests.get(query_url)

To get this data into Pandas, we use a `BytesIO` object, which Pandas will treat like a file. If you downloaded the data as a file you can skip this bit and just pass the filepath directly to `read_csv`.

In [None]:
from io import BytesIO
import pandas

bio = BytesIO(uniprot_request.content)

df = pandas.read_csv(bio, compression='gzip', sep='\t')
df

Unnamed: 0,Entry,Sequence,Beta strand,Helix
0,A0A0K2S4Q6,MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE...,,
1,A0AVI4,MDSPEVTFTLAYLVFAVCFVFTPNEFHAAGLTVQNLLSGWLGSEDA...,,
2,A0JLT2,MENFTALFGAQADPPPPPTALGFGPGKPPPPPPPPAGGGPGTAPPP...,"STRAND 79..81; /evidence=""ECO:0007829|PDB:7EMF""","HELIX 83..86; /evidence=""ECO:0007829|PDB:7EMF""..."
3,A0M8Q6,GQPKAAPSVTLFPPSSEELQANKATLVCLVSDFNPGAVTVAWKADG...,,
4,A0PJY2,MDSSCHNATTKMLATAPARGNMMSTSKPLAFSIERIMARTPEPKAL...,,
...,...,...,...,...
11975,Q9H8V8,MKPDWPRRGAAGTRVRSRGEGDGTYFARRGAGRRRREIKAPIRAAW...,,
11976,Q9HAA7,MLFGIRILVNTPSPLVTGLHHYNPSIHRDQGECANQWRKGPGSAHL...,,
11977,Q9NZ38,MAFPGQSDTKMQWPEVPALPLLSSLCMAMVRKSSALGKEVGRRSEG...,,
11978,Q9UFV3,MAETYRRSRQHEQLPGQRHMDLLTGYSKLIQSRLKLLLHLGSQPPV...,,


Since not all proteins have this structural information, we discard proteins that have no annotated beta strands or alpha helices.

In [None]:
no_structure_rows = df["Beta strand"].isna() & df["Helix"].isna()
df = df[~no_structure_rows]
df

Unnamed: 0,Entry,Sequence,Beta strand,Helix
2,A0JLT2,MENFTALFGAQADPPPPPTALGFGPGKPPPPPPPPAGGGPGTAPPP...,"STRAND 79..81; /evidence=""ECO:0007829|PDB:7EMF""","HELIX 83..86; /evidence=""ECO:0007829|PDB:7EMF""..."
13,A1L3X0,MAFSDLTSRTVHLYDNWIKDADPRVEDWLLMSSPLPQTILLGFYVY...,"STRAND 97..99; /evidence=""ECO:0007829|PDB:6Y7F""","HELIX 17..20; /evidence=""ECO:0007829|PDB:6Y7F""..."
14,A1XBS5,MMRRTLENRNAQTKQLQTAVSNVEKHFGELCQIFAAYVRKTARLRD...,,"HELIX 2..6; /evidence=""ECO:0007829|PDB:8CEG""; ..."
15,A1Z1Q3,MYPSNKKKKVWREEKERLLKMTLEERRKEYLRDYIPLNSILSWKEE...,"STRAND 71..77; /evidence=""ECO:0007829|PDB:4IQY...","HELIX 11..19; /evidence=""ECO:0007829|PDB:4IQY""..."
19,A2RUC4,MAGQHLPVPRLEGVSREQFMQHLYPQRKPLVLEGIDLGPCTSKWTV...,"STRAND 10..13; /evidence=""ECO:0007829|PDB:3AL5...","HELIX 16..22; /evidence=""ECO:0007829|PDB:3AL5""..."
...,...,...,...,...
11595,Q96I45,MVNLGLSRVDDAVAAKHPGLGEYAACQSHAFMKGVFTFVTGTGMAF...,"STRAND 3..5; /evidence=""ECO:0007829|PDB:2LOR"";...","HELIX 6..16; /evidence=""ECO:0007829|PDB:2LOR"";..."
11651,Q9H0W7,MPTNCAAAGCATTYNKHINISFHRFPLDPKRRKEWVRLVRRKNFVP...,"STRAND 7..9; /evidence=""ECO:0007829|PDB:2D8R"";...","HELIX 29..38; /evidence=""ECO:0007829|PDB:2D8R"""
11690,Q9P1F3,MNVDHEVNLLVEEIHRLGSKNADGKLSVKFGVLFRDDKCANLFEAL...,"STRAND 24..29; /evidence=""ECO:0007829|PDB:2L2O...","HELIX 3..17; /evidence=""ECO:0007829|PDB:2L2O"";..."
11692,Q9P298,MSANRRWWVPPDDEDCVSEKLLRKTRESPLVPIGLGGCLVVAAYRI...,"STRAND 11..14; /evidence=""ECO:0007829|PDB:2LON...","HELIX 18..24; /evidence=""ECO:0007829|PDB:2LON""..."


Well, this works, but that data still isn't in a clean format that we can use to build our labels. Let's take a look at one sample to see what exactly we're dealing with:

In [None]:
df.iloc[0]["Helix"]

'HELIX 83..86; /evidence="ECO:0007829|PDB:7EMF"; HELIX 90..96; /evidence="ECO:0007829|PDB:7EMF"; HELIX 112..116; /evidence="ECO:0007829|PDB:7EMF"; HELIX 128..138; /evidence="ECO:0007829|PDB:7EMF"; HELIX 147..152; /evidence="ECO:0007829|PDB:7EMF"'

We'll need to use a [regex](https://docs.python.org/3/howto/regex.html) to pull out each segment that's marked as being a STRAND or HELIX. What we're asking for is a list of everywhere we see the word STRAND or HELIX followed by two numbers separated by two dots. In each case where this pattern is found, we tell the regex to extract the two numbers as a tuple for us.

In [None]:
import re

strand_re = r"STRAND\s(\d+)\.\.(\d+)\;"
helix_re = r"HELIX\s(\d+)\.\.(\d+)\;"

re.findall(helix_re, df.iloc[0]["Helix"])

[('83', '86'), ('90', '96'), ('112', '116'), ('128', '138'), ('147', '152')]

Looks good! We can use this to build our training data. Recall that the **labels** need to be a list or array of integers that's the same length as the input sequence. We're going to use 0 to indicate residues without any annotated structure, 1 for residues in an alpha helix, and 2 for residues in a beta strand. To build that, we'll start with an array of all 0s, and then fill in values based on the positions that our regex pulls out of the UniProt results.

We'll use NumPy arrays rather than lists here, since these allow [slice assignment](https://numpy.org/doc/stable/user/basics.indexing.html#assigning-values-to-indexed-arrays), which will be a lot simpler than editing a list of integers. Note also that UniProt annotates residues starting from 1 (unlike Python, which starts from 0), and region annotations are inclusive (so 1..3 means residues 1, 2 and 3). To turn these into Python slices, we subtract 1 from the start of each annotation, but not the end.

In [None]:
import numpy as np

def build_labels(sequence, strands, helices):
    # Start with all 0s
    labels = np.zeros(len(sequence), dtype=np.int64)
    
    if isinstance(helices, float): # Indicates missing (NaN)
        found_helices = []
    else:
        found_helices = re.findall(helix_re, helices)
    for helix_start, helix_end in found_helices:
        helix_start = int(helix_start) - 1
        helix_end = int(helix_end)
        assert helix_end <= len(sequence)
        labels[helix_start: helix_end] = 1  # Helix category
    
    if isinstance(strands, float): # Indicates missing (NaN)
        found_strands = []
    else:
        found_strands = re.findall(strand_re, strands)
    for strand_start, strand_end in found_strands:
        strand_start = int(strand_start) - 1
        strand_end = int(strand_end)
        assert strand_end <= len(sequence)
        labels[strand_start: strand_end] = 2  # Strand category
    return labels

Now we've defined a helper function, let's build our lists of sequences and labels:

In [None]:
sequences = []
labels = []

for row_idx, row in df.iterrows():
    row_labels = build_labels(row["Sequence"], row["Beta strand"], row["Helix"])
    sequences.append(row["Sequence"])
    labels.append(row_labels)

## Creating our dataset

Nice! Now we'll split and tokenize the data, and then create datasets - I'll go through this quite quickly here, since it's identical to how we did it in the sequence classification example above.

In [None]:
from sklearn.model_selection import train_test_split

train_sequences, test_sequences, train_labels, test_labels = train_test_split(sequences, labels, test_size=0.25, shuffle=True)

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

train_tokenized = tokenizer(train_sequences)
test_tokenized = tokenizer(test_sequences)

In [None]:
from datasets import Dataset

train_dataset = Dataset.from_dict(train_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)

train_dataset = train_dataset.add_column("labels", train_labels)
test_dataset = test_dataset.add_column("labels", test_labels)

## Model loading

The key difference here with the above example is that we use `AutoModelForTokenClassification` instead of `AutoModelForSequenceClassification`. We will also need a `data_collator` this time, as we're in the slightly more complex case where both inputs and labels must be padded in each batch.

In [None]:
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer

num_labels = 3
model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=num_labels)

Some weights of EsmForTokenClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer)

Now we set up our `TrainingArguments` as before.

In [None]:
model_name = model_checkpoint.split("/")[-1]
batch_size = 8

args = TrainingArguments(
    f"{model_name}-finetuned-secondary-structure",
    eval_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=1e-4,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.001,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
)

Our `compute_metrics` function is a bit more complex than in the sequence classification task, as we need to ignore padding tokens (those where the label is `-100`).

In [None]:
from evaluate import load
import numpy as np

metric = load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    labels = labels.reshape((-1,))
    predictions = np.argmax(predictions, axis=2)
    predictions = predictions.reshape((-1,))
    predictions = predictions[labels!=-100]
    labels = labels[labels!=-100]
    return metric.compute(predictions=predictions, references=labels)

And now we're ready to train our model! 

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
)

trainer.train()

  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.456698,0.813887
2,0.494900,0.441967,0.820196
3,0.386400,0.447364,0.822411




TrainOutput(global_step=1224, training_loss=0.42138489243251825, metrics={'train_runtime': 672.0095, 'train_samples_per_second': 14.558, 'train_steps_per_second': 1.821, 'total_flos': 890104254615072.0, 'train_loss': 0.42138489243251825, 'epoch': 3.0})

This definitely seems harder than the first task, but we still attain a very respectable accuracy. Remember that to keep this demo lightweight, we used one of the smallest ESM models, focused on human proteins only and didn't put a lot of work into making sure we only included completely-annotated proteins in our training set. With a bigger model and a cleaner, broader training set, accuracy on this task could definitely go a lot higher!