# Export Bioclinincal BERT NLP classification model

Export using onnx-runtime

In [67]:
import numpy as np
import onnx
import onnxruntime
import onnxruntime_extensions
import torch

from sklearn.metrics import roc_auc_score, accuracy_score, f1_score

from torch.utils.data import Dataset
from torch import nn
from tqdm import tqdm
from datetime import date
from pathlib import Path
from onnxruntime_extensions import pnp, OrtPyFunction, get_library_path as _lib_path
import transformers
from transformers import logging, AutoTokenizer, AutoConfig, AutoModel
from transformers.onnx import FeaturesManager
logging.set_verbosity_error()

In [24]:
print("pytorch:", torch.__version__)
print("onnxruntime:", onnxruntime.__version__)
print("onnx:", onnx.__version__)
print("transformers:", transformers.__version__)

pytorch: 2.2.1
onnxruntime: 1.17.1
onnx: 1.15.0
transformers: 4.39.0


## Data Mapping

In [2]:
target_variables_dict = {
    'no_finding': 0,
    'atelectasis': 1,
    'cardiomegaly': 2,
    'lung_opacity': 3,
    'pleural_effusion': 4,
}

## Model Setup

In [3]:
# Select the Bio_Discharge_Summary_BERT model
MODEL_CHECKPOINT = 'emilyalsentzer/Bio_ClinicalBERT'

In [4]:
# Select parameters
NUM_CLASSES = 5
MAX_SEQUENCE_LENGTH = 512
NUM_EPOCHS = 15
BATCH_SIZE = 16
LEARNING_RATE = 0.00005

## Classes/Functions

