In [1]:
import sys
sys.path.insert(0, "/home/abel/onnxt5")
from transformers import T5Tokenizer
from datetime import datetime
from onnxruntime import InferenceSession, SessionOptions, ExecutionMode
from onnxt5 import create_t5_encoder_decoder, GenerativeT5
import torch
import numpy as np


pretrained_model = 't5-base' # This can be a pretrained version, or the path to a huggingface model
simplified_encoder, decoder_with_lm_head = create_t5_encoder_decoder(pretrained_model)
tokenizer = T5Tokenizer.from_pretrained(pretrained_model)
generative_t5_pytorch = GenerativeT5(simplified_encoder.cuda(), decoder_with_lm_head.cuda(), tokenizer, cuda=True)


decoder_sess = InferenceSession('/home/abel/t5-decoder-with-lm-head.onnx')
encoder_sess = InferenceSession('/home/abel/t5-encoder.onnx')
options = SessionOptions()
options.intra_op_num_threads = 1
options.execution_mode = ExecutionMode.ORT_SEQUENTIAL
tokenizer = T5Tokenizer.from_pretrained(pretrained_model)
generative_t5_onnx = GenerativeT5(encoder_sess, decoder_sess, tokenizer, onnx=True)


Some weights of T5ForConditionalGeneration were not initialized from the model checkpoint at t5-base and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 400/400 [00:09<00:00, 43.84it/s]
100%|██████████| 400/400 [00:08<00:00, 46.11it/s]


('Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start Start St

In [2]:
pt_means = []
pt_stds = []
onnx_means = []
onnx_stds = []

for i in range(2, 1005, 50):
  temp_pytorch = []
  temp_onnx = []
  for o in range(10):
    t1 = datetime.now()
    generative_t5_pytorch('Start', i, temperature=0.)
    t2 = datetime.now()
    generative_t5_onnx('Start', i, temperature=0.)
    t3 = datetime.now()
    pytorch_t = (t2-t1).total_seconds()
    onnx_t = (t3-t2).total_seconds()
    temp_pytorch.append(pytorch_t)
    temp_onnx.append(onnx_t)
  pt_means.append(np.mean(temp_pytorch))
  pt_stds.append(np.std(temp_pytorch))
  onnx_means.append(np.mean(temp_onnx))
  onnx_stds.append(np.std(temp_onnx))


100%|██████████| 2/2 [00:00<00:00, 49.86it/s]
100%|██████████| 2/2 [00:00<00:00, 145.21it/s]
100%|██████████| 2/2 [00:00<00:00, 51.90it/s]
100%|██████████| 2/2 [00:00<00:00, 155.11it/s]
100%|██████████| 2/2 [00:00<00:00, 51.35it/s]
100%|██████████| 2/2 [00:00<00:00, 155.32it/s]
100%|██████████| 2/2 [00:00<00:00, 53.54it/s]
100%|██████████| 2/2 [00:00<00:00, 159.77it/s]
100%|██████████| 2/2 [00:00<00:00, 52.30it/s]
100%|██████████| 2/2 [00:00<00:00, 156.40it/s]
100%|██████████| 2/2 [00:00<00:00, 51.97it/s]
100%|██████████| 2/2 [00:00<00:00, 141.00it/s]
100%|██████████| 2/2 [00:00<00:00, 52.82it/s]
100%|██████████| 2/2 [00:00<00:00, 159.93it/s]
100%|██████████| 2/2 [00:00<00:00, 51.90it/s]
100%|██████████| 2/2 [00:00<00:00, 155.52it/s]
100%|██████████| 2/2 [00:00<00:00, 53.50it/s]
100%|██████████| 2/2 [00:00<00:00, 159.45it/s]
100%|██████████| 2/2 [00:00<00:00, 52.79it/s]
100%|██████████| 2/2 [00:00<00:00, 148.45it/s]
100%|██████████| 52/52 [00:01<00:00, 49.55it/s]
100%|██████████| 52/52

In [15]:
import plotly
plotly.io.renderers.default = 'colab'
import plotly.graph_objects as go

x = list(range(2, 1005, 50))

fig = go.Figure(data=[
    # Pytorch
    go.Scatter(
        x=x,
        name='PyTorch',
        y=pt_means,
        error_y=dict(
            type='data',
            array=pt_stds,
            visible=True)
    ),
    go.Scatter(
        x=x,
        name='ONNX',
        y=onnx_means,
        error_y=dict(
            type='data',
            array=onnx_stds,
            visible=True)
    ),
    ])
fig.update_layout(
    title="Benchmark of Inference time per number of characters to generate (increasing context)",
    xaxis_title="Number of characters to generate (with expanding context)",
    yaxis_title="Seconds to complete",
    legend_title="Framework",
)
fig.show()

In [3]:
pt_means_emb = []
pt_stds_emb = []
onnx_means_emb = []
onnx_stds_emb = []

for i in range(2, 1005, 50):
  temp_pytorch = []
  temp_onnx = []
  for o in range(10):
    inputs = torch.tensor([[1] * i]).cuda()
    inputs_numpy = inputs.cpu().numpy()
    t1 = datetime.now()
    _ = decoder_with_lm_head(inputs, simplified_encoder(inputs))
    t2 = datetime.now()
    encoder_output = encoder_sess.run(None, {"input_ids": inputs_numpy})[0]
    # To generate the full model's embeddings
    decoder_output = decoder_sess.run(None, {
                                            "input_ids": inputs_numpy,
                                            "encoder_hidden_states": encoder_output
        })[0]
    t3 = datetime.now()
    pytorch_t = (t2-t1).total_seconds()
    onnx_t = (t3-t2).total_seconds()
    print(f'{i}: pt {pytorch_t}s, onnx {onnx_t}s')
    temp_pytorch.append(pytorch_t)
    temp_onnx.append(onnx_t)
  pt_means_emb.append(np.mean(temp_pytorch))
  pt_stds_emb.append(np.std(temp_pytorch))
  onnx_means_emb.append(np.mean(temp_onnx))
  onnx_stds_emb.append(np.std(temp_onnx))

2: pt 0.038241s, onnx 0.01003s
2: pt 0.039678s, onnx 0.010294s
2: pt 0.039551s, onnx 0.009932s
2: pt 0.04018s, onnx 0.010839s
2: pt 0.04051s, onnx 0.010325s
2: pt 0.039115s, onnx 0.009822s
2: pt 0.037788s, onnx 0.009947s
2: pt 0.036658s, onnx 0.009754s
2: pt 0.037662s, onnx 0.00951s
2: pt 0.03748s, onnx 0.009797s
52: pt 0.046762s, onnx 0.015137s
52: pt 0.056467s, onnx 0.01482s
52: pt 0.041236s, onnx 0.014591s
52: pt 0.043615s, onnx 0.01486s
52: pt 0.042543s, onnx 0.01466s
52: pt 0.041517s, onnx 0.01454s
52: pt 0.043747s, onnx 0.014455s
52: pt 0.043772s, onnx 0.014576s
52: pt 0.041639s, onnx 0.014583s
52: pt 0.042061s, onnx 0.014757s
102: pt 0.055383s, onnx 0.019927s
102: pt 0.052079s, onnx 0.019665s
102: pt 0.043977s, onnx 0.019478s
102: pt 0.042678s, onnx 0.019405s
102: pt 0.042837s, onnx 0.01958s
102: pt 0.043119s, onnx 0.01956s
102: pt 0.042628s, onnx 0.019403s
102: pt 0.043493s, onnx 0.019908s
102: pt 0.042388s, onnx 0.019466s
102: pt 0.044001s, onnx 0.019627s
152: pt 0.053718s, on

In [4]:
import plotly
plotly.io.renderers.default = 'colab'
import plotly.graph_objects as go

x = list(range(2, 1005, 50))

fig = go.Figure(data=[
    # Pytorch
    go.Scatter(
        x=x,
        name='PyTorch',
        y=pt_means_emb,
        error_y=dict(
            type='data',
            array=pt_stds_emb,
            visible=True)
    ),
    go.Scatter(
        x=x,
        name='ONNX',
        y=onnx_means_emb,
        error_y=dict(
            type='data',
            array=onnx_stds_emb,
            visible=True)
    ),
    ])
fig.update_layout(
    title="Benchmark of embedding time per number of characters to embed",
    xaxis_title="Number of characters to embed",
    yaxis_title="Seconds to complete",
    legend_title="Framework",
)
fig.show()