# Online analytic for host pathogen protein interactions analysis
### Part 4: Finetuning ESM2b model
- We will be finetuning the `facebook/esm2_t33_650M_UR50D` protein language model. 
- The esm2 model is a newer version on the older `facebook/esm1b_t33_650M_UR50S`, which is based on the RoBERTa architecture.

### Step 0: Installations & Setup
Here we install the various libraries needed for our finetuning aspect of the project
- We also specify the checkpoint we are using for our model
- We will be utilizing the `facebook/esm2_t33_650M_UR50D` checkpoint as it is the most similar to esm1

In [1]:
# # Code to install biopython package; uncomment lines below to install the package if not already installed
# import sys
# !conda install --yes --prefix {sys.prefix} datasets transformers=4.28.0 torch scipy scikit-learn evaluate
# !conda install --yes --prefix {sys.prefix} datasets transformers torch scipy scikit-learn evaluate pytorch-cuda
# !conda install --yes --prefix {sys.prefix} pytorch-cuda=11.8 -c pytorch -c nvidia
# !pip install datasets transformers torch scikit-learn evaluate
!pip install evaluate


Collecting evaluate
  Downloading evaluate-0.4.1-py3-none-any.whl (84 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: evaluate
Successfully installed evaluate-0.4.1


In [2]:
import wandb
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
wandbkey = user_secrets.get_secret("wandb-key")
wandb.login(key=wandbkey)

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [3]:
import pandas as pd # pandas package
import numpy as np # numpy package
import torch
import datasets
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW, TrainingArguments, Trainer
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score



In [4]:
# model_checkpoint = "facebook/esm2_t33_650M_UR50D"
model_checkpoint = "facebook/esm2_t12_35M_UR50D"


In [5]:
torch.cuda.is_available()

True

In [6]:
## Settings
# Set to False to run crdh instead
using_full_seq = False
# Set to false to use balanced dataset instead
imbalanced_dataset = False

### Step 1: Import train & test data
We will be importing the dataset that was previously generated by our data cleaning and preparation portion of this project

In [7]:
if imbalanced_dataset:
    train_df = pd.read_csv("/kaggle/input/1024-featurized-dataset/1024_dataset_multiclass_featurized(train).csv")
    test_df = pd.read_csv("/kaggle/input/1024-featurized-dataset/1024_dataset_multiclass_featurized(test).csv")
else:
    train_df = pd.read_csv("/kaggle/input/1024-featurized-dataset/1024_dataset_multiclass_featurized_balanced(train).csv")
    test_df = pd.read_csv("/kaggle/input/1024-featurized-dataset/1024_dataset_multiclass_featurized_balanced(test).csv")

print(train_df.shape)
print(test_df.shape)

(31099, 4)
(10367, 4)


In [8]:
train_df.dtypes

comb_seq_full     object
comb_seq_cdr3     object
virus_sequence    object
label              int64
dtype: object

In [9]:
# def standardize_df(input_df):
#     df = input_df.copy()
#     df["label"] = df["label"].astype(int)
#     df = df[["heavy_chain","light_chain","cdrh3","cdrl3","virus_sequence","label"]]
#     if using_full_seq:
#         df["comb_seq"] = df["heavy_chain"]+""+df["light_chain"]
#     else:
#         df["comb_seq"] = df["cdrh3"]+""+df["cdrl3"]
#     return df

In [10]:
def standardize_df(input_df):
    df = input_df.copy()
    df["label"] = df["label"].astype(int)
    if using_full_seq:
        df["comb_seq"] = df["comb_seq_full"]
    else:
        df["comb_seq"] = df["comb_seq_cdr3"]
    return df

In [11]:
train_df = standardize_df(train_df)
test_df = standardize_df(test_df)

In [12]:
train_df.head()

Unnamed: 0,comb_seq_full,comb_seq_cdr3,virus_sequence,label,comb_seq
0,QVQLQESGPGLVKPSQTLSLTCTVSGGSISSGSYYWSWIRQPVGKG...,ARESSPASIPVRGVIWWFDPAAWDDSLNGSVV,PTESIVRFPNITNLCPFDEVFNATRFASVYAWNRKRISNCVADYSV...,2,ARESSPASIPVRGVIWWFDPAAWDDSLNGSVV
1,EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYDMHWVRQTPGEGLE...,ARAGYDILTAYLDLQQSYIMPPWT,PNITNLCPFDEVFNATRFASVYAWNRKRISNCVADYSVLYNLAPFF...,1,ARAGYDILTAYLDLQQSYIMPPWT
2,QVQLVESGGGVVQPGRSLRLSCAASGFTFSSYGMHWVRQAPGKGLE...,AKDMIRGETHFNYYMDVCSYAGSFVV,HHHHHHTNLCPFDEVFNATRFASVYAWNRKRISNCVADYSVLYNFA...,0,AKDMIRGETHFNYYMDVCSYAGSFVV
3,EVQLVESGGGLVQPGGSLRLSCAASGLTVSSNYMRWVRQAPGKGLE...,ARDLYVFGMDVQQLNSDSST,MGILPSPGMPALLSLVSLLSVLLMGCVAETGTRFPNITNLCPFGEV...,2,ARDLYVFGMDVQQLNSDSST
4,QITLKESGPTLVKPTQTLTLTCTFYGFSLSTSGVGVGWIRQPPGKA...,AHTMLFEYGDFDYSSYKSSTTSRV,TNLCPFDEVFNATRFASVYAWNRKRISNCVADYSVLYNLAPFFTFK...,1,AHTMLFEYGDFDYSSYKSSTTSRV


In [13]:
test_df.head()

Unnamed: 0,comb_seq_full,comb_seq_cdr3,virus_sequence,label,comb_seq
0,EVQLVESGGGVVQPGGSRRLSCVASGFTFTSYDIHWVRQGTGKSLE...,VRAYPFYDMLTGDTYHYYGLDVQQYGRSPPLT,HHHHHHTNLCPFDEVFNATRFASVYAWNRKRISNCVADYSVLYNFA...,0,VRAYPFYDMLTGDTYHYYGLDVQQYGRSPPLT
1,QVQLVQSGAEVKKPGSSVKVSCQASGGTFSSYAISWVRQAPGQGLE...,AQRSEMASVQAWDSSTEV,HHHHHHTNLCPFDEVFNATRFASVYAWNRKRISNCVADYSVLYNLA...,2,AQRSEMASVQAWDSSTEV
2,QVQLVQSGAEVKKPGSSVKVSCKASGGTFSNYALSWVRQAPGQGLE...,ARLDGYSFGHDRYYQDGMDDLQQNIYPRT,PKITNLCPFDEVFNATRFASVYAWNRKRISNCVADYSVLYNLAPFF...,2,ARLDGYSFGHDRYYQDGMDDLQQNIYPRT
3,QITLKESGPTLVKPTQTLTLTCTFSGFSLSTSGVGVGWIRQPPGKA...,AHKYQLAAFDYCQQYDNLWT,HHHHHHTNLCPFDEVFNATRFASVYAWNRKRISNCVADYSVLYNFA...,2,AHKYQLAAFDYCQQYDNLWT
4,QGQLVQSGSELQKPGASVRVSCKASGFTLTSYAINWVRQAPGQGLE...,ARVGRYSISWLDDAFDIQQYYSTPLT,PSKPSKRSFIEDLLFNKVTLADAGF,2,ARVGRYSISWLDDAFDIQQYYSTPLT


### Step 2: Tokenize inputs

In [14]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

Downloading (…)okenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

In [15]:
def preprocess_function(df):
    return tokenizer(df["comb_seq"], df["virus_sequence"])

In [16]:
# test_dataset = datasets.Dataset.from_pandas(test_df[["comb_seq","virus_sequence","label"]])
# test_dataset = test_dataset.map(preprocess_function, batched=True)


In [17]:
train_dataset = datasets.Dataset.from_pandas(train_df[["comb_seq","virus_sequence","label"]])
test_dataset = datasets.Dataset.from_pandas(test_df[["comb_seq","virus_sequence","label"]])

train_dataset = train_dataset.map(preprocess_function, batched=True)
train_dataset = train_dataset.remove_columns(["comb_seq","virus_sequence"])
train_dataset = train_dataset.rename_column("label","labels")
train_dataset.set_format(type="torch")

test_dataset = test_dataset.map(preprocess_function, batched=True)
test_dataset = test_dataset.remove_columns(["comb_seq","virus_sequence"])
test_dataset = test_dataset.rename_column("label","labels")
test_dataset.set_format(type="torch")

  0%|          | 0/32 [00:00<?, ?ba/s]

  0%|          | 0/11 [00:00<?, ?ba/s]

In [18]:
train_dataset

Dataset({
    features: ['labels', 'input_ids', 'attention_mask'],
    num_rows: 31099
})

In [19]:
test_dataset

Dataset({
    features: ['labels', 'input_ids', 'attention_mask'],
    num_rows: 10367
})

### Step 3: Perform training

In [20]:
# if multiclass_prediction:
#     num_labels = 3
#     id2Label = {0: "NOT_NEUTRALISING", 1: "WEAK_NEUTRALISATION", 2: "NEUTRALISING"}
#     label2Id = {"NOT_NEUTRALISING":0,"WEAK_NEUTRALISATION":1,"NEUTRALISING":2}
# else:
#     num_labels = 2
#     id2Label = {0: "NOT_NEUTRALISING", 1: "NEUTRALISING"}
#     label2Id = {"NOT_NEUTRALISING":0,"NEUTRALISING":2}
num_labels = 3
id2Label = {0: "NOT_NEUTRALISING", 1: "WEAK_NEUTRALISATION", 2: "NEUTRALISING"}
label2Id = {"NOT_NEUTRALISING":0,"WEAK_NEUTRALISATION":1,"NEUTRALISING":2}
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels, id2label=id2Label, label2id=label2Id)

