In [1]:
pip install fair-esm



In [2]:
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold
from torch.utils.data import DataLoader
import torch
import esm
from typing import Sequence, Tuple, List, Union
import os

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


## Model

In [5]:
class FastaBatchedDataset(object):
    def __init__(self, labels, sequences):
        self.labels = list(labels)
        self.sequences = list(sequences)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.labels[idx], self.sequences[idx]

In [6]:
class BatchConverter(object):
    """Callable to convert an unprocessed (labels + strings) batch to a
    processed (labels + tensor) batch.
    """

    def __init__(self, alphabet, truncation_seq_length: int = None):
        self.alphabet = alphabet
        self.truncation_seq_length = truncation_seq_length

    def __call__(self, raw_batch: Sequence[Tuple[str, str]]):
        # RoBERTa uses an eos token, while ESM-1 does not.
        batch_size = len(raw_batch)
        batch_labels, seq_str_list = zip(*raw_batch)
        seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in seq_str_list]
        if self.truncation_seq_length:
            seq_encoded_list = [seq_str[:self.truncation_seq_length] for seq_str in seq_encoded_list]
        max_len = max(len(seq_encoded) for seq_encoded in seq_encoded_list)
        tokens = torch.empty(
            (
                batch_size,
                max_len + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos),
            ),
            dtype=torch.int64,
        )
        tokens.fill_(self.alphabet.padding_idx)
        labels = []

        for i, (label, seq_str, seq_encoded) in enumerate(
            zip(batch_labels, seq_str_list, seq_encoded_list)
        ):
            labels.append(label)
            if self.alphabet.prepend_bos:
                tokens[i, 0] = self.alphabet.cls_idx
            seq = torch.tensor(seq_encoded, dtype=torch.int64)
            tokens[
                i,
                int(self.alphabet.prepend_bos) : len(seq_encoded)
                + int(self.alphabet.prepend_bos),
            ] = seq
            if self.alphabet.append_eos:
                tokens[i, len(seq_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx

        return tokens, labels

In [7]:
class Alphabet(object):
    def __init__(
        self,
        standard_toks: Sequence[str],
        prepend_toks: Sequence[str] = ("<pad>", "<eos>", "<unk>"),
        append_toks: Sequence[str] = ("<cls>", "<mask>", "<sep>"),
        prepend_bos: bool = True,
        append_eos: bool = True,
    ):
        self.standard_toks = list(standard_toks)
        self.prepend_toks = list(prepend_toks)
        self.append_toks = list(append_toks)
        self.prepend_bos = prepend_bos
        self.append_eos = append_eos

        self.all_toks = list(self.prepend_toks)
        self.all_toks.extend(self.standard_toks)
        self.all_toks.extend(self.append_toks)

        self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)}
        self.unk_idx = self.tok_to_idx["<unk>"]
        self.padding_idx = self.get_idx("<pad>")
        self.cls_idx = self.get_idx("<cls>")
        self.mask_idx = self.get_idx("<mask>")
        self.eos_idx = self.get_idx("<eos>")
        self.all_special_tokens = ['<eos>', '<pad>', '<mask>']
        self.unique_no_split_tokens = self.all_toks

    def __len__(self):
        return len(self.all_toks)

    def get_idx(self, tok):
        return self.tok_to_idx.get(tok, self.unk_idx)

    def get_tok(self, ind):
        return self.all_toks[ind]

    def to_dict(self):
        return self.tok_to_idx.copy()

    def get_batch_converter(self):
        return BatchConverter(self)

    def _tokenize(self, text) -> str:
        return text.split()

    def tokenize(self, text: str, **kwargs) -> List[str]:
        tokens = []
        i = 0
        while i < len(text):
            if text[i] == '<':
                j = text.find('>', i)
                if j == -1:
                    raise ValueError(f"Unclosed special token starting at position {i}: {text[i:i+10]}")
                tokens.append(text[i:j+1])
                i = j + 1
            else:
                tokens.append(text[i])
                i += 1
        return tokens

    def encode(self, text):
        return [self.tok_to_idx[tok] for tok in self.tokenize(text)]

In [8]:
import torch
import torch.nn as nn
from typing import Union

