### Finetuning ESM2

requirements: transformers, datasets, peft, torch

In [1]:
# # log into huggingface for pushing the model to model hub
# from huggingface_hub import notebook_login
# notebook_login()

In [2]:
import torch
from datasets import load_dataset

# add paths to datasets on your machine
dataset = load_dataset("csv", data_files={"train": "tox_train.tsv", 
                                          "validation":"tox_validation.tsv",
                                          "test":"tox_test.tsv"}, delimiter="\t")
dataset

DatasetDict({
    train: Dataset({
        features: ['protein_id', 'seq_len', 'sequence', 'label'],
        num_rows: 10368
    })
    validation: Dataset({
        features: ['protein_id', 'seq_len', 'sequence', 'label'],
        num_rows: 1296
    })
    test: Dataset({
        features: ['protein_id', 'seq_len', 'sequence', 'label'],
        num_rows: 1297
    })
})

### Tokenization

In [3]:
# dataset['train'][0]
dataset['train'].features
# dataset.save_to_disk("datasetdict")

{'protein_id': Value(dtype='string', id=None),
 'seq_len': Value(dtype='int64', id=None),
 'sequence': Value(dtype='string', id=None),
 'label': Value(dtype='int64', id=None)}

In [4]:
# Use the same tokenizer that was used for pre-training
from transformers import AutoTokenizer

# checkpoint = "facebook/esm2_t12_35M_UR50D"
checkpoint = "facebook/esm2_t33_650M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
def tokenize_function(examples):
    return tokenizer(examples['sequence']) # leave out padding=True, pad dynamically when batches are created

tokenized_dataset = dataset.map(tokenize_function, batched=True) # batched -> speedup tokenization
# model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2, problem_type='single_label_classification')

In [5]:
tokenizer.decode(tokenized_dataset['train'][0]['input_ids'])

'<cls> R F R F R V K C S K G T Y <eos>'

In [6]:
# tokenized_dataset['train'][0]
# Columns to keep
keep_columns = ['input_ids', 'attention_mask', 'label']

for split in tokenized_dataset:
    tokenized_dataset[split] = tokenized_dataset[split] \
        .remove_columns([col for col in tokenized_dataset[split].column_names if col not in keep_columns])

for split in tokenized_dataset:
    # Rename the 'label' column to 'labels'
    tokenized_dataset[split] = tokenized_dataset[split].rename_column('label', 'labels')
tokenized_dataset

DatasetDict({
    train: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 10368
    })
    validation: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 1296
    })
    test: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 1297
    })
})

In [7]:
# Dynamic padding
# from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=data_collator)

# for step, batch in enumerate(train_dataloader):
#     print(batch['input_ids'].shape)
#     if step > 5:
#         break

In [8]:
from transformers import AutoConfig

config = AutoConfig.from_pretrained(checkpoint)
print(config)
type(config)

EsmConfig {
  "_name_or_path": "facebook/esm2_t33_650M_UR50D",
  "architectures": [
    "EsmForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.0,
  "classifier_dropout": null,
  "emb_layer_norm_before": false,
  "esmfold_config": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 1280,
  "initializer_range": 0.02,
  "intermediate_size": 5120,
  "is_folding_model": false,
  "layer_norm_eps": 1e-05,
  "mask_token_id": 32,
  "max_position_embeddings": 1026,
  "model_type": "esm",
  "num_attention_heads": 20,
  "num_hidden_layers": 33,
  "pad_token_id": 1,
  "position_embedding_type": "rotary",
  "token_dropout": true,
  "torch_dtype": "float32",
  "transformers_version": "4.45.2",
  "use_cache": true,
  "vocab_list": null,
  "vocab_size": 33
}



transformers.models.esm.configuration_esm.EsmConfig

#### Model loading

In [9]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2, problem_type="single_label_classification")
print(type(model))
print(isinstance(model, torch.nn.Module))

Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


<class 'transformers.models.esm.modeling_esm.EsmForSequenceClassification'>
True


These warnings are telling us that the model is discarding some weights that it used for language modelling (the lm_head) and adding some weights for sequence classification (the classifier). This is exactly what we expect when we want to fine-tune a language model on a sequence classification task!

## PEFT

In [10]:
from peft import LoraConfig, TaskType

peft_config = LoraConfig(
    task_type=TaskType.SEQ_CLS, 
    inference_mode=False,
    r=16, 
    lora_alpha=32, 
    lora_dropout=0.05,
    # target_modules=["query",
    #                 "key",
    #                 "value"]
    target_modules="all-linear"
    )
[e.value for e in TaskType]

['SEQ_CLS',
 'SEQ_2_SEQ_LM',
 'CAUSAL_LM',
 'TOKEN_CLS',
 'QUESTION_ANS',
 'FEATURE_EXTRACTION']

In [11]:
from peft import get_peft_model
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

trainable params: 13,879,410 || all params: 666,297,385 || trainable%: 2.0831