Downloading (…)lve/main/config.json:   0%|          | 0.00/778 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/136M [00:00<?, ?B/s]

Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.out_proj.weight', 'classifier.out_proj.bias', 'classifier.dense.bias', 'classifier.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [21]:
model_name = "neu-pred"
args = TrainingArguments(
    output_dir="./neu-pred",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=10,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    logging_dir="./logs",
)

In [22]:
from evaluate import load
import numpy as np

metric = load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    new_df = pd.DataFrame()
    new_df['pred_label'] = predictions
    new_df['true_label'] = labels
    new_df.to_csv('result_prediction.csv', header=True)
    print("=============================")
    print("Predictions:")
    print(predictions)
    print("Labels:")
    print(labels)
    return metric.compute(predictions=predictions, references=labels)

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

In [23]:
train_dataset

Dataset({
    features: ['labels', 'input_ids', 'attention_mask'],
    num_rows: 31099
})

In [24]:
test_dataset

Dataset({
    features: ['labels', 'input_ids', 'attention_mask'],
    num_rows: 10367
})

In [25]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [26]:
trainer.train()
results = trainer.evaluate()
print(results)

[34m[1mwandb[0m: Currently logged in as: [33mqixyqix[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: wandb version 0.15.12 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
[34m[1mwandb[0m: Tracking run with wandb version 0.15.9
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20231013_170021-8fg5v9k4[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mmisunderstood-hill-45[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/qixyqix/huggingface[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/qixyqix/huggingface/runs/8fg5v9k4[0m


Epoch,Training Loss,Validation Loss,Accuracy
1,0.8571,0.701218,0.697116
2,0.4844,0.464676,0.818559
3,0.313,0.370864,0.861966
4,0.206,0.330238,0.897077
5,0.1488,0.34801,0.907206
6,0.1124,0.348505,0.916562
7,0.0838,0.36754,0.919745
8,0.0571,0.374796,0.926401
9,0.047,0.370868,0.930163
10,0.0241,0.379278,0.929391


Predictions:
[0 2 2 ... 2 0 1]
Labels:
[0 2 2 ... 2 1 0]




Predictions:
[0 2 2 ... 2 0 1]
Labels:
[0 2 2 ... 2 1 0]




Predictions:
[0 2 2 ... 2 1 0]
Labels:
[0 2 2 ... 2 1 0]




Predictions:
[0 2 2 ... 2 1 0]
Labels:
[0 2 2 ... 2 1 0]




Predictions:
[0 2 2 ... 2 1 0]
Labels:
[0 2 2 ... 2 1 0]




Predictions:
[0 2 2 ... 2 1 0]
Labels:
[0 2 2 ... 2 1 0]




Predictions:
[0 2 2 ... 2 1 0]
Labels:
[0 2 2 ... 2 1 0]




Predictions:
[0 2 2 ... 2 1 0]
Labels:
[0 2 2 ... 2 1 0]




Predictions:
[0 2 2 ... 2 1 0]
Labels:
[0 2 2 ... 2 1 0]




Predictions:
[0 2 2 ... 2 1 0]
Labels:
[0 2 2 ... 2 1 0]




Predictions:
[0 2 2 ... 2 1 0]
Labels:
[0 2 2 ... 2 1 0]
{'eval_loss': 0.37086814641952515, 'eval_accuracy': 0.9301630172663259, 'eval_runtime': 70.3804, 'eval_samples_per_second': 147.3, 'eval_steps_per_second': 9.207, 'epoch': 10.0}
