## Train BERT model

### Load the dataset

In order to fine-tune the BERT models for the cord19 application we need to generate a set of query-document features as well as labels that indicate which documents are relevant for the specific queries. For this exercise we will use the `query` string to represent the query and the `title` string to represent the documents.

The file `labelled_data.json` contains information about the `query` string and the file `training_all_judgement_data.csv` contain information about labels and `title` string. Those files were created and covered elsewhere but you can download them [here](https://drive.google.com/file/d/1R2hZTF6QBKPMaiuS4Du6aXQlBVnfZOAA/view?usp=sharing) and [here](https://drive.google.com/file/d/18jNRM7G7agbO1Mg9t0l1pvqsz-qwEXpz/view?usp=sharing).

In [3]:
import json
from pandas import read_csv

labelled_data = json.load(open("labelled_data_all.json", "r"))
training_data = read_csv("training_all_jugdments_data.csv")

`training_data` has almost everything we need, except the `query` string.

In [5]:
training_data.head()

Unnamed: 0,document_id,query_id,label,title-full
0,005b2j4b,1,2,Monophyletic Relationship between Severe Acute...
1,00fmeepz,1,1,Comprehensive overview of COVID-19 based on cu...
2,010vptx3,1,2,"The SARS, MERS and novel coronavirus (COVID-19..."
3,0194oljo,1,1,Evidence for zoonotic origins of Middle East r...
4,021q9884,1,1,Deadly virus effortlessly hops species


The query string can be obtained from the `labelled_data`.

In [8]:
print(labelled_data[0]["query_id"], labelled_data[0]["query"])

1 coronavirus origin


### Compatible BERT encodings

Since we are training a model that will be deployed in a search application, we need to ensure that the training encodings are compatible with encodings used at serving time. At serving time, document encodings will be applied offline when feeding the documents to the search engine while the query encoding will be applied at run-time upon arrival of the query. In addition, it might be relevant to use different maximum length for queries and documents.

In [9]:
def create_bert_encodings(queries, docs, tokenizer, query_input_size, doc_input_size):
    queries_encodings = tokenizer(
        queries, truncation=True, max_length=query_input_size-2, add_special_tokens=False
    )
    docs_encodings = tokenizer(
        docs, truncation=True, max_length=doc_input_size-1, add_special_tokens=False
    )
    
    TOKEN_NONE=0
    TOKEN_CLS=101
    TOKEN_SEP=102

    input_ids = []
    token_type_ids = []
    attention_mask = []
    for query_input_ids, doc_input_ids in zip(queries_encodings["input_ids"], docs_encodings["input_ids"]):
        # create input id
        input_id = [TOKEN_CLS] + query_input_ids + [TOKEN_SEP] + doc_input_ids + [TOKEN_SEP]
        number_tokens = len(input_id)
        padding_length = max(128 - number_tokens, 0)
        input_id = input_id + [TOKEN_NONE] * padding_length
        input_ids.append(input_id)
        # create token id
        token_type_id = [0] * len([TOKEN_CLS] + query_input_ids + [TOKEN_SEP]) + [1] * len(doc_input_ids + [TOKEN_SEP]) + [TOKEN_NONE] * padding_length
        token_type_ids.append(token_type_id)
        # create attention_mask
        attention_mask.append([1] * number_tokens + [TOKEN_NONE] * padding_length)

    encodings = {
        "input_ids": input_ids,
        "token_type_ids": token_type_ids,
        "attention_mask": attention_mask
    }
    return encodings

### Create Datasets

Create a list for queries (represented by the query string), docs (represented by the doc titles) and labels from the `labelled_data` and `training_data` that we loaded earlier.

In [12]:
train_queries = []
train_docs = []
train_labels = []
for data_point in labelled_data:
    query_id = data_point["query_id"]
    titles = training_data[training_data["query_id"] == query_id]["title-full"].tolist()
    train_docs.extend(titles)
    train_labels.extend([1 if x > 0 else 0 for x in training_data[training_data["query_id"] == query_id]["label"].tolist()])
    query = data_point["query"]
    train_queries.extend([query] * len(titles))

We are going to use a simple data split into train and validation sets for illustration purposes. The cord19 use case probably needs cross-validation to be used since it has only 50 queries containing relevance judgement.

In [13]:
from sklearn.model_selection import train_test_split
train_queries, val_queries, train_docs, val_docs, train_labels, val_labels = train_test_split(
    train_queries, train_docs, train_labels, test_size=.2
)

Create train and validation encodings.

In [14]:
model_name = "google/bert_uncased_L-4_H-512_A-8"
query_input_size=24
doc_input_size=64

In [15]:
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained(model_name)

train_encodings = create_bert_encodings(
    queries=train_queries, 
    docs=train_docs, 
    tokenizer=tokenizer, 
    query_input_size=query_input_size, 
    doc_input_size=doc_input_size
)

val_encodings = create_bert_encodings(
    queries=val_queries, 
    docs=val_docs, 
    tokenizer=tokenizer, 
    query_input_size=query_input_size, 
    doc_input_size=doc_input_size
)

Create a torch dataset

In [10]:
import torch

class Cord19Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = Cord19Dataset(train_encodings, train_labels)
val_dataset = Cord19Dataset(val_encodings, val_labels)

Fine-tune the model (only task specific weights.)

In [None]:
from transformers import BertForSequenceClassification, Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=3,              # total number of training epochs
    per_device_train_batch_size=16,  # batch size per device during training
    per_device_eval_batch_size=64,   # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=10,
)

