# Finetune the BERT model

### load data

In [None]:
import pandas as pd
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"

df = pd.read_csv("/data/human_virus_600k_seq_label_20aa.csv")
seq_20aa = df['sequence'].to_list()
label_seq = df['label'].to_list()
label_20aa = [1 if v == 'human' else 0 for v in label_seq]

### tokenizer

In [7]:
# tokenizer
from transformers import BertTokenizer

#  Initial Tokenizer
AMINO_ACIDS = list("ACDEFGHIKLMNPQRSTVWY") + ["X"]  
SPECIAL_TOKENS = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]
VOCAB = SPECIAL_TOKENS + AMINO_ACIDS  

print(VOCAB)  
with open("vocab.txt", "w") as f:
    for token in VOCAB:
        f.write(token + "\n")

with open("vocab.txt", "r") as f:
    file_vocab = [line.strip() for line in f.readlines()]

tokenizer = BertTokenizer(
    vocab_file="vocab.txt",
    unk_token="[UNK]",
    pad_token="[PAD]",
    cls_token="[CLS]",
    sep_token="[SEP]",
    mask_token="[MASK]",
    do_lower_case=False,  
    tokenize_chinese_chars=False  
)

['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]', 'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y', 'X']


### dataset

In [8]:
import torch
from torch.utils.data import Dataset

class ProteinDataset(Dataset):
    def __init__(self, seqs, labels, tokenizer, max_length=1024):
        self.sequences = seqs
        self.labels = labels  # Ê∑ªÂä†Ê†áÁ≠æ
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = ' '.join(self.sequences[idx])  # Ê†ºÂºèÂåñÊ∞®Âü∫ÈÖ∏Â∫èÂàó
        encoding = self.tokenizer(sequence, 
                                  truncation=True, 
                                  max_length=self.max_length, 
                                  padding="max_length",
                                  return_tensors="pt")
        
        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "labels": torch.tensor(self.labels[idx], dtype=torch.long)
        }


### load bert model

In [9]:

sequences = seq_20aa[:300]+seq_20aa[-300:]
labels = label_20aa[:300]+label_20aa[-300:]

In [None]:
from transformers import BertForSequenceClassification
from transformers import Trainer, default_data_collator
from sklearn.metrics import roc_auc_score
from transformers import TrainingArguments
from torch.utils.data import random_split

training_args = TrainingArguments(
    output_dir="./protein_finetune/",
    overwrite_output_dir=True,
    evaluation_strategy="epoch",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    save_strategy="epoch",
    num_train_epochs=10,  # ÂàÜÁ±ª‰ªªÂä°‰∏ÄËà¨Â∞ëËÆ≠Âá†ËΩÆ
    logging_dir="./logs",
    logging_strategy="epoch",
    fp16=True,
    save_total_limit=2
)

import numpy as np
import torch


from transformers import TrainerCallback
class SaveBestModelCallback(TrainerCallback):
    def __init__(self):
        self.best_loss = float("inf")
        self.best_val_probs = None  
        self.trainer = None
        self.val_labels = None
        self.auc_roc_record = []
        self.val_loss_record = []
        self.train_loss_record = []

    def on_evaluate(self, args, state, control, **kwargs):
        metrics = kwargs.get('metrics')
        eval_loss = metrics.get('eval_loss')
        if self.trainer is None :
            print("Trainer or metrics are None, skipping evaluation.")
            return
        else:
            trainer = self.trainer

        predictions = trainer.predict(trainer.eval_dataset)
        val_preds = predictions.predictions
        val_probs = torch.sigmoid(torch.tensor(val_preds)).numpy()[:,1]
        
        auc_roc = roc_auc_score(self.val_labels, val_probs)
        print(f"########## AUC ROC: {auc_roc}")
        self.auc_roc_record.append(auc_roc)
        self.val_loss_record.append(eval_loss)

        # updata best_val_probs
        if eval_loss < self.best_loss:
            print(f"New best model found with loss: {eval_loss}, save at{args.output_dir}")
            self.best_loss = eval_loss
            self.best_val_probs = val_probs
            trainer.save_model(args.output_dir)

    def on_log(self, args, state, control, **kwargs):
        logs = kwargs.get("logs", {})
        if "loss" in logs:
            self.last_train_loss = logs["loss"]
            self.last_train_loss = logs["loss"]
            self.train_loss_record.append(self.last_train_loss)


valid_probs_all = []
test_probs_all = []


model = BertForSequenceClassification.from_pretrained("./step1_pretrain_bert_with_layer4", num_labels=2)  # load the pretrained bert model, here you can make a bert model from step1_Pretrain_BERT_model.ipynb or just download our model


from sklearn.model_selection import train_test_split  
train_pep, X_temp, train_labels, y_temp = train_test_split(sequences, labels, test_size=0.3, random_state=42)  
valid_pep, test_pep, valid_labels, test_labels = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)  


train_dataset = ProteinDataset(train_pep, train_labels, tokenizer)
val_dataset = ProteinDataset(valid_pep, valid_labels, tokenizer)
test_dataset = ProteinDataset(test_pep, test_labels, tokenizer)

callback = SaveBestModelCallback()

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=default_data_collator,
    callbacks=[callback]  )

callback.trainer = trainer
callback.val_labels = valid_labels

trainer.train()

eval_results = trainer.evaluate()
print(f"Eval Loss: {eval_results['eval_loss']}")

val_labels = np.array(valid_labels)
val_preds = trainer.predict(val_dataset).predictions
val_preds = torch.sigmoid(torch.tensor(val_preds)).numpy()[:,1] 

roc_auc = roc_auc_score(val_labels, val_preds)
print(f"ROC AUC: {roc_auc}")
torch.save(callback.val_loss_record, f"./protein_finetune/continue_val_loss_record.pt")
torch.save(callback.auc_roc_record, f"./protein_finetune/continue_auc_roc_record.pt")
torch.save(callback.train_loss_record, f"./protein_finetune/continue_train_loss_record.pt")



Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ./protein_bert/best_model_4 and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss
1,0.59,0.585202
2,0.4343,0.586911
3,0.4257,0.58938
4,0.3295,0.455044
5,0.2866,0.398539
6,0.2369,0.395273
7,0.228,0.550527
8,0.1761,0.413641
9,0.1497,0.43384
10,0.1458,0.454779


########## AUC ROC: 0.7680000000000001
New best model found with loss: 0.585202157497406, save at./protein_finetune/




########## AUC ROC: 0.863




########## AUC ROC: 0.8750000000000001




########## AUC ROC: 0.88
New best model found with loss: 0.4550439715385437, save at./protein_finetune/




########## AUC ROC: 0.892
New best model found with loss: 0.398539274930954, save at./protein_finetune/




########## AUC ROC: 0.9025000000000001
New best model found with loss: 0.3952731192111969, save at./protein_finetune/




########## AUC ROC: 0.9095




########## AUC ROC: 0.91




########## AUC ROC: 0.912




########## AUC ROC: 0.9095




########## AUC ROC: 0.9095
Eval Loss: 0.45477885007858276




ROC AUC: 0.9095
