In [None]:
import sys
sys.path.insert(0, "/home/abel/onnxt5")
from transformers import T5Tokenizer
from datetime import datetime
from onnxruntime import InferenceSession, SessionOptions, ExecutionMode
from onnxt5 import create_t5_encoder_decoder, GenerativeT5
import torch
import numpy as np


pretrained_model = 't5-base' # This can be a pretrained version, or the path to a huggingface model
simplified_encoder, decoder_with_lm_head = create_t5_encoder_decoder(pretrained_model)
tokenizer = T5Tokenizer.from_pretrained(pretrained_model)
generative_t5_pytorch = GenerativeT5(simplified_encoder.cuda(), decoder_with_lm_head.cuda(), tokenizer, cuda=True)


decoder_sess = InferenceSession('/home/abel/t5-decoder-with-lm-head.onnx')
encoder_sess = InferenceSession('/home/abel/t5-encoder.onnx')
options = SessionOptions()
options.intra_op_num_threads = 1
options.execution_mode = ExecutionMode.ORT_SEQUENTIAL
tokenizer = T5Tokenizer.from_pretrained(pretrained_model)
generative_t5_onnx = GenerativeT5(encoder_sess, decoder_sess, tokenizer, onnx=True)


In [None]:
pt_means = []
pt_stds = []
onnx_means = []
onnx_stds = []

for i in range(2, 1005, 50):
  temp_pytorch = []
  temp_onnx = []
  for o in range(10):
    t1 = datetime.now()
    generative_t5_pytorch('Start', i, temperature=0.)
    t2 = datetime.now()
    generative_t5_onnx('Start', i, temperature=0.)
    t3 = datetime.now()
    pytorch_t = (t2-t1).total_seconds()
    onnx_t = (t3-t2).total_seconds()
    temp_pytorch.append(pytorch_t)
    temp_onnx.append(onnx_t)
  pt_means.append(np.mean(temp_pytorch))
  pt_stds.append(np.std(temp_pytorch))
  onnx_means.append(np.mean(temp_onnx))
  onnx_stds.append(np.std(temp_onnx))


In [None]:
import plotly
plotly.io.renderers.default = 'colab'
import plotly.graph_objects as go

x = list(range(2, 1005, 50))

fig = go.Figure(data=[
    # Pytorch
    go.Scatter(
        x=x,
        name='PyTorch',
        y=pt_means,
        error_y=dict(
            type='data',
            array=pt_stds,
            visible=True)
    ),
    go.Scatter(
        x=x,
        name='ONNX',
        y=onnx_means,
        error_y=dict(
            type='data',
            array=onnx_stds,
            visible=True)
    ),
    ])
fig.update_layout(
    title="Benchmark of Inference time per number of characters to generate (increasing context)",
    xaxis_title="Number of characters to generate (with expanding context)",
    yaxis_title="Seconds to complete",
    legend_title="Framework",
)
fig.show()

In [None]:
pt_means_emb = []
pt_stds_emb = []
onnx_means_emb = []
onnx_stds_emb = []

for i in range(2, 1005, 50):
  temp_pytorch = []
  temp_onnx = []
  for o in range(10):
    inputs = torch.tensor([[1] * i]).cuda()
    inputs_numpy = inputs.cpu().numpy()
    t1 = datetime.now()
    _ = decoder_with_lm_head(inputs, simplified_encoder(inputs))
    t2 = datetime.now()
    encoder_output = encoder_sess.run(None, {"input_ids": inputs_numpy})[0]
    # To generate the full model's embeddings
    decoder_output = decoder_sess.run(None, {
                                            "input_ids": inputs_numpy,
                                            "encoder_hidden_states": encoder_output
        })[0]
    t3 = datetime.now()
    pytorch_t = (t2-t1).total_seconds()
    onnx_t = (t3-t2).total_seconds()
    print(f'{i}: pt {pytorch_t}s, onnx {onnx_t}s')
    temp_pytorch.append(pytorch_t)
    temp_onnx.append(onnx_t)
  pt_means_emb.append(np.mean(temp_pytorch))
  pt_stds_emb.append(np.std(temp_pytorch))
  onnx_means_emb.append(np.mean(temp_onnx))
  onnx_stds_emb.append(np.std(temp_onnx))

In [None]:
import plotly
plotly.io.renderers.default = 'colab'
import plotly.graph_objects as go

x = list(range(2, 1005, 50))

fig = go.Figure(data=[
    # Pytorch
    go.Scatter(
        x=x,
        name='PyTorch',
        y=pt_means_emb,
        error_y=dict(
            type='data',
            array=pt_stds_emb,
            visible=True)
    ),
    go.Scatter(
        x=x,
        name='ONNX',
        y=onnx_means_emb,
        error_y=dict(
            type='data',
            array=onnx_stds_emb,
            visible=True)
    ),
    ])
fig.update_layout(
    title="Benchmark of embedding time per number of characters to embed",
    xaxis_title="Number of characters to embed",
    yaxis_title="Seconds to complete",
    legend_title="Framework",
)
fig.show()