# Fine-tuning protein language model (pLM) to protein classification

This notebook demonstrates how to use HF transformers library and the trained pLM to classify proteins as knotted vs. unknotted.

To run this notebook you are supposed to have two CSV files (knotted and unknotted proteins), each haveing one column "seq". Make sure the separator is a comma "," and not a semicolon ";". 

In [1]:
import transformers, datasets, torch

transformers.__version__, datasets.__version__, torch.__version__

2023-02-07 08:07:43.683858: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


('4.24.0', '2.7.1', '1.13.0.post200')

In [2]:
import torch
import numpy as np
import pandas as pd
from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollator, Trainer, TrainingArguments
from datasets import Dataset, load_metric, load_dataset, Features, Value
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, recall_score

In [21]:
# Hugging face model to be fine-tuned
HF_MODEL = "Rostlab/prot_bert_bfd"  # see https://huggingface.co/Rostlab, https://huggingface.co/yarongef/DistilProtBert 

### Training parameters 
EPOCHS = 1  # number of epochs for training
LR = 1e-5   # learning rate 
BS = 1      # batch size
GA = 8     # gradient acumulation

## 1) Loading dataset, tokenizer and the model

Choose one dataset for training + testing

### Alphafold Pawel's dataset

In [4]:
dss = load_dataset('EvaKlimentova/knots_AF')
dss = dss.rename_column("uniprotSequence", "seq")
print(dss)
print(dss['train'][0])

