# LoRA the ESM model

In [None]:
import pandas as pd
import numpy as np
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
df = pd.read_csv("/data/human_virus_600k_seq_label_20aa_flow.csv")
seq_20aa = df['sequence'].to_list()
label_seq = df['label'].to_list()
label_20aa = [1 if v == 'human' else 0 for v in label_seq]


sequences = seq_20aa[:300]+seq_20aa[-300:] # downsample
labels = label_20aa[:300]+label_20aa[-300:]

In [None]:
from transformers import AutoTokenizer
import numpy as np  
from sklearn.model_selection import train_test_split  
from datasets import Dataset

tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")  # ESM model path, you can down load from https://huggingface.co/facebook/esm2_t33_650M_UR50D

train_pep, X_temp, train_labels, y_temp = train_test_split(sequences, labels, test_size=0.3, random_state=42)  
valid_pep, test_pep, valid_labels, test_labels = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)  

def tokenize_function(examples):
    return tokenizer(examples["sequence"], truncation=True, padding="max_length", max_length=128)


train_dataset = Dataset.from_dict({"sequence": train_pep, "label": train_labels})
train_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=["sequence"])
val_dataset = Dataset.from_dict({"sequence": valid_pep, "label": valid_labels})
val_dataset = val_dataset.map(tokenize_function, batched=True, remove_columns=["sequence"])
test_dataset = Dataset.from_dict({"sequence": test_pep, "label": test_labels})
test_dataset = test_dataset.map(tokenize_function, batched=True, remove_columns=["sequence"])

  from .autonotebook import tqdm as notebook_tqdm
Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 420/420 [00:00<00:00, 7679.00 examples/s]
Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 90/90 [00:00<00:00, 7896.23 examples/s]
Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 90/90 [00:00<00:00, 810.16 examples/s]


