This notebook was adapted from [https://github.com/RSchmirler/data-repo_plm-finetune-eval/tree/main](https://github.com/RSchmirler/data-repo_plm-finetune-eval/tree/main).

The goal is to train a residue classifier using embeddings extracted from pretrained protein language models.

# Setup

In [None]:
import pandas as pd
from datasets import Dataset, load_from_disk
import numpy as np
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import random
from sklearn.metrics import (
    f1_score, precision_score, recall_score, accuracy_score,
    matthews_corrcoef, roc_auc_score)

Available checkpoints:

In [2]:
ESMs = [ "esm2_t6_8M_UR50D" ,
         "esm2_t12_35M_UR50D" ,
         "esm2_t30_150M_UR50D" ,
         "esm2_t33_650M_UR50D",
         "esm2_t36_3B_UR50D"]

ProtT5 = ["prot_t5_xl_uniref50"] 

20 features from uniprot:

In [3]:
all_features = ['Active site', 'Binding site', 'DNA binding', 
                'Topological domain', 'Transmembrane',
                'Disulfide bond', 'Modified residue', 'Propeptide', 'Signal peptide', 'Transit peptide',
                'Beta strand', 'Helix', 'Turn',
                'Coiled coil', 'Compositional bias', 'Domain [FT]', 'Motif', 'Region', 'Repeat', 'Zinc finger']

all_features_re = ['ACT_SITE', 'BINDING', 'DNA_BIND', 
                   'TOPO_DOM', 'TRANSMEM',
                   'DISULFID', 'MOD_RES',  'PROPEP', 'SIGNAL', 'TRANSIT',
                   'STRAND', 'HELIX', 'TURN',
                   'COILED', 'COMPBIAS', 'DOMAIN', 'MOTIF', 'REGION', 'REPEAT', 'ZN_FING']

Function to extract labels from uniprot text descriptions:

In [4]:
def build_labels_region(sequence, feature, feature_re):
    # Start with all 0s
    labels = np.zeros(len(sequence), dtype=np.int64)
    region_re = f"{feature_re}\s(\d+)\.\.(\d+)\;"
    residue_re = f'{feature_re}\s(\d+);'

    found_region = re.findall(region_re, feature)

    for start, end in found_region:
        start = int(start) - 1
        end = int(end)
        assert end <= len(sequence)
        labels[start: end] = 1

    found_residue = re.findall(residue_re, feature)
    for pos in found_residue:
        pos = int(pos) -1
        assert pos <= len(sequence)
        labels[pos] = 1

    return labels


def build_labels_bonds(sequence, feature, feature_re):
    # Start with all 0s
    labels = np.zeros(len(sequence), dtype=np.int64)

    region_re = f"{feature_re}\s(\d+)\.\.(\d+)\;"

    if isinstance(feature, float): # Indicates missing (NaN)
        found_feature = []
    else:
        found_feature = re.findall(region_re, feature)
    for start, end in found_feature:
        start = int(start) - 1
        end = int(end) -1
        assert end <= len(sequence)
        labels[start] = 1
        labels[end] = 1 
    return labels


# Model architecture

In [5]:
def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


# Implementation of the linear layer
class EmbPredictor(nn.Module):      
    def __init__(self, input_dim, dense, dropout):
        super().__init__()
        self.normalizer = nn.BatchNorm1d(input_dim)
        self.fc1 = nn.Linear(input_dim, dense)
        self.dropout = nn.Dropout(dropout)
        self.output = nn.Linear(dense, 2)  # 2 classes

    def forward(self, x):
        x = self.normalizer(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.output(x)  # logits


# function to train the model
def train_predictor(train_embeds, train_labels,
                    val_embeds, val_labels,
                    test_embeds, test_labels,
                    epochs=10, lr=3e-4, epsilon=1e-7,
                    batch=64, dropout=0.2, dense=32, seed=99,
                    save_path="best_model.pt", metric_path="test_metrics.tsv"):

    set_seed(seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Convert numpy arrays to tensors
    X_train = torch.tensor(train_embeds, dtype=torch.float32).to(device)
    y_train = torch.tensor(train_labels, dtype=torch.long).squeeze().to(device)

    X_val = torch.tensor(val_embeds, dtype=torch.float32).to(device)
    y_val = torch.tensor(val_labels, dtype=torch.long).squeeze().to(device)

    X_test = torch.tensor(test_embeds, dtype=torch.float32).to(device)
    y_test = torch.tensor(test_labels, dtype=torch.long).squeeze().to(device)

    train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=batch, shuffle=True)

    model = EmbPredictor(X_train.shape[1], dense, dropout).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, eps=epsilon)
    criterion = nn.CrossEntropyLoss()

    best_val_loss = float('inf')
    best_model_state = None

    for epoch in range(epochs):
        model.train()
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()
            logits = model(xb)
            loss = criterion(logits, yb)
            loss.backward()
            optimizer.step()

        model.eval()
        with torch.no_grad():
            val_logits = model(X_val)
            val_loss = criterion(val_logits, y_val).item()

        print(f"Epoch {epoch+1}/{epochs} - Validation loss: {val_loss:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict()

    # Save best model
    if best_model_state is not None:
        torch.save(best_model_state, save_path)
        print(f"Best model saved to {save_path} with val loss: {best_val_loss:.4f}")

    # Load best model and evaluate on test
    model.load_state_dict(torch.load(save_path))
    model.to(device)
    model.eval()

    with torch.no_grad():
        test_logits = model(X_test)
        test_probs = torch.softmax(test_logits, dim=1)
        test_preds = test_probs.argmax(dim=1).cpu().numpy()
        y_test_np = y_test.cpu().numpy()
        test_probs_np = test_probs[:, 1].cpu().numpy()  # positive class probs

    # Compute metrics
    metrics = {
        "f1": f1_score(y_test_np, test_preds, average='macro'),
        "precision": precision_score(y_test_np, test_preds, average='macro'),
        "recall": recall_score(y_test_np, test_preds, average='macro'),
        "mcc": matthews_corrcoef(y_test_np, test_preds),
        "auroc": roc_auc_score(y_test_np, test_probs_np, average='macro'),
        "accuracy": accuracy_score(y_test_np, test_preds)
    }

    # Save to TSV
    df_metrics = pd.DataFrame([metrics])
    df_metrics.to_csv(metric_path, sep="\t", index=False)
    print(f"Test metrics saved to {metric_path}")

    return model, metrics


In [6]:
# read all uniprot protein
uniprot_df = pd.read_csv("../data/uniprot_all_human_proteins.txt.gz", sep='\t')
long_proteins = uniprot_df[uniprot_df['Length']>1000]['Entry'].tolist()

# Train and evaluate the model

In [None]:

#for checkpoint in ESMs + ProtT5:
for checkpoint in ProtT5:
    for i in range(len(all_features)):
        print(checkpoint)
        print(all_features[i])
        
        #read_train_test_val embeds
        train_df = pd.read_csv(f"../data/splits/df/{all_features[i]}_train.tsv.gz", sep='\t')[['Entry', 'Sequence', all_features[i]]]
        val_df = pd.read_csv(f"../data/splits/df/{all_features[i]}_val.tsv.gz", sep='\t')[['Entry', 'Sequence', all_features[i]]]
        test_df = pd.read_csv(f"../data/splits/df/{all_features[i]}_test.tsv.gz", sep='\t')[['Entry', 'Sequence', all_features[i]]]
        
        train_embeds = [np.load(f'../data/embeddings/{f}_{checkpoint}.npy') for f in train_df.Entry.tolist() if f not in long_proteins]
        train_embeds = np.concatenate(train_embeds, axis=0)  

        val_embeds = [np.load(f'../data/embeddings/{f}_{checkpoint}.npy') for f in val_df.Entry.tolist() if f not in long_proteins]
        val_embeds = np.concatenate(val_embeds, axis=0)  

        test_embeds = [np.load(f'../data/embeddings/{f}_{checkpoint}.npy') for f in test_df.Entry.tolist() if f not in long_proteins]
        test_embeds = np.concatenate(test_embeds, axis=0)  

        # create labels
        if all_features[i] == 'Disulfide bond':
            labeler_func = build_labels_bonds
        else:
            labeler_func = build_labels_region
            
        train_labels = []
        for row_idx, row in train_df.iterrows():
            if row['Entry'] not in long_proteins:
                row_labels = labeler_func(row["Sequence"], row[all_features[i]], all_features_re[i])
                train_labels.append(row_labels)
        train_labels = np.concatenate(train_labels, axis=0)
        
        test_labels = []
        for row_idx, row in test_df.iterrows():
            if row['Entry'] not in long_proteins:
                row_labels = labeler_func(row["Sequence"], row[all_features[i]], all_features_re[i])
                test_labels.append(row_labels)
        test_labels = np.concatenate(test_labels, axis=0)
        
        val_labels = []
        for row_idx, row in val_df.iterrows():
            if row['Entry'] not in long_proteins:
                row_labels = labeler_func(row["Sequence"], row[all_features[i]], all_features_re[i])
                val_labels.append(row_labels)
        val_labels = np.concatenate(val_labels, axis=0)
        
        #train and evaluate the model
        model, metrics = train_predictor(
                    train_embeds, train_labels,
                    val_embeds, val_labels,
                    test_embeds, test_labels,
                    save_path=f"../res/models/lp_{all_features_re[i]}_{checkpoint}.pt",
                    metric_path=f"../res/metrics/lp_{all_features_re[i]}_{checkpoint}.tsv"
                )
        
        torch.cuda.empty_cache()
        print("***************")
    

prot_t5_xl_uniref50
Active site
Using device: cuda
Epoch 1/10 - Validation loss: 0.0018
Epoch 2/10 - Validation loss: 0.0019
Epoch 3/10 - Validation loss: 0.0020
Epoch 4/10 - Validation loss: 0.0021
Epoch 5/10 - Validation loss: 0.0029
Epoch 6/10 - Validation loss: 0.0025
Epoch 7/10 - Validation loss: 0.0024
Epoch 8/10 - Validation loss: 0.0028
Epoch 9/10 - Validation loss: 0.0026
Epoch 10/10 - Validation loss: 0.0041
Best model saved to ../res/models/lp_ACT_SITE_prot_t5_xl_uniref50.pt with val loss: 0.0018


  model.load_state_dict(torch.load(save_path))


Test metrics saved to ../res/metrics/lp_ACT_SITE_prot_t5_xl_uniref50.tsv
***************
prot_t5_xl_uniref50
Binding site
Using device: cuda
Epoch 1/10 - Validation loss: 0.0401
Epoch 2/10 - Validation loss: 0.0405
Epoch 3/10 - Validation loss: 0.0408
Epoch 4/10 - Validation loss: 0.0417
Epoch 5/10 - Validation loss: 0.0424
Epoch 6/10 - Validation loss: 0.0423
Epoch 7/10 - Validation loss: 0.0432
Epoch 8/10 - Validation loss: 0.0434
Epoch 9/10 - Validation loss: 0.0438
Epoch 10/10 - Validation loss: 0.0444
Best model saved to ../res/models/lp_BINDING_prot_t5_xl_uniref50.pt with val loss: 0.0401


  model.load_state_dict(torch.load(save_path))


Test metrics saved to ../res/metrics/lp_BINDING_prot_t5_xl_uniref50.tsv
***************
prot_t5_xl_uniref50
DNA binding
Using device: cuda
Epoch 1/10 - Validation loss: 0.1702
Epoch 2/10 - Validation loss: 0.1885
Epoch 3/10 - Validation loss: 0.1919
Epoch 4/10 - Validation loss: 0.2089
Epoch 5/10 - Validation loss: 0.2123
Epoch 6/10 - Validation loss: 0.2245
Epoch 7/10 - Validation loss: 0.2440
Epoch 8/10 - Validation loss: 0.2328
Epoch 9/10 - Validation loss: 0.2641
Epoch 10/10 - Validation loss: 0.2594
Best model saved to ../res/models/lp_DNA_BIND_prot_t5_xl_uniref50.pt with val loss: 0.1702


  model.load_state_dict(torch.load(save_path))


Test metrics saved to ../res/metrics/lp_DNA_BIND_prot_t5_xl_uniref50.tsv
***************
prot_t5_xl_uniref50
Topological domain
Using device: cuda
Epoch 1/10 - Validation loss: 0.1837
Epoch 2/10 - Validation loss: 0.1838
Epoch 3/10 - Validation loss: 0.1934
Epoch 4/10 - Validation loss: 0.1880
Epoch 5/10 - Validation loss: 0.1881
Epoch 6/10 - Validation loss: 0.1873
Epoch 7/10 - Validation loss: 0.1842
Epoch 8/10 - Validation loss: 0.1909
Epoch 9/10 - Validation loss: 0.1903
Epoch 10/10 - Validation loss: 0.1872
Best model saved to ../res/models/lp_TOPO_DOM_prot_t5_xl_uniref50.pt with val loss: 0.1837
Test metrics saved to ../res/metrics/lp_TOPO_DOM_prot_t5_xl_uniref50.tsv
***************
prot_t5_xl_uniref50
Transmembrane


  model.load_state_dict(torch.load(save_path))


Using device: cuda
Epoch 1/10 - Validation loss: 0.1367
