In [1]:
pip install fair-esm

Collecting fair-esm
  Downloading fair_esm-2.0.0-py3-none-any.whl.metadata (37 kB)
Downloading fair_esm-2.0.0-py3-none-any.whl (93 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/93.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fair-esm
Successfully installed fair-esm-2.0.0


In [2]:
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold
from torch.utils.data import DataLoader
import torch
import esm
from typing import Sequence, Tuple, List, Union

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [4]:
class FastaBatchedDataset(object):
    def __init__(self, labels, sequences):
        self.labels = list(labels)
        self.sequences = list(sequences)


    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.labels[idx], self.sequences[idx]

In [5]:
class BatchConverter(object):
    """Callable to convert an unprocessed (labels + strings) batch to a
    processed (labels + tensor) batch.
    """

    def __init__(self, alphabet, truncation_seq_length: int = None):
        self.alphabet = alphabet
        self.truncation_seq_length = truncation_seq_length

    def __call__(self, raw_batch: Sequence[Tuple[str, str]]):
        # RoBERTa uses an eos token, while ESM-1 does not.
        batch_size = len(raw_batch)
        batch_labels, seq_str_list = zip(*raw_batch)
        seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in seq_str_list]
        if self.truncation_seq_length:
            seq_encoded_list = [seq_str[:self.truncation_seq_length] for seq_str in seq_encoded_list]
        max_len = max(len(seq_encoded) for seq_encoded in seq_encoded_list)
        tokens = torch.empty(
            (
                batch_size,
                max_len + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos),
            ),
            dtype=torch.int64,
        )
        tokens.fill_(self.alphabet.padding_idx)
        labels = []

        for i, (label, seq_str, seq_encoded) in enumerate(
            zip(batch_labels, seq_str_list, seq_encoded_list)
        ):
            labels.append(label)
            if self.alphabet.prepend_bos:
                tokens[i, 0] = self.alphabet.cls_idx
            seq = torch.tensor(seq_encoded, dtype=torch.int64)
            tokens[
                i,
                int(self.alphabet.prepend_bos) : len(seq_encoded)
                + int(self.alphabet.prepend_bos),
            ] = seq
            if self.alphabet.append_eos:
                tokens[i, len(seq_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx

        return tokens, labels

In [6]:
class Alphabet(object):
    def __init__(
        self,
        standard_toks: Sequence[str],
        prepend_toks: Sequence[str] = ("<pad>", "<eos>", "<unk>"),
        append_toks: Sequence[str] = ("<cls>", "<mask>", "<sep>"),
        prepend_bos: bool = True,
        append_eos: bool = True,
    ):
        self.standard_toks = list(standard_toks)
        self.prepend_toks = list(prepend_toks)
        self.append_toks = list(append_toks)
        self.prepend_bos = prepend_bos
        self.append_eos = append_eos

        self.all_toks = list(self.prepend_toks)
        self.all_toks.extend(self.standard_toks)
        self.all_toks.extend(self.append_toks)

        self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)}
        self.unk_idx = self.tok_to_idx["<unk>"]
        self.padding_idx = self.get_idx("<pad>")
        self.cls_idx = self.get_idx("<cls>")
        self.mask_idx = self.get_idx("<mask>")
        self.eos_idx = self.get_idx("<eos>")
        self.all_special_tokens = ['<eos>', '<pad>', '<mask>']
        self.unique_no_split_tokens = self.all_toks

    def __len__(self):
        return len(self.all_toks)

    def get_idx(self, tok):
        return self.tok_to_idx.get(tok, self.unk_idx)

    def get_tok(self, ind):
        return self.all_toks[ind]

    def to_dict(self):
        return self.tok_to_idx.copy()

    def get_batch_converter(self):
        return BatchConverter(self)

    def _tokenize(self, text) -> str:
        return text.split()

    def tokenize(self, text: str, **kwargs) -> List[str]:
        tokens = []
        i = 0
        while i < len(text):
            if text[i] == '<':
                j = text.find('>', i)
                if j == -1:
                    raise ValueError(f"Unclosed special token starting at position {i}: {text[i:i+10]}")
                tokens.append(text[i:j+1])
                i = j + 1
            else:
                tokens.append(text[i])
                i += 1
        return tokens

    def encode(self, text):
        return [self.tok_to_idx[tok] for tok in self.tokenize(text)]

In [7]:
import torch
import torch.nn as nn
from typing import Union

from esm.modules import ESM1bLayerNorm, TransformerLayer