from esm.modules import ESM1bLayerNorm, TransformerLayer, RobertaLMHead

class ESM2SI(nn.Module):
    def __init__(
        self,
        num_layers: int = 33,
        embed_dim: int = 1280,
        attention_heads: int = 20,
        alphabet: any = None,
    ):
        super().__init__()
        self.num_layers = num_layers
        self.embed_dim = embed_dim
        self.attention_heads = attention_heads

        self.alphabet = alphabet
        self.alphabet_size = len(alphabet)
        self.padding_idx = alphabet.padding_idx
        self.cls_idx = alphabet.cls_idx
        self.eos_idx = alphabet.eos_idx
        self.prepend_bos = alphabet.prepend_bos
        self.append_eos = alphabet.append_eos

        self._init_submodules()

    def _init_submodules(self):
        self.embed_scale = 1
        self.embed_tokens = nn.Embedding(
            self.alphabet_size,
            self.embed_dim,
            padding_idx=self.padding_idx,
        )
        self.layers = nn.ModuleList(
            [
                TransformerLayer(
                    self.embed_dim,
                    4 * self.embed_dim,
                    self.attention_heads,
                    add_bias_kv=False,
                    use_esm1b_layer_norm=True,
                    use_rotary_embeddings=True,
                )
                for _ in range(self.num_layers)
            ]
        )
        self.emb_layer_norm_after = ESM1bLayerNorm(self.embed_dim)

        self.lm_head = RobertaLMHead(
            embed_dim=self.embed_dim,
            output_dim=self.alphabet_size,
            weight=self.embed_tokens.weight,
        )

    def forward(self, tokens, repr_layers=[], return_representation=True):
        assert tokens.ndim == 2
        padding_mask = tokens.eq(self.padding_idx)

        x = self.embed_scale * self.embed_tokens(tokens)

        # Apply padding mask
        if padding_mask is not None:
            x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))

        repr_layers = set(repr_layers)
        hidden_representations = {}
        if 0 in repr_layers:
            hidden_representations[0] = x

        x = x.transpose(0, 1)  # (B, T, E) -> (T, B, E)

        if not padding_mask.any():
            padding_mask = None

        for layer_idx, layer in enumerate(self.layers):
            x, _ = layer(
                x,
                self_attn_padding_mask=padding_mask,
                need_head_weights=False,
            )
            if (layer_idx + 1) in repr_layers:
                hidden_representations[layer_idx + 1] = x.transpose(0, 1)

        x = self.emb_layer_norm_after(x)
        x = x.transpose(0, 1)  # (T, B, E) -> (B, T, E)

        if (layer_idx + 1) in repr_layers:
            hidden_representations[layer_idx + 1] = x
        x = self.lm_head(x)

        if return_representation:
            return {"logits": x,"representations": hidden_representations}
        else:
            return {}

In [9]:
class ESM2_Linear(nn.Module):
    def __init__(self, embed_dim, nodes, projection_dim=64):
        super().__init__()
        self.esm2 = ESM2SI(num_layers = 6,
              embed_dim = 128,
              attention_heads = 16,
              alphabet = alphabet)

        self.nodes = nodes
        self.dropout3 = 0.2

        self.fc = nn.Linear(embed_dim, self.nodes)
        self.dropout3 = nn.Dropout(self.dropout3)

        self.output = nn.Linear(self.nodes, 1)

        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()

    def forward(self, tokens, repr_layer=6, return_representation=True):
        x = self.esm2(tokens, repr_layers=[repr_layer], return_representation=return_representation)
        logits = x["logits"]

        # get representation
        x = x["representations"][repr_layer][:, 0]
        x_o = x.unsqueeze(2)
        x = self.flatten(x_o)


        # prediction head
        x = self.fc(x)
        x = self.relu(x)
        x = self.dropout3(x)
        x = self.output(x)

        return x, logits


## Loss function

In [10]:
import torch
import torch.nn.functional as F
from scipy.stats import spearmanr
import numpy as np