Using custom data configuration EvaKlimentova--knots_AF-293560de9ceccb3f
Found cached dataset parquet (/home/jovyan/.cache/huggingface/datasets/EvaKlimentova___parquet/EvaKlimentova--knots_AF-293560de9ceccb3f/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/2 [00:00<?, ?it/s]

DatasetDict({
    test: Dataset({
        features: ['ID', 'latestVersion', 'globalMetricValue', 'uniprotStart', 'uniprotEnd', 'seq', 'Length', 'Domain_architecture', 'InterPro', 'Max_Topology', 'Max Freq', 'Knot Core', 'label', 'FamilyName'],
        num_rows: 39412
    })
    train: Dataset({
        features: ['ID', 'latestVersion', 'globalMetricValue', 'uniprotStart', 'uniprotEnd', 'seq', 'Length', 'Domain_architecture', 'InterPro', 'Max_Topology', 'Max Freq', 'Knot Core', 'label', 'FamilyName'],
        num_rows: 157644
    })
})
{'ID': 'A0A7V9G3C1', 'latestVersion': 4, 'globalMetricValue': 82.62, 'uniprotStart': 1, 'uniprotEnd': 924, 'seq': 'MIAFDHILALILALPAAAAWWLWFRRGHGRVVRAVALAALVLAAAGPHVDFGRGGSDVVVVVDRSASMGEALQRQDEALRAIGEQRRGHDRLAVVAFGERALVVQAPQETGVPRLADAIAGDSGSELADGLEAGWSVLPAGRSGRVVVLSDGEFTGLEPRMAGARFALARVPIDVLPEVRSTAADAAIMDVELPTALRLGESFIGAARLLSDRDERRPWRIMRLHARGGGAEDKLVASGVAELSSLRPTIVTFADRPPAAGVAQYRVELDERDDRQPLNNRARAVLRVTGGERVLVLGGDGTPGNIATALSAAGMTVVCRAEGPVSLAELVGVSCLVLEQVPAD

### Load model

In [22]:
# load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL, do_lower_case=False)
model = AutoModelForSequenceClassification.from_pretrained(HF_MODEL)

loading configuration file config.json from cache at /home/jovyan/.cache/huggingface/hub/models--Rostlab--prot_bert_bfd/snapshots/6c5c8a55a52ff08a664dfd584aa1773f125a0487/config.json
Model config BertConfig {
  "_name_or_path": "Rostlab/prot_bert_bfd",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.0,
  "classifier_dropout": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 40000,
  "model_type": "bert",
  "num_attention_heads": 16,
  "num_hidden_layers": 30,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.24.0",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30
}

loading file vocab.txt from cache at /home/jovyan/.cache/huggingface/hub/models--Rostlab--prot_bert_bfd/snapshots/6c5c8a55a52ff08a664dfd584aa1773f125a0487/vocab.txt
loading file tokenizer

## 2) Dataset tokenization

In [8]:
tokenizer

PreTrainedTokenizerFast(name_or_path='Rostlab/prot_bert_bfd', vocab_size=30, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

In [9]:
def tokenize_function(s):
  seq_split = " ".join(s['seq'])
  return tokenizer(seq_split)

tokenize_function({'seq': "B"})

{'input_ids': [2, 27, 3], 'token_type_ids': [0, 0, 0], 'attention_mask': [1, 1, 1]}

In [10]:
tokenized_datasets = dss.map(tokenize_function, remove_columns=['seq', 'ID', 'latestVersion', 'globalMetricValue', 'uniprotStart', 'uniprotEnd', 'Length', 'Domain_architecture', 'InterPro', 'Max_Topology', 'Max Freq', 'Knot Core', 'FamilyName'], num_proc=4)
tokenized_datasets.set_format("pt")
print(tokenized_datasets)
print(tokenized_datasets['train'][0])

       

#0:   0%|          | 0/9853 [00:00<?, ?ex/s]

 

#1:   0%|          | 0/9853 [00:00<?, ?ex/s]

#2:   0%|          | 0/9853 [00:00<?, ?ex/s]

#3:   0%|          | 0/9853 [00:00<?, ?ex/s]

        

#1:   0%|          | 0/39411 [00:00<?, ?ex/s]

#2:   0%|          | 0/39411 [00:00<?, ?ex/s]

#3:   0%|          | 0/39411 [00:00<?, ?ex/s]

#0:   0%|          | 0/39411 [00:00<?, ?ex/s]

DatasetDict({
    test: Dataset({
        features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 39412
    })
    train: Dataset({
        features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 157644
    })
})
{'label': tensor(0), 'input_ids': tensor([ 2, 21, 11,  6, 19, 14, 22, 11,  5,  6,  5, 11,  5,  6,  5, 16,  6,  6,
         6,  6, 24, 24,  5, 24, 19, 13, 13,  7, 22,  7, 13,  8,  8, 13,  6,  8,
         6,  5,  6,  6,  5,  8,  5,  6,  6,  6,  7, 16, 22,  8, 14, 19,  7, 13,
         7,  7, 10, 14,  8,  8,  8,  8,  8, 14, 13, 10,  6, 10, 21,  7,  9,  6,
         5, 18, 13, 18, 14,  9,  6,  5, 13,  6, 11,  7,  9, 18, 13, 13,  7, 22,
        14, 13,  5,  6,  8,  8,  6, 19,  7,  9, 13,  6,  5,  8,  8, 18,  6, 16,
        18,  9, 15,  7,  8, 16, 13,  5,  6, 14,  6, 11,  6,  7, 14, 10,  7, 10,
         9,  5,  6, 14,  7,  5,  9,  6,  7, 24, 10,  8,  5, 16,  6,  7, 13, 10,
         7, 13,  8,  8,  8,  5, 10, 14,  7,  9, 1

## 3) Training

In [23]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    specificity=recall_score(labels, preds, pos_label=0)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall (TPR)': recall,
        'specificity (TNR)': specificity
    }

In [24]:
training_args = TrainingArguments('outputs2', learning_rate=LR, warmup_ratio=0.1, lr_scheduler_type='cosine', fp16=True,
            evaluation_strategy="epoch", save_strategy="epoch", per_device_train_batch_size=BS, per_device_eval_batch_size=BS, gradient_accumulation_steps=GA,
            num_train_epochs=EPOCHS, load_best_model_at_end=True, save_total_limit=1, weight_decay=0.01, report_to='none', gradient_checkpointing=True, optim="adafactor")  # taken from https://huggingface.co/docs/transformers/v4.18.0/en/performance

PyTorch: setting up devices


In [25]:
trainer = Trainer(
    model,
    training_args,
    train_dataset=tokenized_datasets["train"], # downsampled before
    eval_dataset=tokenized_datasets["test"], # downsampled before
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

Using cuda_amp half precision backend


In [None]:
trainer.train()

***** Running training *****
  Num examples = 157644
  Num Epochs = 1
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 8
  Total optimization steps = 19705
  Number of trainable parameters = 419933186
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss


In [None]:
trainer.save_model("Alphafold_dataset/ProtBertBFD_ALPHAFOLD_v3.2")

In [None]:
trainer.evaluate()

***** Running Evaluation *****
  Num examples = 39412
  Batch size = 1


### Evaluate on different dataset

#### Based on family

In [34]:
FAMILIES = ['SPOUT', 'TDD', 'DUF', 'AdoMet synthase', 'Carbonic anhydrase', 'UCH', 'ATCase/OTCase', 'ribosomal-mitochondrial', 'membrane', 
            'VIT', 'biosynthesis of lantibiotics', 'PGluconate dehydrogenase']

In [35]:
for family in FAMILIES:
    print(family)
    dataset = dss['test'].filter(lambda x: (x['FamilyName'] is not None) and (family in x['FamilyName']))
    dataset = dataset.map(tokenize_function, remove_columns=['seq', 'ID', 'latestVersion', 'globalMetricValue', 'uniprotStart', 'uniprotEnd', 'Length', 'Domain_architecture', 'InterPro', 'Max_Topology', 'Max Freq', 'Knot Core', 'FamilyName'], num_proc=4)
    dataset.set_format("pt")

    trainer = Trainer(
        model,
        training_args,
        train_dataset=dataset,
        eval_dataset=dataset,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics
    )

    metrics = trainer.evaluate()
    print(f"Accuracy {round(metrics['eval_accuracy'], 4)}\nTPR {round(metrics['eval_recall (TPR)'], 4)}\nTNR {round(metrics['eval_specificity (TNR)'], 4)}")
    print()

Loading cached processed dataset at /home/jovyan/.cache/huggingface/datasets/EvaKlimentova___parquet/EvaKlimentova--knots_AF-293560de9ceccb3f/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-96c870fd24106dba.arrow


VIT
     

Loading cached processed dataset at /home/jovyan/.cache/huggingface/datasets/EvaKlimentova___parquet/EvaKlimentova--knots_AF-293560de9ceccb3f/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-033dbe67db36d00a.arrow


 

Loading cached processed dataset at /home/jovyan/.cache/huggingface/datasets/EvaKlimentova___parquet/EvaKlimentova--knots_AF-293560de9ceccb3f/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-c280e8508625a1de.arrow


 

Loading cached processed dataset at /home/jovyan/.cache/huggingface/datasets/EvaKlimentova___parquet/EvaKlimentova--knots_AF-293560de9ceccb3f/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-3190a59783ef6e70.arrow


 

Loading cached processed dataset at /home/jovyan/.cache/huggingface/datasets/EvaKlimentova___parquet/EvaKlimentova--knots_AF-293560de9ceccb3f/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-d817b142ea0e6ac9.arrow
Using cuda_amp half precision backend
***** Running Evaluation *****
  Num examples = 14347
  Batch size = 1


Loading cached processed dataset at /home/jovyan/.cache/huggingface/datasets/EvaKlimentova___parquet/EvaKlimentova--knots_AF-293560de9ceccb3f/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-8d95e7ab0fb1d3e0.arrow


Accuracy 0.9873
TPR 0.9415
TNR 0.9935

biosynthesis of lantibiotics
     

Loading cached processed dataset at /home/jovyan/.cache/huggingface/datasets/EvaKlimentova___parquet/EvaKlimentova--knots_AF-293560de9ceccb3f/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-14b0e51931e1d0e4.arrow


  

Loading cached processed dataset at /home/jovyan/.cache/huggingface/datasets/EvaKlimentova___parquet/EvaKlimentova--knots_AF-293560de9ceccb3f/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-d30b655b19018aca.arrow
Loading cached processed dataset at /home/jovyan/.cache/huggingface/datasets/EvaKlimentova___parquet/EvaKlimentova--knots_AF-293560de9ceccb3f/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-51a5f6664abd12eb.arrow


 

Loading cached processed dataset at /home/jovyan/.cache/huggingface/datasets/EvaKlimentova___parquet/EvaKlimentova--knots_AF-293560de9ceccb3f/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-5c4c52d44e8fa664.arrow
Using cuda_amp half precision backend
***** Running Evaluation *****
  Num examples = 392
  Batch size = 1


Loading cached processed dataset at /home/jovyan/.cache/huggingface/datasets/EvaKlimentova___parquet/EvaKlimentova--knots_AF-293560de9ceccb3f/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-e8998b4029e84774.arrow
num_proc must be <= 1. Reducing num_proc to 1 for dataset of size 1.
Loading cached processed dataset at /home/jovyan/.cache/huggingface/datasets/EvaKlimentova___parquet/EvaKlimentova--knots_AF-293560de9ceccb3f/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-73cc6266d7b22bf5.arrow
Using cuda_amp half precision backend
***** Running Evaluation *****
  Num examples = 1
  Batch size = 1


Accuracy 0.9719
TPR 0.9811
TNR 0.9685

PGluconate dehydrogenase


Accuracy 1.0
TPR 1.0
TNR 0.0



  _warn_prf(average, modifier, msg_start, len(result))


---------------------------------------------------------------------

## 4) Optional: Uploading the fine-tuned model to HF

In [36]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [37]:
model.push_to_hub("EvaKlimentova/knots_protbertBFD_alphafold")

Configuration saved in /tmp/tmpza3mb48c/config.json
Model weights saved in /tmp/tmpza3mb48c/pytorch_model.bin
Uploading the following files to EvaKlimentova/knots_protbertBFD_alphafold: config.json,pytorch_model.bin


CommitInfo(commit_url='https://huggingface.co/EvaKlimentova/knots_protbertBFD_alphafold/commit/a7d1a84d23998ed2fc577d3d712dfe0a43cce942', commit_message='Upload BertForSequenceClassification', commit_description='', oid='a7d1a84d23998ed2fc577d3d712dfe0a43cce942', pr_url=None, pr_revision=None, pr_num=None)

In [38]:
tokenizer.push_to_hub("EvaKlimentova/knots_protbertBFD_alphafold")

tokenizer config file saved in /tmp/tmptxn9q5_6/tokenizer_config.json
Special tokens file saved in /tmp/tmptxn9q5_6/special_tokens_map.json
Uploading the following files to EvaKlimentova/knots_protbertBFD_alphafold: tokenizer_config.json,special_tokens_map.json,vocab.txt,tokenizer.json


CommitInfo(commit_url='https://huggingface.co/EvaKlimentova/knots_protbertBFD_alphafold/commit/b1edd1ef533bd520add165493f0f324667a496d0', commit_message='Upload tokenizer', commit_description='', oid='b1edd1ef533bd520add165493f0f324667a496d0', pr_url=None, pr_revision=None, pr_num=None)