In [5]:
class TokenizerDataset(Dataset):

    def __init__(self, X_data, y_data, tokenizer, max_seq_length):
        self.X_data = X_data
        self.y_data = y_data
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length

    def __len__ (self):
        return len(self.X_data)

    def __getitem__(self, index):
        inputs = self.tokenizer.batch_encode_plus(
            [self.X_data[index]],
            add_special_tokens=True,
            max_length=self.max_seq_length,
            padding='max_length',
            return_tensors='pt',
            truncation=True
        )

        input_ids = inputs['input_ids'].squeeze()
        token_type_ids = inputs['token_type_ids'].squeeze()
        attention_mask = inputs['attention_mask'].squeeze()
        labels = torch.tensor(self.y_data[index]).long()

        return {
            'index': index,
            'input_ids': input_ids,
            'token_type_ids': token_type_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

In [6]:
class MulticlassClassification(nn.Module):

    def __init__(self, checkpoint, num_classes, hidden_size=201, dropout_prob=0.3, freeze_bert=True):
        super(MulticlassClassification, self).__init__()

        self.model = AutoModel.from_pretrained(checkpoint)
        self.hidden_size = hidden_size
        self.dropout_prob = dropout_prob
        self.num_classes = num_classes
        self.freeze_bert = freeze_bert

        for param in self.model.parameters():
            param.requires_grad = not self.freeze_bert

        self.pooler_layer = nn.Linear(self.model.config.hidden_size, hidden_size) # maps the output of the BERT model's hidden state to the hidden_size
        self.relu = nn.ReLU() # introduces non-linearity to the model
        self.dropout = nn.Dropout(dropout_prob) # applied for regularization
        self.classification_layer = nn.Linear(hidden_size, num_classes) # projects the hidden_size down to the number of target classes

    def forward(self, input_ids, token_type_ids = None, attention_mask = None):
        outputs = self.model(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask)

        pooler_output = outputs.pooler_output
        hidden = self.pooler_layer(pooler_output)
        hidden = self.relu(hidden)
        hidden = self.dropout(hidden)
        classification = self.classification_layer(hidden) # logits for each class

        return classification

    def unfreeze_bert_layers(self, n_layers):
        """Unfreezes the top n layers of the BERT model."""
        layers_to_unfreeze = list(self.model.encoder.layer[-n_layers:])
        for layer in layers_to_unfreeze:
            for param in layer.parameters():
                param.requires_grad = True

In [7]:
def test_multiclass_classification_model(
        model,
        test_dataloader,
        checkpoint_folder,
        model_name_folder,
        model_variation,
        class_weights=None
    ):

    today_date = date.today().strftime("%d%m%Y")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    if class_weights is not None:
        class_weights = class_weights.to(device)
        criterion = nn.CrossEntropyLoss(pos_weight=class_weights)
    else:
        criterion = nn.CrossEntropyLoss()

    model.eval()
    total_test_loss = 0

    all_test_indexes = []
    all_test_prob = []
    all_test_preds = []
    all_test_labels = []

    test_indexes_list = []
    test_loss_list = []
    test_prob_list = []
    test_preds_list = []
    test_labels_list = []

    with torch.no_grad():

        for test_batch in tqdm(test_dataloader, desc=f"Test"):
            indexes = test_batch['index'].to(device)
            input_ids = test_batch['input_ids'].to(device)
            token_type_ids = test_batch['token_type_ids'].to(device)
            attention_mask = test_batch['attention_mask'].to(device)
            labels = test_batch['labels'].to(device)

            outputs = model(input_ids, token_type_ids, attention_mask)
            loss = criterion(outputs, labels)

            total_test_loss += loss.item()

            probabilities = torch.softmax(outputs, dim=1)
            predictions = torch.argmax(probabilities, dim=1)

            all_test_indexes.extend(indexes.detach().cpu().numpy().tolist())
            all_test_prob.extend(probabilities.detach().cpu().numpy().tolist())
            all_test_preds.extend(predictions.detach().cpu().numpy().tolist())
            all_test_labels.extend(labels.detach().cpu().numpy().tolist())

    test_loss = total_test_loss / len(test_dataloader)

    all_test_prob = np.array(all_test_prob)
    all_test_preds = np.array(all_test_preds)
    all_test_labels = np.array(all_test_labels)

    test_indexes_list.append(all_test_indexes)
    test_prob_list.append(all_test_prob)
    test_preds_list.append(all_test_preds)
    test_labels_list.append(all_test_labels)
    test_loss_list.append(test_loss)

    # Calculate metrics for testing
    test_auc = roc_auc_score(all_test_labels, all_test_prob, multi_class="ovr")
    test_accuracy = accuracy_score(all_test_labels, all_test_preds)
    test_f1_average = f1_score(all_test_labels, all_test_preds, average='macro')

    print(f"Test AUC: {test_auc:.4f} | "
          f"Test Loss: {test_loss:.4f} | "
          f"Test Accuracy: {test_accuracy:.4f} | "
          f"Test F1 (average): {test_f1_average:.4f}")

    results = {
      "test_indexes": test_indexes_list,
      "test_loss": test_loss_list,
      "test_prob": test_prob_list,
      "test_preds": test_preds_list,
      "test_labels": test_labels_list,
    }

    results_path = f"{checkpoint_folder}/{model_name_folder}/{today_date}__{model_variation}__test_results.pt"
    torch.save(results, results_path)

    return results

In [8]:
def load_results(results_path):
    results = torch.load(results_path, map_location=torch.device('cpu'))
    return results

In [9]:
def load_model(model, checkpoint_path):
    model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
    return model

## Export to ONNX-runtime

In [11]:
CHECKPOINT_FOLDER = "./checkpoints"
MODEL_NAME_FOLDER = "./model_findings"

In [12]:
model = MulticlassClassification(
    checkpoint=MODEL_CHECKPOINT,
    num_classes=NUM_CLASSES,
    freeze_bert=False,
    )

CHECKPOINT_FILE = "./model_findings/bio_clinical_bert__balanced__unfrozen_layers__best.pt"

model.load_state_dict(torch.load(CHECKPOINT_FILE, map_location=torch.device('cpu')))

<All keys matched successfully>

In [13]:
model_tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)