def con_fit_loss(pred, labels, logits_new, logits_old=None, lambda_kl=0.0):

    # 1. Pairwise Ranking Loss
    pairwise_loss = 0.0
    n = pred.size(0)

    for i in range(n):
        for j in range(n):
            if labels[i] > labels[j]:
                diff = pred[i] - pred[j]
                pairwise_loss += F.softplus(-diff)  # log(1+exp(-diff))

    pairwise_loss = pairwise_loss / (n * n)

    # 2. KL Divergence Loss
    if logits_old is not None and lambda_kl > 0.0:
        logits_new_cls = logits_new[:, 0, :]  # (batch, vocab_size)
        logits_old_cls = logits_old[:, 0, :]

        p_new = F.log_softmax(logits_new_cls, dim=-1)
        p_old = F.softmax(logits_old_cls, dim=-1)
        kl_loss = F.kl_div(p_new, p_old, reduction='batchmean')
    else:
        kl_loss = torch.tensor(0.0, device=pred.device)

    total_loss = pairwise_loss + lambda_kl * kl_loss

    return total_loss


def pairwise_ranking_loss(preds, targets, margin=1.0):
    """
    preds: Tensor of shape (B, 1) – predicted DMS scores
    targets: Tensor of shape (B, 1) – true DMS scores
    """
    preds = preds.squeeze()
    targets = targets.squeeze()
    diff_target = targets.unsqueeze(0) - targets.unsqueeze(1)
    diff_pred = preds.unsqueeze(0) - preds.unsqueeze(1)
    target_sign = torch.sign(diff_target)
    loss_matrix = F.relu(-target_sign * diff_pred + margin)
    mask = torch.eye(len(targets), dtype=torch.bool, device=targets.device)
    loss_matrix = loss_matrix[~mask].view(len(targets), -1)
    return loss_matrix.mean()

## Model Config

In [71]:
alphabet = Alphabet(standard_toks = 'AGCT')
modelfile = '/content/ESM2SI_3.1_fiveSpeciesCao_6layers_16heads_128embedsize_4096batchToks_MLMLossMin.pkl'
obj_col = 'te_log'
seq_type = 'utr'
config = {
    "num_epochs": 20,
    "early_stop_patience": 5,
    "batch_size": 32,
    "learning_rate_esm": 1e-5,
    "learning_rate_head": 1e-3,
    "nodes": 40,
    "save_dir": "/content/saved_models"
}

## Read in Data

In [64]:
data = pd.read_csv('/content/pc3_sequence.csv')
dataset = FastaBatchedDataset(data.loc[:,obj_col], data[seq_type])

## Setting optimizer and scheduler

In [74]:
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR

def setup_optimizer_and_scheduler(model, train_loader, epochs, batch_size, lr_for_esm=1e-6, lr_for_head=1e-3):
    """
    Set up optimizer and scheduler for ESM2 finetuning model.

    Args:
        model: The model containing esm2 encoder and regression head (like fc/output layers).
        train_loader: DataLoader for training set.
        epochs: Total number of epochs.
        batch_size: Training batch size.
        lr_for_esm: Learning rate for pretrained ESM2 encoder.
        lr_for_head: Learning rate for new prediction head.

    Returns:
        optimizer, scheduler
    """

    # Define scaled learning rates
    lr_esm = batch_size * lr_for_esm
    lr_head = batch_size * lr_for_head

    # Collect parameters
    esm2_params = list(model.esm2.parameters())
    head_params = []
    for name, param in model.named_parameters():
        if not name.startswith("esm2."):  # Everything not inside esm2 is considered head
            head_params.append(param)

    # Group parameters
    param_groups = [
        {'params': esm2_params, 'lr': lr_esm},
        {'params': head_params, 'lr': lr_head}
    ]

    # Initialize optimizer
    optimizer = AdamW(param_groups, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01)

    # Initialize scheduler
    scheduler = OneCycleLR(
        optimizer,
        max_lr=[lr_esm, lr_head],
        steps_per_epoch=len(train_loader),
        epochs=epochs,
    )

    return optimizer, scheduler


In [75]:
old_model = ESM2_Linear(128, config["nodes"]).to(device)
state_dict = torch.load(modelfile, map_location=device)
new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
old_model.esm2.load_state_dict(new_state_dict, strict=False)
for param in old_model.parameters():
    param.requires_grad = False

