# Chapter-4 Model Optimization using Onnx-Simplifier and Onnxruntime

#### In this notebook, we will try to optimize GPT2 ONNX model using Onnx-Simplifier and Onnxruntime. We will also see the impact of these optimizations on the model.

## Part-1 : Export GPT2 ONNX Model

In [1]:
# Install prerequisites
!pip install onnx onnxsim onnxruntime transformers netron

Collecting onnx
  Downloading onnx-1.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (16 kB)
Collecting onnxsim
  Downloading onnxsim-0.4.36-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.3 kB)
Collecting onnxruntime
  Downloading onnxruntime-1.20.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.5 kB)
Collecting netron
  Downloading netron-8.1.8-py3-none-any.whl.metadata (1.5 kB)
Collecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Downloading onnx-1.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.0/16.0 MB[0m [31m23.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading onnxsim-0.4.36-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86

In [2]:
# Load GPT2 model from HuggingFace: https://huggingface.co/openai-community/gpt2

from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")

# Encode the input text (prompt) into tokens
input_text = "Once upon a time"
input_ids = tokenizer.encode(input_text, return_tensors='pt')
print("Input Ids shape: ", input_ids.shape)

# Generate text using the model
output = model.generate(input_ids, max_length=50, num_return_sequences=1, no_repeat_ngram_size=2)

# Decode the generated tokens back into text
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

print("--"*30)
print(f"Given input: {input_text}")
print(f"Generated output: {generated_text}")
print("--"*30)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Input Ids shape:  torch.Size([1, 4])
------------------------------------------------------------
Given input: Once upon a time
Generated output: Once upon a time, the world was a place of great beauty and great danger. The world of the gods was the place where the great gods were born, and where they were to live.

The world that was created was not the same
------------------------------------------------------------


In [3]:
# Export GPT2 model to ONNX

import os
import torch
os.makedirs("./exported_models/", exist_ok=True)
static_shape_output_path = "./exported_models/gpt2_hf_static_shape.onnx"
dynamic_shape_output_path = "./exported_models/gpt2_hf_dynamic_shape.onnx"

# Export the model to ONNX with static shapes
dummy_static_input_ids = torch.ones([1, 128], dtype=torch.int32)
torch.onnx.export(
    model,
    args=(dummy_static_input_ids,),
    f=static_shape_output_path,
    input_names=["input_ids"],
    output_names=["logits"],
    opset_version=14  # ONNX opset version (use 14 or later for exporting models with sdpa attention)
)

# Export the model to ONNX with dynamic shapes
torch.onnx.export(
    model,
    args=(input_ids,),
    f=dynamic_shape_output_path,
    input_names=["input_ids"],
    output_names=["logits"],
    dynamic_axes={"input_ids": {0: "batch_size", 1: "sequence_length"},  # Dynamic axes for batch size and sequence length
                  "logits": {0: "batch_size", 1: "sequence_length"}},
    opset_version=14  # ONNX opset version (use 14 or later for exporting models with sdpa attention)
)

print(f"Model with static shapes successfully exported to {static_shape_output_path}")
print(f"Model with dynamic shapes successfully exported to {dynamic_shape_output_path}")

  if input_shape[-1] > 1 or self.sliding_window is not None:
  if past_key_values_length > 0:


Model with static shapes successfully exported to ./exported_models/gpt2_hf_static_shape.onnx
Model with dynamic shapes successfully exported to ./exported_models/gpt2_hf_dynamic_shape.onnx


In [4]:
# Visualize the exported model with static shapes

import IPython
import netron

port = 6006
netron.start(static_shape_output_path, port, browse=False)
IPython.display.IFrame(f"http://localhost:{port}", width=1000, height=500)

Serving './exported_models/gpt2_hf_static_shape.onnx' at http://localhost:6006


In [5]:
# Visualize the exported model with dynamic shapes

import IPython
import netron

port = 6006
netron.start(dynamic_shape_output_path, port, browse=False)
IPython.display.IFrame(f"http://localhost:{port}", width=1000, height=500)

Stopping http://localhost:6006
Serving './exported_models/gpt2_hf_dynamic_shape.onnx' at http://localhost:6006


In [6]:
# Below code is used to compare the original model with optimized model

import onnxruntime as ort
import numpy as np
import random
import time

def check_performance(model_path, input_data, num_iter=100):
    # Perform inference and measure timing
    session = ort.InferenceSession(model_path)
    start = time.time()
    for i in range(num_iter):
        outputs = session.run(None, input_data)
    end = time.time()

    time_diff = (end-start)/num_iter
    print(f"Inference time: {time_diff:.4f} seconds")

## Part-2 : Optimize model using Onnx-Simplifier

In [7]:
import onnx
from onnxsim import simplify

def optimize_model_using_simplifier(model_path, output_path):
    # Load onnx model
    onnx_model = onnx.load(model_path)

    # Simplify model using Onnx-Simplifier
    simplified_model, status = simplify(onnx_model)

    # Save simplified model
    onnx.save(simplified_model, output_path)
    print(f"Before Nodes: {len(onnx_model.graph.node)}")
    print(f"After Nodes: {len(simplified_model.graph.node)}")

print("Model with static shapes:")
opt_model_onnxsim_static_shape = "./exported_models/gpt2_hf_static_shapes_onnxsim.onnx"
optimize_model_using_simplifier(static_shape_output_path, opt_model_onnxsim_static_shape)

print("Model with dynamic shapes:")
opt_model_onnxsim_dynamic_shape = "./exported_models/gpt2_hf_dynamic_shapes_onnxsim.onnx"
optimize_model_using_simplifier(dynamic_shape_output_path, opt_model_onnxsim_dynamic_shape)

Model with static shapes:
Before Nodes: 1157
After Nodes: 672
Model with dynamic shapes:
Before Nodes: 2781
After Nodes: 1435


In [8]:
# Visualize the optimized model with static shapes

import IPython
import netron

port = 6006
netron.start(opt_model_onnxsim_static_shape, port, browse=False)
IPython.display.IFrame(f"http://localhost:{port}", width=1000, height=500)

Stopping http://localhost:6006
Serving './exported_models/gpt2_hf_static_shapes_onnxsim.onnx' at http://localhost:6006


In [9]:
# Visualize the optimized model with dynamic shapes

import IPython
import netron

port = 6006
netron.start(opt_model_onnxsim_dynamic_shape, port, browse=False)
IPython.display.IFrame(f"http://localhost:{port}", width=1000, height=500)

Stopping http://localhost:6006
Serving './exported_models/gpt2_hf_dynamic_shapes_onnxsim.onnx' at http://localhost:6006


In [10]:
# Let us make a dummy input tensor of shape [1, 128] for checking the performance of the models
input_data_for_static_shape = {"input_ids" : np.random.randint(low=0, high=100, size=(1, 128), dtype=np.int32)}

# Check performance for static shape model
print("Original model with static shapes")
check_performance(static_shape_output_path, input_data_for_static_shape)

print("Optimized model with static shapes")
check_performance(opt_model_onnxsim_static_shape, input_data_for_static_shape)

Original model with static shapes
Inference time: 0.5835 seconds
Optimized model with static shapes
Inference time: 0.5815 seconds


In [11]:
input_data_for_dynamic_shape = {"input_ids" : np.random.randint(low=0, high=100, size=(1, 128), dtype=np.int64)}

# Check performance for dynamic shape model
print("Original model with dynamic shapes")
check_performance(dynamic_shape_output_path, input_data_for_dynamic_shape)

print("Optimized model with dynamic shapes")
check_performance(opt_model_onnxsim_dynamic_shape, input_data_for_dynamic_shape)

Original model with dynamic shapes
Inference time: 0.5887 seconds
Optimized model with dynamic shapes
Inference time: 0.5895 seconds


## Part-3 : Optimize model using Onnxruntime

In [12]:
import onnxruntime as rt

def optimize_model_using_ort(model_path, output_path):
    # Load onnx model
    sess_options = rt.SessionOptions()
    sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL

    # Below are the different levels of optimizations in onnxruntime
    # rt.GraphOptimizationLevel.ORT_DISABLE_ALL -> Disables all optimizations
    # rt.GraphOptimizationLevel.ORT_ENABLE_BASIC -> Enables basic optimizations
    # rt.GraphOptimizationLevel.ORT_ENABLE_EXTENDED -> Enables basic and extended optimizations
    # rt.GraphOptimizationLevel.ORT_ENABLE_ALL -> Enables all available optimizations including layout optimizations

    # To enable model serialization after graph optimization set this
    sess_options.optimized_model_filepath = output_path

    session = rt.InferenceSession(model_path, sess_options)

    # No need to run the model. Initializing the session will generate the optimized model

    orig_model = onnx.load(model_path)
    opt_model = onnx.load(output_path)
    print(f"Before Nodes: {len(orig_model.graph.node)}")
    print(f"After Nodes: {len(opt_model.graph.node)}")


In [13]:
print("Model with static shapes:")
opt_model_ort_static_shape = "./exported_models/gpt2_hf_static_shapes_ort.onnx"
optimize_model_using_ort(static_shape_output_path, opt_model_ort_static_shape)

print("Model with dynamic shapes:")
opt_model_ort_dynamic_shape = "./exported_models/gpt2_hf_dynamic_shapes_ort.onnx"
optimize_model_using_ort(dynamic_shape_output_path, opt_model_ort_dynamic_shape)

Model with static shapes:
Before Nodes: 1157
After Nodes: 377
Model with dynamic shapes:
Before Nodes: 2781
After Nodes: 1120


In [14]:
# Check performance for static shape model
print("Original model with static shapes")
check_performance(static_shape_output_path, input_data_for_static_shape)

print("Optimized model with static shapes")
check_performance(opt_model_ort_static_shape, input_data_for_static_shape)

Original model with static shapes
Inference time: 0.6063 seconds
Optimized model with static shapes
Inference time: 0.5949 seconds


In [15]:
# Check performance for dynamic shape model
print("Original model with dynamic shapes")
check_performance(dynamic_shape_output_path, input_data_for_dynamic_shape)

print("Optimized model with dynamic shapes")
check_performance(opt_model_ort_dynamic_shape, input_data_for_dynamic_shape)

Original model with dynamic shapes
Inference time: 0.5809 seconds
Optimized model with dynamic shapes
Inference time: 0.5868 seconds
