# Low Rank Adapter (LoRA) for Post-Training ESM model

### load data

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0,2,3"
from Bio import SeqIO
 
def is_common_aa_sequence(sequence):
    common_amino_acids = set('ACDEFGHIKLMNPQRSTVWY')
    return all(aa in common_amino_acids for aa in sequence)

fasta_file = "/data/human_uniprot-reviewed_yes+taxonomy_9606.fasta"
 
seqs = []
for record in SeqIO.parse(fasta_file, "fasta"):
    seq = str(record.seq)
    if is_common_aa_sequence(seq): # make sure that the sequence contains only common amino acids
        seqs.append(str(record.seq))
    

### ESM model

In [None]:
# import os
# os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2,3"
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer, EsmForMaskedLM
from peft import LoraConfig, get_peft_model, TaskType
from datasets import Dataset
from torch.utils.data import random_split
from transformers import DataCollatorForLanguageModeling, TrainingArguments, Trainer, TrainerCallback


# Tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")  # ESM model path, you can down load from https://huggingface.co/facebook/esm2_t33_650M_UR50D
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  

# ESM model
ESMmodel = EsmForMaskedLM.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device) # # ESM model path, you can down load from https://huggingface.co/facebook/esm2_t33_650M_UR50D

  from .autonotebook import tqdm as notebook_tqdm
2025-04-07 23:27:03.825383: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


### dataset

In [3]:
sequences = seqs

# build dataset
dataset = Dataset.from_dict({"sequence": sequences})

def tokenize_function(examples):
    tokenized = tokenizer(
        examples["sequence"],
        truncation=True,
        max_length=1024,
        padding=True,  
    )
    tokenized["labels"] = tokenized["input_ids"].copy()
    return tokenized

dataset = dataset.map(tokenize_function, batched=True, remove_columns=["sequence"])

# training set and validation set
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.3
)


Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 16850/16850 [00:22<00:00, 761.46 examples/s]


### Lora-ESM

In [4]:
# LoRA config
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,  
    inference_mode=False,
    r=16, # adjust here to your desired rank
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["query", "value", "key"],  
)

# get lora-esm model
model = get_peft_model(ESMmodel, peft_config)
model.print_trainable_parameters()  

trainable params: 4,055,040 || all params: 656,411,574 || trainable%: 0.6178


### training

In [None]:
training_args = TrainingArguments(
    output_dir="./post_train_esm/", # checkpoint save path
    overwrite_output_dir=True,
    num_train_epochs=10,
    per_device_train_batch_size=3,
    per_device_eval_batch_size=32,
    logging_dir="./lora_logs", 
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_steps=100,  
    dataloader_num_workers=6,
)

class BestModelSaver(TrainerCallback):
    def __init__(self, save_path):
        self.best_val_loss = float("inf")
        self.save_path = save_path

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        if metrics is None or "eval_loss" not in metrics:
            return  

        val_loss = metrics["eval_loss"]
        if val_loss < self.best_val_loss:
            self.best_val_loss = val_loss
            print(f"new best val_loss: {val_loss:.4f}, save model...")
            trainer.save_model(self.save_path)

# trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    callbacks=[BestModelSaver("./step2_lora_post_train_ESM")],  # model save path
)

trainer.train()

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Epoch,Training Loss,Validation Loss
1,0.155,0.148787
2,0.1293,0.125569
3,0.1227,0.121066
4,0.1213,0.119086
5,0.1208,0.118392
6,0.1181,0.117492
7,0.1174,0.116245
8,0.118,0.114661
9,0.1161,0.114392
10,0.1154,0.114249


new best val_loss: 0.1488, save model...




new best val_loss: 0.1256, save model...




new best val_loss: 0.1211, save model...




new best val_loss: 0.1191, save model...




new best val_loss: 0.1184, save model...




new best val_loss: 0.1175, save model...




new best val_loss: 0.1162, save model...




new best val_loss: 0.1147, save model...




new best val_loss: 0.1144, save model...




new best val_loss: 0.1142, save model...


TrainOutput(global_step=14980, training_loss=0.24851770840276863, metrics={'train_runtime': 27271.5718, 'train_samples_per_second': 4.943, 'train_steps_per_second': 0.549, 'total_flos': 5.425247605506048e+17, 'train_loss': 0.24851770840276863, 'epoch': 10.0})