## Training

In [76]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from scipy.stats import spearmanr
from torch.amp import autocast, GradScaler



scaler = GradScaler('cuda')


def train_one_epoch(model, loader, optimizer, scheduler, criterion, device):

    model.train()
    total_loss = 0

    for inputs, labels in loader:
        inputs = inputs.to(device)
        labels = torch.tensor(labels, dtype=torch.float).to(device)

        with autocast('cuda'):
            optimizer.zero_grad()
            predictions, logits = model(inputs)
            if criterion == "btr":
                with torch.no_grad():
                    _, logits_old = old_model(inputs)
                loss = con_fit_loss(predictions, labels, logits, logits_old, lambda_kl=0.1)
            elif criterion == 'pairwise':
                loss = pairwise_ranking_loss(predictions, labels)
            else:
                loss = criterion(predictions, labels.view(-1, 1))

            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1)
            scaler.step(optimizer)
            scaler.update()

            scheduler.step()


        total_loss += loss.item()

    return total_loss / len(loader)


def evaluate(model, loader, criterion, device):
    model.eval()
    all_preds = []
    all_labels = []
    total_loss = 0

    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(device)
            labels = torch.tensor(labels, dtype=torch.float).to(device)

            with autocast('cuda'):
                predictions, logits = model(inputs)
                if criterion == "btr":
                    with torch.no_grad():
                        _, logits_old = old_model(inputs)
                    loss = con_fit_loss(predictions, labels, logits, logits_old, lambda_kl=0.1)
                elif criterion == 'pairwise':
                    loss = pairwise_ranking_loss(predictions, labels)
                else:
                    loss = criterion(predictions, labels.view(-1, 1))

            total_loss += loss.item()

            all_preds.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    all_preds = np.squeeze(all_preds)
    all_labels = np.squeeze(all_labels)


    metrics = {
        f"loss": total_loss / len(loader),
        f"spearman": spearmanr(all_labels, all_preds)[0],
    }

    return metrics


In [78]:
from sklearn.model_selection import KFold
from sklearn.model_selection import GroupKFold

n_splits = 5
kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
gkf = GroupKFold(n_splits=n_splits)
group_labels = data['external_gene_id']
fold_spearman_scores = []


# for fold, (train_idx, val_idx) in enumerate(gkf.split(dataset, groups=group_labels)):
for fold, (train_idx, val_idx) in enumerate(kf.split(dataset)):
    print(f"\n=== Fold {fold+1}/{n_splits} ===")

    # new dataloader
    train_subset = torch.utils.data.Subset(dataset, train_idx)
    val_subset = torch.utils.data.Subset(dataset, val_idx)
    train_dataloader = DataLoader(train_subset,
                                        batch_size= config["batch_size"],
                                        shuffle = True,
                                        collate_fn=alphabet.get_batch_converter())
    val_dataloader = DataLoader(val_subset,
                                        batch_size= config["batch_size"],
                                        shuffle = False,
                                        collate_fn=alphabet.get_batch_converter())

    # initialize model/optimizer
    model = ESM2_Linear(128, config["nodes"]).to(device)
    state_dict = torch.load(modelfile, map_location=device)
    new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
    model.esm2.load_state_dict(new_state_dict, strict=False)
    optimizer, scheduler = setup_optimizer_and_scheduler(model, train_dataloader, config["num_epochs"], config["batch_size"], config["learning_rate_esm"], config["learning_rate_head"])
    # criterion = nn.MSELoss()
    # criterion = torch.nn.HuberLoss()
    criterion = "pairwise"
    # criterion = "btr"

    best_val_metric = -float('inf')
    epochs_since_improvement = 0

    for epoch in range(config["num_epochs"]):
        train_loss = train_one_epoch(model, train_dataloader, optimizer, scheduler, criterion, device)
        val_metric = evaluate(model, val_dataloader, criterion, device)
        print(f"[Fold {fold+1} | Epoch {epoch}] Train Loss: {train_loss:.4f} | Val Loss: {val_metric['loss']:.4f} | Val Spearman: {val_metric['spearman']:.4f}")

        if val_metric['spearman'] > best_val_metric:
            best_val_metric = val_metric['spearman']
            epochs_since_improvement = 0
            # save model
            torch.save(model.state_dict(), os.path.join(config["save_dir"], f"best_model_fold{fold+1}.pt"))
            print(f"New best Spearman: {best_val_metric:.4f} — Model saved.")
        else:
            epochs_since_improvement += 1

        if epochs_since_improvement >= config["early_stop_patience"]:
            print(f"Early stopping triggered in fold {fold+1} after {epoch+1} epochs.")
            break

    print(f"Best Spearman for fold {fold+1}: {best_val_metric:.4f}")
    fold_spearman_scores.append(best_val_metric)

