Finetune BERT models

In [1]:
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim.adamw import AdamW

from datasets import Dataset, DatasetDict
from transformers import AutoModel, AutoTokenizer

from onnxruntime import InferenceSession

import duckdb

  from .autonotebook import tqdm as notebook_tqdm


### Load iris data

In [2]:
con = duckdb.connect("../imdb.db")
imdb = con.sql("SELECT * FROM imdb").df()
con.close()

imdb['label'] = imdb['label'].astype(int)
imdb.head()

Unnamed: 0,text,label,stage
0,I rented I AM CURIOUS-YELLOW from my video sto...,0,train
1,"""I Am Curious: Yellow"" is a risible and preten...",0,train
2,If only to avoid making this type of film in t...,0,train
3,This film was probably inspired by Godard's Ma...,0,train
4,"Oh, brother...after hearing about this ridicul...",0,train


In [3]:
imdb_test = Dataset.from_pandas(imdb[imdb['stage'] == 'test'].drop(columns=['stage']))
imdb_train = Dataset.from_pandas(imdb[imdb['stage'] == 'train'].drop(columns=['stage']))

dataset = DatasetDict()
dataset['train'] = imdb_train
dataset['test'] = imdb_test

dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'label', '__index_level_0__'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label', '__index_level_0__'],
        num_rows: 25000
    })
})

### Train the model

In [4]:
model_name = 'bert-tiny'
model_path = f"/homes/ukumaras/scratch/Models/{model_name}"

tokenizer = AutoTokenizer.from_pretrained(model_path)
bert_model = AutoModel.from_pretrained(model_path)

In [5]:
def preprocess(batch):
    return tokenizer(batch['text'], return_tensors="pt", padding="max_length", truncation=True, max_length=512)

dataset_encoded = dataset.map(preprocess, batched=True, batch_size=None)
# def preprocess_function(examples):
#     return tokenizer(examples["text"], truncation=True)


# tokenized_imdb = imdb.map(preprocess_function, batched=True)
# tokenized_imdb

Map: 100%|██████████| 25000/25000 [00:11<00:00, 2142.38 examples/s]
Map: 100%|██████████| 25000/25000 [00:11<00:00, 2136.08 examples/s]


In [6]:
bert_model

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 128, padding_idx=0)
    (position_embeddings): Embedding(512, 128)
    (token_type_embeddings): Embedding(2, 128)
    (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-1): 2 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=128, out_features=128, bias=True)
            (key): Linear(in_features=128, out_features=128, bias=True)
            (value): Linear(in_features=128, out_features=128, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=128, out_features=128, bias=True)
            (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
    

In [7]:
# dataset_encoded['train']['token_type_ids'][10]

### Text Classification

In [8]:
from transformers import AutoModelForSequenceClassification

num_labels = 2
class WrappedModel(torch.nn.Module):
    def __init__(self, model_path, num_labels):
        super(WrappedModel, self).__init__()
        self.auto_model = AutoModelForSequenceClassification.from_pretrained(model_path, num_labels=num_labels)

    def forward(self,
                input_ids,
                attention_mask,
                labels=None,
                position_ids=None,
                head_mask=None,
                inputs_embeds=None):
        return self.auto_model(input_ids,
                                token_type_ids=None,
                                attention_mask=attention_mask,
                                labels=labels,
                                position_ids=position_ids,
                                head_mask=head_mask,
                                inputs_embeds=inputs_embeds)
model = WrappedModel(model_path, num_labels)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /homes/ukumaras/scratch/Models/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
from sklearn.metrics import accuracy_score, f1_score

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    f1 = f1_score(labels, preds, average="weighted")
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc, "f1": f1}

In [11]:
from transformers import Trainer, TrainingArguments

batch_size = 64
logging_steps = len(dataset_encoded["train"])
model_name = f"{model_name}-finetuned"
training_args = TrainingArguments(output_dir=model_name,
                                  num_train_epochs=5,
                                  learning_rate=2e-5,
                                  per_device_train_batch_size=batch_size,
                                  per_device_eval_batch_size=batch_size,
                                  weight_decay=0.01,
                                  evaluation_strategy="epoch",
                                  disable_tqdm=False,
                                  logging_steps=logging_steps,
                                  push_to_hub=False, 
                                  log_level="error")

In [12]:
from transformers import Trainer

trainer = Trainer(model=model, args=training_args, 
                  compute_metrics=compute_metrics,
                  train_dataset=dataset_encoded["train"],
                  eval_dataset=dataset_encoded["test"],
                  tokenizer=tokenizer)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False)


In [13]:
trainer.train()


Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.536486,0.75448,0.754226
2,No log,0.419991,0.8166,0.816461
3,No log,0.374859,0.83672,0.836717
4,No log,0.361707,0.84432,0.844263
5,No log,0.35717,0.84752,0.847517


TrainOutput(global_step=1955, training_loss=0.44515237247242645, metrics={'train_runtime': 2417.8273, 'train_samples_per_second': 51.699, 'train_steps_per_second': 0.809, 'total_flos': 0.0, 'train_loss': 0.44515237247242645, 'epoch': 5.0})

In [14]:
trainer.save_model(model_path + '-imdb-cls')

In [15]:
dummy_model_input = tokenizer("This is a sample", return_tensors="pt")

In [20]:
dummy_model_input

{'input_ids': tensor([[ 101, 2023, 2003, 1037, 7099,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}

In [23]:
model_path

'/homes/ukumaras/scratch/Models/bert-tiny'

In [22]:
model.eval()
torch.onnx.export(
    model, 
    tuple([dummy_model_input['input_ids'], dummy_model_input['attention_mask']]),
    f=model_path+"/model.onnx",  
    input_names=['input_ids', 'attention_mask'], 
    output_names=['logits'], 
    dynamic_axes={'input_ids': {0: 'batch_size', 1: 'sequence'}, 
                  'attention_mask': {0: 'batch_size', 1: 'sequence'}, 
                  'logits': {0: 'batch_size', 1: 'sequence'}}, 
    do_constant_folding=True, 
    opset_version=17, 
)

In [24]:
onnx_model_path = f"{model_path}"

tokenizer = AutoTokenizer.from_pretrained(onnx_model_path)
session = InferenceSession(onnx_model_path + "/model.onnx")

In [29]:

inputs = tokenizer("Using DistilBERT with ONNX Runtime!", return_tensors="np")
del inputs['token_type_ids']
inputs

{'input_ids': array([[  101,  2478,  4487, 16643, 23373,  2007,  2006, 26807,  2448,
         7292,   999,   102]]), 'attention_mask': array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [30]:

outputs = session.run(output_names=["logits"], input_feed=dict(inputs))

In [31]:
outputs

[array([[-0.00219239, -0.47350207]], dtype=float32)]

In [None]:
from transformers.models.distilbert import DistilBertConfig, DistilBertOnnxConfig

config = DistilBertConfig()
onnx_config = DistilBertOnnxConfig(config)
print(list(onnx_config.outputs.keys()))

In [None]:
onnx_config.inputs