In [14]:
vals = model_tokenizer('hello world', return_tensors='pt',  add_special_tokens=True, padding='max_length', max_length=MAX_SEQUENCE_LENGTH, truncation=True)

In [15]:
print(vals.input_ids.squeeze().shape, vals.attention_mask.squeeze().shape, vals.token_type_ids.squeeze().shape)

torch.Size([512]) torch.Size([512]) torch.Size([512])


In [16]:
model.eval()
inputs = model_tokenizer('hello world', return_tensors='pt',  add_special_tokens=True, padding='max_length', max_length=MAX_SEQUENCE_LENGTH, truncation=True) # (input_ids, attention_mask, token_type_ids)
output = model(**inputs)

In [17]:
prob = torch.softmax(output, dim=1)
prob

tensor([[0.2416, 0.1649, 0.2561, 0.2026, 0.1349]], grad_fn=<SoftmaxBackward0>)

In [18]:
predictions = torch.argmax(prob, dim=1)
predictions

tensor([2])

In [61]:
EXPORT_FOLDER = "../../onnx_models"
export_model_path = Path.joinpath(Path(EXPORT_FOLDER), f"bioclinicalbert_nlp_cls_best.onnx")


In [29]:
if Path(EXPORT_FOLDER).exists() is False:
    print(f"Creating folder {EXPORT_FOLDER}")
    Path(EXPORT_FOLDER).mkdir(parents=True)
symbolic_names = {0: 'batch_size', 1:'max_seq_len'}
model.eval()
with torch.no_grad():
    torch.onnx.export(model=model,
                    args=(inputs['input_ids'], inputs['token_type_ids'], inputs['attention_mask']),
                    f=export_model_path,
                    do_constant_folding=True,
                    input_names=['input_ids', 'token_type_ids', 'attention_mask'],
                    output_names=['output'],
                    dynamic_axes={'input_ids': symbolic_names, 'token_type_ids': symbolic_names, 'attention_mask': symbolic_names, 'output': symbolic_names})
    print(f"Model exported to {export_model_path}")


Model exported to ../../onnx_models/bioclinicalbert_nlp_cls_best.onnx


## Test onnx export

In [30]:
onnx_model = onnx.load(export_model_path)
onnx.checker.check_model(onnx_model)

In [40]:
ort_session = onnxruntime.InferenceSession(export_model_path, providers=['CPUExecutionProvider'])

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(inputs[ort_session.get_inputs()[0].name]),
              ort_session.get_inputs()[1].name: to_numpy(inputs[ort_session.get_inputs()[1].name]),
              ort_session.get_inputs()[2].name: to_numpy(inputs[ort_session.get_inputs()[2].name])}
ort_outs = ort_session.run(None, ort_inputs)

# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(output), ort_outs[0], rtol=1e-03, atol=1e-05)

print("Exported model has been tested with ONNXRuntime, and the result looks good!")

Exported model has been tested with ONNXRuntime, and the result looks good!


## Export pre/post processors (tokenizer/softmax)

In [48]:
export_model_path = Path.joinpath(Path(EXPORT_FOLDER), f"bioclinicalbert_nlp_cls_best.onnx")
export_augmented_model_path = Path.joinpath(Path(EXPORT_FOLDER), f"bioclinicalbert_nlp_cls_best_complete.onnx")
tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT, return_tensors='pt',  add_special_tokens=True, padding='max_length', max_length=MAX_SEQUENCE_LENGTH, truncation=True)


In [44]:
def post_processing_forward(*pred):
    return torch.softmax(pred[0], dim=1)

def mapping_token_output(_1, _2, _3):
    return _1.unsqueeze(0), _2.unsqueeze(0), _3.unsqueeze(0)

In [47]:
help(pnp.PreHuggingFaceBert)