In [10]:
# Metric for evaluation
import numpy as np
from sklearn.metrics import (
    confusion_matrix,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    matthews_corrcoef,
)

# metric = load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    # Calculate confusion matrix values: TN, FP, FN, TP
    tn, fp, fn, tp = confusion_matrix(labels, predictions).ravel()
    # Calculate other metrics
    accuracy = accuracy_score(labels, predictions)
    precision = precision_score(labels, predictions)
    recall = recall_score(labels, predictions)
    f1 = f1_score(labels, predictions)
    try:
        auc = roc_auc_score(labels, predictions)
    except ValueError:
        auc = "N/A"  # Handle the case when there are no positive labels

    mcc = matthews_corrcoef(labels, predictions)
    # Create a pandas dataframe to store the results
    metrics = {
        "TP": tp,
        "TN": tn,
        "FP": fp,
        "FN": fn,
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1-score": f1,
        "AUC": auc,
        "MCC": mcc,
    }
    return metrics

In [11]:
# Trainer arguments
from transformers import TrainingArguments

batch_size = 16
model_name = checkpoint.split("/")[-1]

args = TrainingArguments(
    f"{model_name}-finetuned_toxi",
    # evaluation_strategy = "epoch",
    evaluation_strategy="steps",
    eval_steps=400,
    save_strategy ="steps",
    save_steps=400,
    # save_strategy="epoch",
    learning_rate=2e-5,
    # learning_rate=2e-4,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=4,
    weight_decay=0.01,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="MCC",
    push_to_hub=False,
    report_to="wandb"
)



In [12]:
# example of callback for tracking memory
import torch
from transformers import TrainerCallback

class CUDAMemoryLogger(TrainerCallback):
    def __init__(self, log_every_n_steps=100):
        self.log_every_n_steps = log_every_n_steps
    
    def on_step_end(self, args, state, control, model, tokenizer, logs=None, **kwargs):
        # Log CUDA memory usage every 'log_every_n_steps' steps
        if state.global_step % self.log_every_n_steps == 0:
            allocated_memory = torch.cuda.memory_allocated() / 1024**2  # in MB
            reserved_memory = torch.cuda.memory_reserved() / 1024**2  # in MB
            print(f"Step {state.global_step}:")
            print(f"  Allocated memory: {allocated_memory:.2f} MB")
            print(f"  Reserved memory: {reserved_memory:.2f} MB")
            
        return control

In [15]:
# import torch
# torch.cuda.memory_snapshot()

In [13]:
from transformers import Trainer

trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    data_collator=data_collator, # DataCollatorWithPadding is the default in Trainer, but we specified it
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[CUDAMemoryLogger(log_every_n_steps=400)]
)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [17]:
# seq='MIDYVGSFLGAYFLGFALFYGIGFFKSISNRIIIGI'
# input = tokenizer(seq, return_tensors='pt')
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# input['labels'] = torch.tensor(0)
# out = model(**input.to(device))
# out