model = BertForSequenceClassification.from_pretrained(model_name)
for param in model.base_model.parameters():
    param.requires_grad = False

trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=val_dataset             # evaluation dataset
)

trainer.train()

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=116252865.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at google/bert_uncased_L-4_H-512_A-8 were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification w

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=3.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=7145.0, style=ProgressStyle(description_w…

{'loss': 0.8461690902709961, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.0013995801259622112, 'total_flos': 3534603141120, 'step': 10}
{'loss': 0.8681703567504883, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.0027991602519244225, 'total_flos': 7069206282240, 'step': 20}
{'loss': 0.857829475402832, 'learning_rate': 3e-06, 'epoch': 0.004198740377886634, 'total_flos': 10603809423360, 'step': 30}
{'loss': 0.8333715438842774, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.005598320503848845, 'total_flos': 14138412564480, 'step': 40}
{'loss': 0.8387279510498047, 'learning_rate': 5e-06, 'epoch': 0.006997900629811057, 'total_flos': 17673015705600, 'step': 50}
{'loss': 0.8409191131591797, 'learning_rate': 6e-06, 'epoch': 0.008397480755773267, 'total_flos': 21207618846720, 'step': 60}
{'loss': 0.8141700744628906, 'learning_rate': 7.000000000000001e-06, 'epoch': 0.00979706088173548, 'total_flos': 24742221987840, 'step': 70}
{'loss': 0.7777870178222657, 'learning_rate': 8.000000

### Export the model to onnx

In [None]:
from torch.onnx import export
from pathlib import Path 

model_onnx_path = Path(model_name + ".onnx")
dummy_input = (
    train_dataset[0]["input_ids"].unsqueeze(0), 
    train_dataset[0]["token_type_ids"].unsqueeze(0), 
    train_dataset[0]["attention_mask"].unsqueeze(0)
)
input_names = ["input_ids", "token_type_ids", "attention_mask"]
output_names = ["logits"]
export(
    model, dummy_input, model_onnx_path, input_names = input_names, 
    output_names = output_names, verbose=False, opset_version=11
)

Check output type.

In [None]:
import onnxruntime as ort
m = ort.InferenceSession(model_name + ".onnx") 
print(m.get_outputs()[0].name)
print(m.get_outputs()[0].type)
print(m.get_outputs()[0].shape)

Another form to check output type.

In [None]:
import onnx                                                                                                                                                                          
m = onnx.load(model_name + ".onnx")                                                                                                                                                         
m.graph.output