Help on class PreHuggingFaceBert in module onnxruntime_extensions.pnp._nlp:

class PreHuggingFaceBert(onnxruntime_extensions.pnp._base.ProcessingTracedModule)
 |  PreHuggingFaceBert(hf_tok=None, vocab_file=None, do_lower_case=0, strip_accents=1)
 |  
 |  # v1. Order of outputs - input_ids, token_type_ids, attention_mask
 |  #    (this is NOT consistent with the HuggingFace implementation of the tokenizer)
 |  
 |  Method resolution order:
 |      PreHuggingFaceBert
 |      onnxruntime_extensions.pnp._base.ProcessingTracedModule
 |      torch.nn.modules.module.Module
 |      onnxruntime_extensions.pnp._base._ProcessingModule
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __init__(self, hf_tok=None, vocab_file=None, do_lower_case=0, strip_accents=1)
 |      Initialize internal Module state, shared by both nn.Module and ScriptModule.
 |  
 |  export(self, *args, **kwargs)
 |  
 |  forward(self, text)
 |      Define the computation performed at every call.
 |      
 |     

In [51]:
test_sentence = ["This is a test sentence"]
ort_tok = pnp.PreHuggingFaceBert(tokenizer)
onnx_model = onnx.load_model(export_model_path)

In [53]:
augmented_model = pnp.export(pnp.SequentialProcessingModule(ort_tok, mapping_token_output, onnx_model, post_processing_forward), test_sentence, output_path=export_augmented_model_path)

