**Check GPU if exists**

In [1]:
!nvidia-smi

Tue May 25 00:17:44 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.102.04   Driver Version: 450.102.04   CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  A100-SXM4-40GB      On   | 00000000:07:00.0 Off |                    0 |
| N/A   34C    P0    72W / 400W |      0MiB / 40537MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  A100-SXM4-40GB      On   | 00000000:0F:00.0 Off |                    0 |
| N/A   33C    P0    53W / 400W |      0MiB / 40537MiB |      0%      Default |
|       

**Load necessry libraries including huggingface transformers**

In [49]:
import torch

from transformers import AutoModelForSequenceClassification, BertTokenizerFast
from transformers import Trainer, TrainingArguments
from transformers import EvalPrediction

from datasets import load_dataset
from torch.utils.data import Dataset

from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

import re
import gc
import os
import pandas as pd
import numpy as np
import requests
from tqdm.auto import tqdm

In [3]:
torch.__version__

'1.8.1'

**Select Model**

In [4]:
#model_name = "Rostlab/prot_t5_xl_uniref50" # for embedding only
model_name = "Rostlab/prot_bert_bfd" # for fine-tuning

**Load the vocabulary**

In [5]:
if "bert" in model_name:
    seq_tokenizer = BertTokenizerFast.from_pretrained(model_name, do_lower_case=False)
else:
    print("Unkown model name")

Download binding affinity dataset

In [6]:
train = load_dataset("jglaser/binding_affinity",split='train[:90%]')
validation = load_dataset("jglaser/binding_affinity",split='train[90%:]')

Using custom data configuration default
Reusing dataset binding_affinity (/home/xvg/.cache/huggingface/datasets/binding_affinity/default/1.0.0/3a5703f991b1ebb97be31a097633a9c2a357ad6e2fdfb5579aefc24535b8a1e3)
Using custom data configuration default
Reusing dataset binding_affinity (/home/xvg/.cache/huggingface/datasets/binding_affinity/default/1.0.0/3a5703f991b1ebb97be31a097633a9c2a357ad6e2fdfb5579aefc24535b8a1e3)


Tokenize sequences

In [7]:
n = 10

train_seqs = train['seq'][0:n]
val_seqs = validation['seq'][0:n]

train_labels = train['neg_log10_affinity_M'][0:n]
val_labels = validation['neg_log10_affinity_M'][0:n]

In [18]:
def expand(seqs):
    input_fixed = ["".join(seq.split()) for seq in seqs]
    input_fixed = [re.sub(r"[UZOB]", "X", seq) for seq in input_fixed]
    return [list(seq) for seq in input_fixed]

In [19]:
max_length=1024
train_seqs_encodings = seq_tokenizer(expand(train_seqs),
                                     is_split_into_words=True,
                                     return_offsets_mapping=True,
                                     truncation=True,
                                     padding=True,
                                     add_special_tokens=True,
                                     max_length=max_length)
val_seqs_encodings = seq_tokenizer(expand(val_seqs),
                                   is_split_into_words=True,
                                   return_offsets_mapping=True,
                                   truncation=True,
                                   padding=True,
                                   add_special_tokens=True,
                                   max_length=max_length)

In [21]:
# don't want to pass this to the model
_ = train_seqs_encodings.pop("offset_mapping")
_ = val_seqs_encodings.pop("offset_mapping")

In [22]:
len(train_seqs)

10

Create the dataset

In [23]:
class AffinityDataset(Dataset):
    def __init__(self, encodings, labels, pKd_cutoff = 5.0):
        self.encodings = encodings
        self.labels = labels
        self.pKd_cutoff = pKd_cutoff

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(np.where(self.labels[idx] > self.pKd_cutoff,1,0))
        return item

    def __len__(self):
        return len(self.labels)

In [24]:
train_dataset = AffinityDataset(train_seqs_encodings,train_labels)
val_dataset = AffinityDataset(val_seqs_encodings,val_labels)

Define the evaluation metrics

In [50]:
def compute_metrics(p: EvalPrediction):
    preds_list, out_label_list = (np.argmax(p.predictions,axis=1), p.label_ids)

    return {
        "accuracy": accuracy_score(out_label_list, preds_list),
        "precision": precision_score(out_label_list, preds_list),
        "recall": recall_score(out_label_list, preds_list),
        "f1": f1_score(out_label_list, preds_list),
    }


Create the model

In [51]:
def model_init():
    return AutoModelForSequenceClassification.from_pretrained(model_name,
                                                         num_labels=2, # binary classifier
                                                         gradient_checkpointing=False)

Define the training args and start the trainer

In [None]:
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=5,              # total number of training epochs
    per_device_train_batch_size=1,   # batch size per device during training
    per_device_eval_batch_size=8,   # batch size for evaluation
    warmup_steps=200,                # number of warmup steps for learning rate scheduler
    learning_rate=3e-05,             # learning rate
    weight_decay=0.0,                # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=200,               # How often to print logs
    do_train=True,                   # Perform training
    do_eval=True,                    # Perform evaluation
    evaluation_strategy="epoch",     # evalute after each epoch
    gradient_accumulation_steps=32,  # total number of steps before back propagation
    fp16=True,                       # Use mixed precision
    fp16_opt_level="02",             # mixed precision mode
    run_name="ProBert-BFD-SS3",      # experiment name
    seed=3,                          # Seed for experiment reproducibility
    load_best_model_at_end=True,
    metric_for_best_model="eval_accuracy",
    greater_is_better=True,

)

trainer = Trainer(
    model_init=model_init,                # the instantiated 🤗 Transformers model to be trained
    args=training_args,                   # training arguments, defined above
    train_dataset=train_dataset,          # training dataset
    eval_dataset=val_dataset,             # evaluation dataset
    compute_metrics = compute_metrics,    # evaluation metrics
)

trainer.train()

Some weights of the model checkpoint at Rostlab/prot_bert_bfd were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not init

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,0.675773,0.6,0.6,1.0,0.75
2,No log,0.675787,0.6,0.6,1.0,0.75
3,No log,0.675785,0.6,0.6,1.0,0.75