# avergae spearman
mean_spearman = sum(fold_spearman_scores) / len(fold_spearman_scores)
print(f"\n5-Fold Cross-Validation Mean Spearman: {mean_spearman:.4f}")


=== Fold 1/5 ===
[Fold 1 | Epoch 0] Train Loss: 0.9317 | Val Loss: 1.0317 | Val Spearman: 0.2062
New best Spearman: 0.2062 — Model saved.
[Fold 1 | Epoch 1] Train Loss: 0.8863 | Val Loss: 1.0078 | Val Spearman: 0.2739
New best Spearman: 0.2739 — Model saved.
[Fold 1 | Epoch 2] Train Loss: 0.8750 | Val Loss: 1.0148 | Val Spearman: 0.3000
New best Spearman: 0.3000 — Model saved.
[Fold 1 | Epoch 3] Train Loss: 0.8472 | Val Loss: 1.0180 | Val Spearman: 0.3360
New best Spearman: 0.3360 — Model saved.
[Fold 1 | Epoch 4] Train Loss: 0.8139 | Val Loss: 1.0236 | Val Spearman: 0.3796
New best Spearman: 0.3796 — Model saved.
[Fold 1 | Epoch 5] Train Loss: 0.7692 | Val Loss: 1.0206 | Val Spearman: 0.4190
New best Spearman: 0.4190 — Model saved.
[Fold 1 | Epoch 6] Train Loss: 0.7574 | Val Loss: 1.0300 | Val Spearman: 0.4327
New best Spearman: 0.4327 — Model saved.
[Fold 1 | Epoch 7] Train Loss: 0.6801 | Val Loss: 1.0513 | Val Spearman: 0.4743
New best Spearman: 0.4743 — Model saved.
[Fold 1 | Epoc

## Testing

In [70]:
from scipy.stats import rankdata
from scipy.stats import spearmanr

# test_data = pd.read_csv('/content/Experimental_data_revised_label.csv')
# obj_col = 'label'
# seq_type = 'utr_100'

test_data = pd.read_csv('/content/HEK_sequence.csv')
obj_col = 'te_log'
seq_type = 'utr'

test_dataset = FastaBatchedDataset(test_data.loc[:, obj_col], test_data[seq_type])
test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=config["batch_size"],
    shuffle=False,
    collate_fn=alphabet.get_batch_converter()
)

all_preds = []

models = []
for fold in range(5):
    model =  ESM2_Linear(128, config["nodes"]).to(device)
    model_path = os.path.join(config["save_dir"], f"best_model_fold{fold+1}.pt")
    model.load_state_dict(torch.load(model_path, map_location='cuda'))
    model.eval()
    models.append(model)

with torch.no_grad():
    for fold_model in models:
        preds = []
        for inputs, labels in test_dataloader:
            inputs = inputs.to(device)
            labels = torch.tensor(labels, dtype=torch.float).to(device)

            with autocast('cuda'):
                outputs = model(inputs)[0]

            preds.append(outputs.cpu())
        preds = torch.cat(preds, dim=0)
        all_preds.append(preds)



ensemble_preds = torch.stack(all_preds, dim=0).mean(dim=0)

y_true = test_data[obj_col].values
y_pred = ensemble_preds.numpy().squeeze()

from scipy.stats import spearmanr
spearman_corr = spearmanr(y_true, y_pred).correlation

print(f"Test Spearman of Ensemble Model: {spearman_corr:.4f}")


Test Spearman of Ensemble Model: 0.3517
Test Spearman of Ensemble Model (ranked): 0.3517