class ESM2SI(nn.Module):
    def __init__(
        self,
        num_layers: int = 33,
        embed_dim: int = 1280,
        attention_heads: int = 20,
        alphabet: any = None,
    ):
        super().__init__()
        self.num_layers = num_layers
        self.embed_dim = embed_dim
        self.attention_heads = attention_heads

        self.alphabet = alphabet
        self.alphabet_size = len(alphabet)
        self.padding_idx = alphabet.padding_idx
        self.cls_idx = alphabet.cls_idx
        self.eos_idx = alphabet.eos_idx
        self.prepend_bos = alphabet.prepend_bos
        self.append_eos = alphabet.append_eos

        self._init_submodules()

    def _init_submodules(self):
        self.embed_scale = 1
        self.embed_tokens = nn.Embedding(
            self.alphabet_size,
            self.embed_dim,
            padding_idx=self.padding_idx,
        )
        self.layers = nn.ModuleList(
            [
                TransformerLayer(
                    self.embed_dim,
                    4 * self.embed_dim,
                    self.attention_heads,
                    add_bias_kv=False,
                    use_esm1b_layer_norm=True,
                    use_rotary_embeddings=True,
                )
                for _ in range(self.num_layers)
            ]
        )
        self.emb_layer_norm_after = ESM1bLayerNorm(self.embed_dim)

    def forward(self, tokens, repr_layers=[], return_representation=True):
        assert tokens.ndim == 2
        padding_mask = tokens.eq(self.padding_idx)

        x = self.embed_scale * self.embed_tokens(tokens)

        # Apply padding mask
        if padding_mask is not None:
            x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))

        repr_layers = set(repr_layers)
        hidden_representations = {}
        if 0 in repr_layers:
            hidden_representations[0] = x

        x = x.transpose(0, 1)  # (B, T, E) -> (T, B, E)

        if not padding_mask.any():
            padding_mask = None

        for layer_idx, layer in enumerate(self.layers):
            x, _ = layer(
                x,
                self_attn_padding_mask=padding_mask,
                need_head_weights=False,
            )
            if (layer_idx + 1) in repr_layers:
                hidden_representations[layer_idx + 1] = x.transpose(0, 1)

        x = self.emb_layer_norm_after(x)
        x = x.transpose(0, 1)  # (T, B, E) -> (B, T, E)

        if (layer_idx + 1) in repr_layers:
            hidden_representations[layer_idx + 1] = x

        if return_representation:
            return {"representations": hidden_representations}
        else:
            return {}

In [8]:
class ESM2_Linear(nn.Module):
    def __init__(self, embed_dim, nodes):
        super().__init__()
        self.esm2 = ESM2SI(num_layers = 6,
              embed_dim = 128,
              attention_heads = 16,
              alphabet = alphabet)

        self.nodes = 40
        self.dropout3 = 0.2

        self.fc = nn.Linear(embed_dim, self.nodes)
        self.dropout3 = nn.Dropout(self.dropout3)

        self.output = nn.Linear(self.nodes, 1)

        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()

    def forward(self, tokens, repr_layer=6, return_representation=True):
        x = self.esm2(tokens, repr_layers=[repr_layer], return_representation=return_representation)

        # get representation
        x = x["representations"][repr_layer][:, 0]
        x_o = x.unsqueeze(2)
        x = self.flatten(x_o)

        # prediction head
        x = self.fc(x)
        x = self.relu(x)
        x = self.dropout3(x)
        x = self.output(x)

        return x


In [55]:
alphabet = Alphabet(standard_toks = 'AGCT')
model = ESM2_Linear(128, 40).to(device)
modelfile = '/content/ESM2SI_3.1_fiveSpeciesCao_6layers_16heads_128embedsize_4096batchToks_MLMLossMin.pkl'
state_dict = torch.load(modelfile, map_location=device)
new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.esm2.load_state_dict(new_state_dict, strict=False)

_IncompatibleKeys(missing_keys=[], unexpected_keys=['contact_head.regression.weight', 'contact_head.regression.bias', 'lm_head.weight', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'supervised_linear.weight', 'supervised_linear.bias'])

In [65]:
data = pd.read_csv('/content/Muscle_dataset.csv')
obj_col = 'te_log'
seq_type = 'utr'
batch_size = 32

In [57]:
from sklearn.model_selection import train_test_split

train_data, val_data = train_test_split(
    data,
    test_size=0.2,
    random_state=42,
    shuffle=True
)

In [68]:
train_dataset = FastaBatchedDataset(train_data.loc[:,obj_col], train_data[seq_type])
train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                        batch_size= batch_size,
                                        shuffle = True,
                                        collate_fn=alphabet.get_batch_converter())

val_dataset = FastaBatchedDataset(val_data.loc[:,obj_col], val_data[seq_type])
val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                        batch_size= batch_size,
                                        shuffle = False,
                                        collate_fn=alphabet.get_batch_converter())

In [69]:
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR

def setup_optimizer_and_scheduler(model, train_loader, epochs, batch_size, lr_for_esm=1e-5, lr_for_head=1e-3):
    """
    Set up optimizer and scheduler for ESM2 finetuning model.

    Args:
        model: The model containing esm2 encoder and regression head (like fc/output layers).
        train_loader: DataLoader for training set.
        epochs: Total number of epochs.
        batch_size: Training batch size.
        lr_for_esm: Learning rate for pretrained ESM2 encoder.
        lr_for_head: Learning rate for new prediction head.

    Returns:
        optimizer, scheduler
    """

    # Define scaled learning rates
    lr_esm = batch_size * lr_for_esm
    lr_head = batch_size * lr_for_head

    # Collect parameters
    esm2_params = list(model.esm2.parameters())
    head_params = []
    for name, param in model.named_parameters():
        if not name.startswith("esm2."):  # Everything not inside esm2 is considered head
            head_params.append(param)

    # Group parameters
    param_groups = [
        {'params': esm2_params, 'lr': lr_esm},
        {'params': head_params, 'lr': lr_head}
    ]

    # Initialize optimizer
    optimizer = AdamW(param_groups, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01)

    # Initialize scheduler
    scheduler = OneCycleLR(
        optimizer,
        max_lr=[lr_esm, lr_head],
        steps_per_epoch=len(train_loader),
        epochs=epochs,
    )

    return optimizer, scheduler


