# ONNX Format
---

This notebooks walks through and example of saving a PyTorch model to the ONNX format and then conducting inference through the ONNX runtime.

In [0]:
# Install dependencies

# capture - disables cell output
%%capture
! pip install transformers
! pip install onnx
! pip install onnxruntime

In [3]:
# Mount our gdrive to the notebook for saving a model

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [4]:
import torch
from transformers import *

import warnings
warnings.filterwarnings('ignore')

In [5]:
# BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Pretrained BERT model
model = BertModel.from_pretrained('bert-base-uncased')

# Example
input_ids = torch.tensor([tokenizer.encode("Here is some text to encode", add_special_tokens=True)])

HBox(children=(IntProgress(value=0, description='Downloading', max=231508, style=ProgressStyle(description_wid…




HBox(children=(IntProgress(value=0, description='Downloading', max=361, style=ProgressStyle(description_width=…




HBox(children=(IntProgress(value=0, description='Downloading', max=440473133, style=ProgressStyle(description_…




In [6]:
# Run the sample through the model
with torch.no_grad():
    last_hidden_states = model(input_ids)[0]

# Output from the language model
last_hidden_states

tensor([[[-0.0549,  0.1053, -0.1065,  ..., -0.3550,  0.0686,  0.6506],
         [-0.5759, -0.3650, -0.1383,  ..., -0.6782,  0.2092, -0.1639],
         [-0.1641, -0.5597,  0.0150,  ..., -0.1603, -0.1346,  0.6216],
         ...,
         [ 0.2448,  0.1254,  0.1587,  ..., -0.2749, -0.1163,  0.8809],
         [ 0.0481,  0.4950, -0.2827,  ..., -0.6097, -0.1212,  0.2527],
         [ 0.9046,  0.2137, -0.5897,  ...,  0.3040, -0.6172, -0.1950]]])

In [7]:
last_hidden_states.shape

torch.Size([1, 9, 768])

# ONNX

So we now have a model and we want to serialize and package this model for production.

In [0]:
import onnx
import onnxruntime as ort

In [0]:
# Path to save the model in mounted gdrive
# CHANGE THIS TO THE LOCATION THAT YOU WOULD LIKE FOR YOUR DRIVE
onnx_path = f"/content/drive/My Drive/MLOPS/hands_on/onnx/model"

In [0]:
# Sample text tensor for tracing
sample_input_ids = torch.tensor([tokenizer.encode("Here is some text to encode", add_special_tokens=True)])

In [0]:
# Eporting model to the ONNX framework
torch.onnx.export(model,
                  sample_input_ids,
                  f"{onnx_path}/bert_static.onnx",
                  input_names=['input'],
                  output_names=['output'])

In [13]:
# We can see out model has been save to my gdrive
! ls /content/drive/My\ Drive/MLOPS/hands_on/onnx/model

bert_static.onnx


In [0]:
# Load the ONNX model
onnx_static_model = onnx.load(f"{onnx_path}/bert_static.onnx")

In [15]:
# Inspect the input
onnx_static_model.graph.input[0]

name: "input"
type {
  tensor_type {
    elem_type: 7
    shape {
      dim {
        dim_value: 1
      }
      dim {
        dim_value: 9
      }
    }
  }
}

In [16]:
# Inspect the output
onnx_static_model.graph.output[0]

name: "output"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_value: 1
      }
      dim {
        dim_value: 9
      }
      dim {
        dim_value: 768
      }
    }
  }
}

In [0]:
# Initialize the runtime inference session
ort_session = ort.InferenceSession(f"{onnx_path}/bert_static.onnx")

In [18]:
# Test input text
input_ids = torch.tensor([tokenizer.encode("Here is some more text to encode", add_special_tokens=True)])

# Create the input dictionary for the ONNX runtime - accepts numpy arrays
inference_input = {onnx_static_model.graph.input[0].name: input_ids.numpy()}

# .run() conducts inference
outputs = ort_session.run(None, inference_input)

print(outputs[0])

InvalidArgument: ignored


# Dynamic

Here we will see how to export a model to the ONNX format with dynamic axes.

This is a great reference:
https://pytorch.org/docs/stable/onnx.html

In [0]:
# Sample tensore for tracing
sample_input_ids = torch.tensor([tokenizer.encode("Here is some text to encode", add_special_tokens=True)])

# Exportintg with dynamic axes
torch.onnx.export(model,
                  sample_input_ids,
                  f"{onnx_path}/bert_dynamic.onnx",
                  input_names=['input'],
                  output_names=['output'],
                  dynamic_axes={'input': {1: 'sequence'}})

In [20]:
# We can see our saved model in the gdrive
! ls /content/drive/My\ Drive/MLOPS/hands_on/onnx/model

bert_dynamic.onnx  bert_static.onnx


In [0]:
# Load the ONNX model
onnx_dynamic_model = onnx.load(f"{onnx_path}/bert_dynamic.onnx")

In [22]:
onnx_dynamic_model.graph.input[0]

name: "input"
type {
  tensor_type {
    elem_type: 7
    shape {
      dim {
        dim_value: 1
      }
      dim {
        dim_param: "sequence"
      }
    }
  }
}

In [23]:
onnx_dynamic_model.graph.output[0]

name: "output"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_value: 1
      }
      dim {
        dim_value: 9
      }
      dim {
        dim_value: 768
      }
    }
  }
}

In [0]:
# Initialize the runtime inference session
ort_session = ort.InferenceSession(f"{onnx_path}/bert_dynamic.onnx")

In [25]:
# Test input text
input_ids = torch.tensor([tokenizer.encode("Here is some more and more text to encode", add_special_tokens=True)])

# Create the input dictionary for the ONNX runtime - accepts numpy arrays
inference_input = {onnx_dynamic_model.graph.input[0].name: input_ids.numpy()}

# .run() conducts inference
outputs = ort_session.run(None, inference_input)

print(outputs[0])

[[[ 1.50481716e-01  2.69410014e-01  2.00014189e-01 ... -2.39659578e-01
    3.17833632e-01  6.80566788e-01]
  [-3.93219739e-01 -3.43886405e-01  1.03734061e-01 ... -2.16759071e-01
    6.49665594e-01 -2.37899646e-02]
  [ 1.87670708e-01 -3.85836780e-01  1.58024997e-01 ...  1.34617090e-04
    6.79188967e-02  6.33356869e-01]
  ...
  [ 3.33162278e-01  1.59700871e-01  5.05449474e-01 ... -1.12704746e-01
   -9.19650793e-02  5.34562707e-01]
  [ 4.43043977e-01  4.08770025e-01  5.10673895e-02 ... -2.96072185e-01
    6.52517378e-02  6.66331872e-02]
  [ 8.38668346e-01  2.93955117e-01 -1.72309980e-01 ...  1.22411609e-01
   -5.04576206e-01 -1.15954645e-01]]]