In [18]:
# LoRA
best_model = trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mnikolamilicevic[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss,Tp,Tn,Fp,Fn,Accuracy,Precision,Recall,F1-score,Auc,Mcc
400,No log,0.382058,506,578,86,126,0.83642,0.85473,0.800633,0.826797,0.835557,0.67343
800,0.465000,0.358143,531,572,92,101,0.85108,0.852327,0.84019,0.846215,0.850818,0.701944
1200,0.371200,0.350189,512,591,73,120,0.85108,0.875214,0.810127,0.841413,0.850093,0.703305
1600,0.349700,0.343169,522,585,79,110,0.854167,0.868552,0.825949,0.846715,0.853487,0.708624
2000,0.320800,0.334348,524,582,82,108,0.853395,0.864686,0.829114,0.846527,0.85281,0.706891
2400,0.320800,0.334452,514,596,68,118,0.856481,0.883162,0.813291,0.846787,0.855441,0.71438


Step 400:
  Allocated memory: 2767.32 MB
  Reserved memory: 16458.00 MB
Step 800:
  Allocated memory: 2767.06 MB
  Reserved memory: 16458.00 MB
Step 1200:
  Allocated memory: 2767.06 MB
  Reserved memory: 16458.00 MB
Step 1600:
  Allocated memory: 2766.99 MB
  Reserved memory: 16458.00 MB
Step 2000:
  Allocated memory: 2767.09 MB
  Reserved memory: 16458.00 MB
Step 2400:
  Allocated memory: 2767.06 MB
  Reserved memory: 16458.00 MB


In [14]:
# all-params
best_model = trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


[34m[1mwandb[0m: Currently logged in as: [33mnikolamilicevic[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss,Tp,Tn,Fp,Fn,Accuracy,Precision,Recall,F1-score,Auc,Mcc
400,No log,0.337577,541,563,101,91,0.851852,0.842679,0.856013,0.849294,0.851952,0.70372
800,0.405600,0.309057,547,588,76,85,0.875772,0.87801,0.865506,0.871713,0.875524,0.751379
1200,0.283300,0.287741,558,595,69,74,0.88966,0.889952,0.882911,0.886418,0.889498,0.779167
1600,0.212100,0.333742,546,613,51,86,0.89429,0.914573,0.863924,0.888527,0.893558,0.789325
2000,0.129100,0.360917,551,606,58,81,0.892747,0.904762,0.871835,0.887994,0.892243,0.785671
2400,0.129100,0.431226,550,607,57,82,0.892747,0.906096,0.870253,0.887813,0.892205,0.785745


Step 400:
  Allocated memory: 10228.17 MB
  Reserved memory: 19330.00 MB
Step 800:
  Allocated memory: 10227.91 MB
  Reserved memory: 19362.00 MB
Step 1200:
  Allocated memory: 10227.91 MB
  Reserved memory: 19364.00 MB
Step 1600:
  Allocated memory: 10227.84 MB
  Reserved memory: 19364.00 MB
Step 2000:
  Allocated memory: 10227.94 MB
  Reserved memory: 19364.00 MB
Step 2400:
  Allocated memory: 10227.91 MB
  Reserved memory: 19364.00 MB


In [16]:
# best checkpoint after finetuning finished
print(f"Best checkpoint: {trainer.state.best_model_checkpoint}")

Best checkpoint: esm2_t33_650M_UR50D-finetuned_toxi/checkpoint-1600


In [17]:
trainer.evaluate(tokenized_dataset["test"])

{'eval_loss': 0.4631272256374359,
 'eval_TP': 560,
 'eval_TN': 583,
 'eval_FP': 82,
 'eval_FN': 72,
 'eval_accuracy': 0.8812644564379337,
 'eval_precision': 0.8722741433021807,
 'eval_recall': 0.8860759493670886,
 'eval_f1-score': 0.8791208791208791,
 'eval_AUC': 0.8813838393451984,
 'eval_MCC': 0.7625590504037529,
 'eval_runtime': 10.2666,
 'eval_samples_per_second': 126.332,
 'eval_steps_per_second': 7.987,
 'epoch': 4.0}

In [20]:
best_model.metrics

{'train_runtime': 1056.3174,
 'train_samples_per_second': 29.446,
 'train_steps_per_second': 1.84,
 'total_flos': 4479845036486976.0,
 'train_loss': 0.25209634294235167,
 'epoch': 3.0}

In [25]:
# checkpoint_path = "./esm2_t33_650M_UR50D-finetuned_v2/checkpoint-1296/"
# other_model = AutoModelForSequenceClassification.from_pretrained(checkpoint_path, num_labels=2)


In [32]:
# Evaluate epoch 1 
# checkpoint_path_e1 = "./esm2_t33_650M_UR50D-finetuned_v2/checkpoint-648/"
# model_e1 = AutoModelForSequenceClassification.from_pretrained(checkpoint_path_e1, num_labels=2)
# trainer_e1 = Trainer(
#     model=model_e1,
#     args=args,
#     train_dataset=train_dataset,
#     eval_dataset=val_dataset,
#     tokenizer=tokenizer,
#     compute_metrics=compute_metrics,
# )
# trainer_e1.evaluate(test_dataset)


Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


{'eval_loss': 0.3016924560070038,
 'eval_model_preparation_time': 0.0055,
 'eval_TP': 534,
 'eval_TN': 599,
 'eval_FP': 65,
 'eval_FN': 98,
 'eval_accuracy': 0.8742283950617284,
 'eval_precision': 0.8914858096828047,
 'eval_recall': 0.8449367088607594,
 'eval_f1-score': 0.867587327376117,
 'eval_AUC': 0.8735225712978496,
 'eval_MCC': 0.7489617263426595,
 'eval_runtime': 10.0133,
 'eval_samples_per_second': 129.427,
 'eval_steps_per_second': 8.089}

In [14]:
# repo_name = "milka1g/esm2_650M_finetuned_toxicity"
# checkpoint_path_e3 = "./esm2_t33_650M_UR50D-finetuned_v2/checkpoint-1944/"
# model_e3 = AutoModelForSequenceClassification.from_pretrained(checkpoint_path_e3, num_labels=2)
# model_e3.push_to_hub(repo_name)
# tokenizer.push_to_hub(repo_name)

model.safetensors:   0%|          | 0.00/2.61G [00:00<?, ?B/s]

README.md:   0%|          | 0.00/5.17k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/milka1g/esm2_650M_finetuned_toxicity/commit/13f2d12998cf8f67b019aed33b633c24f4ea6d0e', commit_message='Upload tokenizer', commit_description='', oid='13f2d12998cf8f67b019aed33b633c24f4ea6d0e', pr_url=None, repo_url=RepoUrl('https://huggingface.co/milka1g/esm2_650M_finetuned_toxicity', endpoint='https://huggingface.co', repo_type='model', repo_id='milka1g/esm2_650M_finetuned_toxicity'), pr_revision=None, pr_num=None)