In [70]:
optimizer, scheduler = setup_optimizer_and_scheduler(model, train_dataloader, 10, 32)

In [71]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from scipy.stats import spearmanr
from torch.amp import autocast, GradScaler



scaler = GradScaler('cuda')


def train_one_epoch(model, loader, optimizer, criterion, device):

    model.train()
    total_loss = 0

    for inputs, labels in train_dataloader:
        inputs = inputs.to(device)
        labels = torch.tensor(labels, dtype=torch.float).to(device)

        with autocast('cuda'):
            optimizer.zero_grad()
            predictions = model(inputs)
            loss = criterion(predictions, labels.view(-1, 1))

            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1)
            scaler.step(optimizer)
            scaler.update()

            scheduler.step()


        total_loss += loss.item()

    return total_loss / len(loader)


def evaluate(model, loader, criterion, device):
    model.eval()
    all_preds = []
    all_labels = []
    total_loss = 0

    with torch.no_grad():
        for inputs, labels in train_dataloader:
            inputs = inputs.to(device)
            labels = torch.tensor(labels, dtype=torch.float).to(device)

            with autocast('cuda'):
                predictions = model(inputs)
                loss = criterion(predictions, labels.view(-1, 1))

            total_loss += loss.item()

            all_preds.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    all_preds = np.squeeze(all_preds)
    all_labels = np.squeeze(all_labels)


    metrics = {
        f"loss": total_loss / len(loader),
        f"spearman": spearmanr(all_labels, all_preds)[0],
    }

    return metrics

# def predict(model, dataframe, config,device):
#     model.eval()
#     dataset = ProteinDataset(dataframe)
#     loader = DataLoader(dataset, batch_size=8, num_workers=config['num_workers'],shuffle=False)


#     all_preds = []
#     with torch.no_grad():
#         for batch in loader:
#             batch = [b.to(device) if isinstance(b, torch.Tensor) else b for b in batch]
#             preds = model(*batch[:-1])  # Exclude labels
#             all_preds.extend(preds.cpu().numpy())

#     predicts = np.squeeze(all_preds)

#     return predicts


In [72]:
criterion = torch.nn.HuberLoss()
early_stop_patience = 5

In [73]:
best_val_metric = -float('inf')
epochs_since_improvement = 0

for epoch in range(10):
        train_loss = train_one_epoch(model, train_dataloader, optimizer, criterion, device)
        val_metric = evaluate(model,val_dataloader, criterion, device)
        print(f"[Epoch {epoch}] Train Loss: {train_loss:.4f} | Val loss:{val_metric['loss']:.4f} | Val Spearman: {val_metric['spearman']:.4f}")

        if val_metric['spearman'] > best_val_metric:
            best_val_metric = val_metric['spearman']
            # torch.save(model.state_dict(), os.path.join(save_dir, f"best_model_fold{fold + 1}.pt"))
            # print(f"New best Spearman: {best_val_metric:.4f} — Model saved.")
            epochs_since_improvement = 0
        else:
            epochs_since_improvement += 1

        if epochs_since_improvement >= early_stop_patience:
            print(f"Early stopping triggered after {epoch + 1} epochs.")
            break



print("Best_Spearman: {best_val_metric:.4f}")

[Epoch 0] Train Loss: 0.5028 | Val loss:1.6197 | Val Spearman: 0.6351
[Epoch 1] Train Loss: 0.5002 | Val loss:1.5547 | Val Spearman: 0.6703
[Epoch 2] Train Loss: 0.4819 | Val loss:1.5486 | Val Spearman: 0.6828
[Epoch 3] Train Loss: 0.4961 | Val loss:1.4674 | Val Spearman: 0.7293
[Epoch 4] Train Loss: 0.4575 | Val loss:1.3269 | Val Spearman: 0.7929
[Epoch 5] Train Loss: 0.4037 | Val loss:1.0055 | Val Spearman: 0.8298
[Epoch 6] Train Loss: 0.3449 | Val loss:1.0599 | Val Spearman: 0.8530
[Epoch 7] Train Loss: 0.3048 | Val loss:0.8797 | Val Spearman: 0.8641
[Epoch 8] Train Loss: 0.2838 | Val loss:0.8158 | Val Spearman: 0.8758
[Epoch 9] Train Loss: 0.2686 | Val loss:0.8039 | Val Spearman: 0.8768
Best_Spearman: {best_val_metric:.4f}