[0;93m2024-03-21 13:25:22.804080 [W:onnxruntime:, execution_frame.cc:858 VerifyOutputSizes] Expected shape from model of {} does not match actual shape of {7} for output ot_9[m
[0;93m2024-03-21 13:25:22.804123 [W:onnxruntime:, execution_frame.cc:858 VerifyOutputSizes] Expected shape from model of {} does not match actual shape of {7} for output ot_10[m
[0;93m2024-03-21 13:25:22.804131 [W:onnxruntime:, execution_frame.cc:858 VerifyOutputSizes] Expected shape from model of {} does not match actual shape of {7} for output ot_11[m
  return _OnnxTracedFunction.apply(torch.tensor(self.func_id), *args, **kwargs)
  return _invoke_onnx_model(args[0].item(), *args[1:], **kwargs)
  results = func(*list(_i.numpy() if isinstance(_i, torch.Tensor) else _i for _i in args), **kwargs)
  [torch.from_numpy(_o) for _o in results]) if isinstance(results, tuple) else torch.from_numpy(results)


In [55]:
model_func = OrtPyFunction.from_model(augmented_model)
result = model_func(test_sentence)
print(result)

[[0.21341915 0.16678403 0.2527243  0.22243719 0.14463529]]


[0;93m2024-03-21 13:30:53.987290 [W:onnxruntime:, graph.cc:3593 CleanUnusedInitializersAndNodeArgs] Removing initializer 'g2_ai.onnx.contrib::_ModelFunctionCall_9'. It is not used by any node and should be removed from the model.[m


In [56]:
# original model
model.eval()
inputs = model_tokenizer(test_sentence, return_tensors='pt',  add_special_tokens=True, padding='max_length', max_length=MAX_SEQUENCE_LENGTH, truncation=True) # (input_ids, attention_mask, token_type_ids)
output = model(**inputs)
print(torch.softmax(output, dim=1))

tensor([[0.2134, 0.1668, 0.2527, 0.2224, 0.1446]], grad_fn=<SoftmaxBackward0>)


In [59]:
test_model_augmented = onnx.load(export_augmented_model_path)
onnx.checker.check_model(test_model_augmented)

In [72]:
so = onnxruntime.SessionOptions()
so.register_custom_ops_library(_lib_path())
ort_session = onnxruntime.InferenceSession(export_augmented_model_path, providers=['CPUExecutionProvider'], sess_options=so)
outputs = ort_session.run(None, {ort_session.get_inputs()[0].name: test_sentence})
print(outputs)

[array([[0.21341915, 0.16678403, 0.2527243 , 0.22243719, 0.14463529]],
      dtype=float32)]


[0;93m2024-03-21 14:58:24.378206 [W:onnxruntime:, graph.cc:3593 CleanUnusedInitializersAndNodeArgs] Removing initializer 'g2_ai.onnx.contrib::_ModelFunctionCall_9'. It is not used by any node and should be removed from the model.[m


In [68]:
help(onnxruntime_extensions.gen_processing_models)

Help on function gen_processing_models in module onnxruntime_extensions.cvt:

gen_processing_models(processor: Union[str, object], pre_kwargs: dict = None, post_kwargs: dict = None, opset: int = None, **kwargs)
    Generate the pre- and post-processing ONNX model, basing on the name or HF class.
    
    Parameters
    ----------
    processor:
        the HF processor/tokenizer instance, or the name (str) of a Data Processor
        the instance is preferred, otherwise when name was given, the corresponding configuration for the processor
        has to be provided in the kwargs
    pre_kwargs: dict
        Keyword arguments for generating the pre-processing model
        WITH_DEFAULT_INPUTS: bool, add default inputs to the graph, default is True
        CAST_TOKEN_ID: bool, add a cast op to output token IDs to be int64 if needed, default is False
    post_kwargs: dict
        Keyword arguments for generating the post-processing model
    opset: int
        the target opset version of

## Compress ONNX

In [113]:
from onnxruntime.quantization import quantize_dynamic, shape_inference, QuantType
from onnxconverter_common import float16
from onnxruntime.transformers import optimizer

In [112]:
model_fp32 = Path.joinpath(Path(EXPORT_FOLDER), f"bioclinicalbert_nlp_cls_best.onnx")

In [120]:
model_fp16_path = Path.joinpath(Path(EXPORT_FOLDER), f"bioclinicalbert_nlp_cls_best_fp16.onnx")
optimized_model = optimizer.optimize_model(model_fp32, model_type='bert')
optimized_model.convert_float_to_float16()
optimized_model.save_model_to_file(model_fp16_path)
# model_load_fp32 = onnx.load(model_fp32)
# model_load_fp16 = float16.convert_float_to_float16(model_load_fp32)
# model_fp16_path = Path.joinpath(Path(EXPORT_FOLDER), f"bioclinicalbert_nlp_cls_best_fp16.onnx")
# onnx.save(model_load_fp16, model_fp16_path)

In [121]:
onnx_model = onnx.load(model_fp16_path)
onnx.checker.check_model(onnx_model)

In [122]:
model.eval()
inputs = model_tokenizer('hello world', return_tensors='pt',  add_special_tokens=True, padding='max_length', max_length=MAX_SEQUENCE_LENGTH, truncation=True) # (input_ids, attention_mask, token_type_ids)
output = model(**inputs)

ort_session = onnxruntime.InferenceSession(model_fp16_path, providers=['CPUExecutionProvider'])
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(inputs[ort_session.get_inputs()[0].name]),
              ort_session.get_inputs()[1].name: to_numpy(inputs[ort_session.get_inputs()[1].name]),
              ort_session.get_inputs()[2].name: to_numpy(inputs[ort_session.get_inputs()[2].name])}
ort_outs = ort_session.run(None, ort_inputs)

# compare ONNX Runtime and PyTorch results
print(to_numpy(output), ort_outs[0])
np.testing.assert_allclose(to_numpy(output), ort_outs[0], rtol=1e-02, atol=1e-04)

print("Exported model has been tested with ONNXRuntime, and the result looks good!")

[[ 0.4246738   0.04263563  0.483116    0.24882282 -0.15789741]] [[ 0.42475727  0.04258233  0.4832514   0.24885708 -0.15792966]]
Exported model has been tested with ONNXRuntime, and the result looks good!