In [None]:
import torch
from transformers import AutoTokenizer, EsmForSequenceClassification, TrainingArguments, Trainer, TrainerCallback, EsmModel
from peft import LoraConfig, get_peft_model, TaskType, PeftModel
import numpy as np
from sklearn.model_selection import train_test_split
import torch.nn as nn


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ESMClassificationModel(nn.Module):
    def __init__(self, model_name, num_labels=2):
        super(ESMClassificationModel, self).__init__()
        self.esm = EsmModel.from_pretrained(model_name)
        hidden_size = self.esm.config.hidden_size  
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.BatchNorm1d(hidden_size // 2),
            nn.ReLU(),
            # nn.Dropout(0.1),
            nn.Linear(hidden_size // 2, hidden_size // 4),
            nn.BatchNorm1d(hidden_size // 4),
            nn.ReLU(),
            # nn.Dropout(0.1),
            nn.Linear(hidden_size // 4, num_labels),
        )
        self.config = self.esm.config
        self.loss_fn = nn.CrossEntropyLoss()  

    def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None, labels=None, **kwargs):
        if inputs_embeds is not None:
            outputs = self.esm(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
        else:
            outputs = self.esm(input_ids=input_ids, attention_mask=attention_mask)
        
        pooled_output = torch.mean(outputs.last_hidden_state, dim=1) 
        logits = self.classifier(pooled_output)
        loss = None
        if labels is not None:
            loss = self.loss_fn(logits, labels)

        return {"loss": loss, "logits": logits}

num_labels = 2
ESMmodel = ESMClassificationModel("./post_train_esm/checkpoint-14980", num_labels).to(device) # load lora-post train model, you can make a lora model from step2_LoRA_Post_train_ESM_model.ipynb or just download our model


lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,  
    r=16,  
    lora_alpha=16,  
    lora_dropout=0.1,  
    target_modules=["query", "value", 'key'],  
)

ESMmodel = get_peft_model(ESMmodel, lora_config)
for param in ESMmodel.classifier.parameters():
    param.requires_grad = True

ESMmodel.print_trainable_parameters()


Some weights of EsmModel were not initialized from the model checkpoint at /home/zhangxin/TCR/esm/ and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Loading adapter weights from ./post_train_esm/checkpoint-14980 led to unexpected keys not found in the model: esm.encoder.layer.0.attention.self.key.lora_A.default.weight, esm.encoder.layer.0.attention.self.key.lora_B.default.weight, esm.encoder.layer.0.attention.self.query.lora_A.default.weight, esm.encoder.layer.0.attention.self.query.lora_B.default.weight, esm.encoder.layer.0.attention.self.value.lora_A.default.weight, esm.encoder.layer.0.attention.self.value.lora_B.default.weight, esm.encoder.layer.1.attention.self.key.lora_A.default.weight, esm.encoder.layer.1.attention.self.key.lora_B.default.weight, esm.encoder.layer.1.attention.self.query.lora_A.default.weight, esm.encoder.layer.1.attention.self.query

trainable params: 6,110,084 || all params: 658,464,025 || trainable%: 0.9279




In [None]:
from sklearn.metrics import roc_auc_score

training_args = TrainingArguments(
    output_dir="./lora_esm_classification",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=3,
    per_device_train_batch_size=24,
    per_device_eval_batch_size=64,
    save_total_limit=2,
    logging_dir="./logs",
    logging_steps=30,
    learning_rate=2e-5,
    weight_decay=0.01,
    load_best_model_at_end=True,
    push_to_hub=False,
    fp16=True if torch.cuda.is_available() else False,
)


class SaveBestModelCallback(TrainerCallback):
    def __init__(self):
        self.best_loss = float("inf")
        self.best_val_probs = None 
        self.trainer = None
        self.val_labels = None
        self.auc_roc_record = []
        self.val_loss_record = []
        self.train_loss_record = []

    def on_evaluate(self, args, state, control, **kwargs):
        metrics = kwargs.get('metrics')
        eval_loss = metrics.get('eval_loss')
        
        if self.trainer is None :
            print("Trainer or metrics are None, skipping evaluation.")
            return
        
        else:
            trainer = self.trainer


        predictions = trainer.predict(trainer.eval_dataset)
        val_preds = predictions.predictions
        val_probs = torch.sigmoid(torch.tensor(val_preds)).numpy()[:,1]
        

        auc_roc = roc_auc_score(self.val_labels, val_probs)
        print(f"########## AUC ROC: {auc_roc}")
        self.auc_roc_record.append(auc_roc)
        self.val_loss_record.append(eval_loss)


        if eval_loss < self.best_loss:
            print(f"New best model found with loss: {eval_loss}, save at{args.output_dir}")
            self.best_loss = eval_loss
            self.best_val_probs = val_probs
            trainer.save_model(args.output_dir)

    def on_log(self, args, state, control, **kwargs):
        logs = kwargs.get("logs", {})
        if "loss" in logs:
            self.last_train_loss = logs["loss"]
            self.last_train_loss = logs["loss"]
            # print(f"Train Loss: {self.last_train_loss}")
            self.train_loss_record.append(self.last_train_loss)
            
callback = SaveBestModelCallback()

trainer = Trainer(
    model=ESMmodel,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    callbacks=[callback]  
)

callback.trainer = trainer
callback.val_labels = valid_labels

trainer.train()


ESMmodel.save_pretrained("./saved_lora_esm_cls")
tokenizer.save_pretrained("./saved_lora_esm_cls")



No label_names provided for model class `PeftModelForSequenceClassification`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Epoch,Training Loss,Validation Loss
1,No log,0.677935
2,0.566700,0.615915
3,0.566700,0.512141


########## AUC ROC: 0.7635000000000001
New best model found with loss: 0.6779350638389587, save at./lora_esm_classification
########## AUC ROC: 0.8734999999999999
New best model found with loss: 0.6159153580665588, save at./lora_esm_classification
########## AUC ROC: 0.9129999999999999
New best model found with loss: 0.5121405720710754, save at./lora_esm_classification


('./saved_lora_esm_cls/tokenizer_config.json',
 './saved_lora_esm_cls/special_tokens_map.json',
 './saved_lora_esm_cls/vocab.txt',
 './saved_lora_esm_cls/added_tokens.json')