In [1]:
from time import time
from typing import Callable, Dict, Set

import numpy as np
import onnx
import tensorrt as trt
import torch
from onnx import ModelProto
from tensorrt import ICudaEngine
from tensorrt.tensorrt import Logger, Runtime
from torch.nn import Linear
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, PretrainedConfig, T5ForConditionalGeneration, TensorType
from transformers.generation_utils import GenerationMixin
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput
from transformers.models.t5.modeling_t5 import T5Stack

from transformer_deploy.backends.ort_utils import create_model_for_provider, inference_onnx_binding, optimize_onnx
from transformer_deploy.backends.pytorch_utils import convert_to_onnx
from transformer_deploy.backends.trt_utils import (
    TensorRTShape,
    add_output_nodes,
    build_engine,
    get_adjency_dict,
    get_fix_fp16_network_func,
    get_list_fp32_nodes,
    load_engine,
    save_engine,
)

In [2]:
model_name = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_ids: torch.Tensor = tokenizer("Studies show that", return_tensors=TensorType.PYTORCH).input_ids
input_ids = input_ids.to("cuda")
model: T5ForConditionalGeneration = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model = model.eval()
model = model.to("cuda")
out_enc: BaseModelOutputWithPastAndCrossAttentions = model.encoder(input_ids=input_ids)
out_full: Seq2SeqLMOutput = model(input_ids=input_ids, decoder_input_ids=input_ids)

In [3]:
model = model.to("cuda")

convert_to_onnx(
    model_pytorch=model.encoder,
    output_path="test-enc.onnx",
    inputs_pytorch={"input_ids": input_ids},
    var_output_seq=True,
    quantization=False,
)
optimize_onnx(
    onnx_path="test-enc.onnx", onnx_optim_model_path="test-enc-opt.onnx", architecture="bert", use_cuda=True, fp16=True
)

enc_onnx = create_model_for_provider("test-enc-opt.onnx", "CUDAExecutionProvider")
enc_onnx_out = inference_onnx_binding(
    model_onnx=enc_onnx,
    inputs={"input_ids": input_ids},
    device=input_ids.device.type,
    output_shape=tuple(input_ids.shape) + (int(model.encoder.config.d_model),),
)["output"]
assert np.allclose(enc_onnx_out.detach().cpu().numpy(), out_enc.last_hidden_state.detach().cpu().numpy(), atol=1e-2)

In [4]:
from typing import Tuple


class ExportT5(torch.nn.Module):
    def __init__(self, decoder: T5Stack, lm_head: Linear):
        super(ExportT5, self).__init__()
        self.decoder = decoder
        self.lm_head = lm_head

    def forward(self, input_ids: torch.Tensor, encoder_hidden_states: torch.Tensor, past_key_values: Tuple = None):
        out_dec = self.decoder.forward(
            input_ids=input_ids, encoder_hidden_states=encoder_hidden_states, past_key_values=past_key_values
        )
        # Rescale output before projecting on vocab
        out_dec["last_hidden_state"] = out_dec["last_hidden_state"] * (model.model_dim**-0.5)
        # out_lm = self.lm_head(out_dec)
        out_dec["last_hidden_state"] = self.lm_head(out_dec["last_hidden_state"])
        return out_dec


model.cuda()
model_decoder = ExportT5(decoder=model.decoder, lm_head=model.lm_head).eval()
out_model_export: torch.Tensor = model_decoder(input_ids=input_ids, encoder_hidden_states=out_enc.last_hidden_state)
assert np.allclose(
    out_model_export["last_hidden_state"].detach().cpu().numpy(), out_full.logits.detach().cpu().numpy(), atol=1e-5
)

In [5]:
model_decoder.cuda()
# decoder output one step before
out_dec_pytorch = model_decoder(input_ids=input_ids[:, :-1], encoder_hidden_states=out_enc.last_hidden_state)

model_inputs = {
    "input_ids": input_ids[:, -1:].type(torch.int32),
    "encoder_hidden_states": out_enc.last_hidden_state,
    "past_key_values": out_dec_pytorch.past_key_values,
}

input_names = [
    "input_ids",
    "encoder_hidden_states",
    "past_key_values.0.decoder.key",
    "past_key_values.0.decoder.value",
    "past_key_values.0.encoder.key",
    "past_key_values.0.encoder.value",
    "past_key_values.1.decoder.key",
    "past_key_values.1.decoder.value",
    "past_key_values.1.encoder.key",
    "past_key_values.1.encoder.value",
    "past_key_values.2.decoder.key",
    "past_key_values.2.decoder.value",
    "past_key_values.2.encoder.key",
    "past_key_values.2.encoder.value",
    "past_key_values.3.decoder.key",
    "past_key_values.3.decoder.value",
    "past_key_values.3.encoder.key",
    "past_key_values.3.encoder.value",
    "past_key_values.4.decoder.key",
    "past_key_values.4.decoder.value",
    "past_key_values.4.encoder.key",
    "past_key_values.4.encoder.value",
    "past_key_values.5.decoder.key",
    "past_key_values.5.decoder.value",
    "past_key_values.5.encoder.key",
    "past_key_values.5.encoder.value",
]

output_names = [
    "logits",
    "present.0.decoder.key",
    "present.0.decoder.value",
    "present.0.encoder.key",
    "present.0.encoder.value",
    "present.1.decoder.key",
    "present.1.decoder.value",
    "present.1.encoder.key",
    "present.1.encoder.value",
    "present.2.decoder.key",
    "present.2.decoder.value",
    "present.2.encoder.key",
    "present.2.encoder.value",
    "present.3.decoder.key",
    "present.3.decoder.value",
    "present.3.encoder.key",
    "present.3.encoder.value",
    "present.4.decoder.key",
    "present.4.decoder.value",
    "present.4.encoder.key",
    "present.4.encoder.value",
    "present.5.decoder.key",
    "present.5.decoder.value",
    "present.5.encoder.key",
    "present.5.encoder.value",
]

dynamic_axis = {
    "input_ids": {0: "batch", 1: "encoder_sequence"},
    "encoder_hidden_states": {0: "batch", 1: "encoder_sequence"},
    "past_key_values.0.decoder.key": {0: "batch", 2: "past_decoder_sequence"},
    "past_key_values.0.decoder.value": {0: "batch", 2: "past_decoder_sequence"},
    "past_key_values.0.encoder.key": {0: "batch", 2: "past_encoder_sequence"},
    "past_key_values.0.encoder.value": {0: "batch", 2: "past_encoder_sequence"},
    "past_key_values.1.decoder.key": {0: "batch", 2: "past_decoder_sequence"},
    "past_key_values.1.decoder.value": {0: "batch", 2: "past_decoder_sequence"},
    "past_key_values.1.encoder.key": {0: "batch", 2: "past_encoder_sequence"},
    "past_key_values.1.encoder.value": {0: "batch", 2: "past_encoder_sequence"},
    "past_key_values.2.decoder.key": {0: "batch", 2: "past_decoder_sequence"},
    "past_key_values.2.decoder.value": {0: "batch", 2: "past_decoder_sequence"},
    "past_key_values.2.encoder.key": {0: "batch", 2: "past_encoder_sequence"},
    "past_key_values.2.encoder.value": {0: "batch", 2: "past_encoder_sequence"},
    "past_key_values.3.decoder.key": {0: "batch", 2: "past_decoder_sequence"},
    "past_key_values.3.decoder.value": {0: "batch", 2: "past_decoder_sequence"},
    "past_key_values.3.encoder.key": {0: "batch", 2: "past_encoder_sequence"},
    "past_key_values.3.encoder.value": {0: "batch", 2: "past_encoder_sequence"},
    "past_key_values.4.decoder.key": {0: "batch", 2: "past_decoder_sequence"},
    "past_key_values.4.decoder.value": {0: "batch", 2: "past_decoder_sequence"},
    "past_key_values.4.encoder.key": {0: "batch", 2: "past_encoder_sequence"},
    "past_key_values.4.encoder.value": {0: "batch", 2: "past_encoder_sequence"},
    "past_key_values.5.decoder.key": {0: "batch", 2: "past_decoder_sequence"},
    "past_key_values.5.decoder.value": {0: "batch", 2: "past_decoder_sequence"},
    "past_key_values.5.encoder.key": {0: "batch", 2: "past_encoder_sequence"},
    "past_key_values.5.encoder.value": {0: "batch", 2: "past_encoder_sequence"},
    "logits": {0: "batch", 1: "decoder_sequence"},
    "present.0.decoder.key": {0: "batch", 2: "past_decoder_sequence + sequence"},
    "present.0.decoder.value": {0: "batch", 2: "past_decoder_sequence + sequence"},
    "present.0.encoder.key": {0: "batch", 2: "past_encoder_sequence"},
    "present.0.encoder.value": {0: "batch", 2: "past_encoder_sequence"},
    "present.1.decoder.key": {0: "batch", 2: "past_decoder_sequence + sequence"},
    "present.1.decoder.value": {0: "batch", 2: "past_decoder_sequence + sequence"},
    "present.1.encoder.key": {0: "batch", 2: "past_encoder_sequence"},
    "present.1.encoder.value": {0: "batch", 2: "past_encoder_sequence"},
    "present.2.decoder.key": {0: "batch", 2: "past_decoder_sequence + sequence"},
    "present.2.decoder.value": {0: "batch", 2: "past_decoder_sequence + sequence"},
    "present.2.encoder.key": {0: "batch", 2: "past_encoder_sequence"},
    "present.2.encoder.value": {0: "batch", 2: "past_encoder_sequence"},
    "present.3.decoder.key": {0: "batch", 2: "past_decoder_sequence + sequence"},
    "present.3.decoder.value": {0: "batch", 2: "past_decoder_sequence + sequence"},
    "present.3.encoder.key": {0: "batch", 2: "past_encoder_sequence"},
    "present.3.encoder.value": {0: "batch", 2: "past_encoder_sequence"},
    "present.4.decoder.key": {0: "batch", 2: "past_decoder_sequence + sequence"},
    "present.4.decoder.value": {0: "batch", 2: "past_decoder_sequence + sequence"},
    "present.4.encoder.key": {0: "batch", 2: "past_encoder_sequence"},
    "present.4.encoder.value": {0: "batch", 2: "past_encoder_sequence"},
    "present.5.decoder.key": {0: "batch", 2: "past_decoder_sequence + sequence"},
    "present.5.decoder.value": {0: "batch", 2: "past_decoder_sequence + sequence"},
    "present.5.encoder.key": {0: "batch", 2: "past_encoder_sequence"},
    "present.5.encoder.value": {0: "batch", 2: "past_encoder_sequence"},
}

with torch.no_grad():
    model.config.return_dict = True
    model.eval()

    # export can works with named args but the dict containing named args as to be last element of the args tuple
    torch.onnx.export(
        model_decoder,
        (model_inputs,),
        f="test-dec-cache.onnx",
        input_names=input_names,
        output_names=output_names,
        dynamic_axes=dynamic_axis,
        do_constant_folding=True,
        use_external_data_format=False,
        enable_onnx_checker=True,
        opset_version=13,
    )

model_inputs_no_cache = {
    "input_ids": input_ids.type(torch.int32),
    "encoder_hidden_states": out_enc.last_hidden_state,
}

with torch.no_grad():
    model.config.return_dict = True
    model.eval()

    # export can works with named args but the dict containing named args as to be last element of the args tuple
    torch.onnx.export(
        model_decoder,
        (model_inputs_no_cache,),
        f="test-dec-no-cache.onnx",
        input_names=list(model_inputs_no_cache.keys()),
        output_names=output_names,
        dynamic_axes={k: v for k, v in dynamic_axis.items() if "past_key_values" not in k},
        do_constant_folding=True,
        use_external_data_format=False,
        enable_onnx_checker=True,
        opset_version=13,
    )

  if causal_mask.shape[1] < attention_mask.shape[1]:


In [18]:
import onnx
from onnx import GraphProto, ModelProto, helper

onnx_model_cache = onnx.load("test-dec-cache.onnx")
onnx_model_no_cache = onnx.load("test-dec-no-cache.onnx")


prefix = "cache_node_"
mapping_initializer_cache_to_no_cache = dict()
to_add = list()
for node_cache in onnx_model_cache.graph.initializer:
    found = False
    for node_no_cache in onnx_model_no_cache.graph.initializer:
        if node_cache.raw_data == node_no_cache.raw_data:
            found = True
            mapping_initializer_cache_to_no_cache[node_cache.name] = node_no_cache.name
            break
    if not found:
        node_cache.name = prefix + node_cache.name
        to_add.append(node_cache)
        mapping_initializer_cache_to_no_cache[node_cache.name] = node_cache.name
        print(f"name: {node_cache.name} - size: {len(node_cache.raw_data)/1024:.2f}")

onnx_model_no_cache.graph.initializer.extend(to_add)
# I/O model names should not be prefixed
model_io_names = [n.name for n in list(onnx_model_cache.graph.input) + list(onnx_model_cache.graph.output)]

for node in onnx_model_cache.graph.node:
    for index, input_name in enumerate(node.input):
        if input_name in model_io_names:
            continue
        node.input[index] = mapping_initializer_cache_to_no_cache.get(input_name, prefix + input_name)
    for index, output_name in enumerate(node.output):
        if output_name in model_io_names:
            continue
        node.output[index] = prefix + output_name
    node.name = prefix + node.name
model_io_names = [n.name for n in list(onnx_model_cache.graph.input) + list(onnx_model_cache.graph.output)]

prefix = "init_"
cache = dict()
for node in onnx_model_no_cache.graph.initializer:
    if node.name in model_io_names:
        new_name = prefix + node.name
        cache[node.name] = new_name
        node.name = new_name

for node in onnx_model_no_cache.graph.node:
    for index, n in enumerate(node.input):
        node.input[index] = cache.get(n, n)

# mandatory for subgraph in if/else node
assert len(onnx_model_cache.graph.output) == len(onnx_model_no_cache.graph.output)

graph_cache: onnx.GraphProto = onnx.helper.make_graph(
    nodes=list(onnx_model_cache.graph.node),
    name="graph-cache",
    inputs=[],
    outputs=list(onnx_model_cache.graph.output),
    initializer=[],
)

graph_no_cache: onnx.GraphProto = onnx.helper.make_graph(
    nodes=list(onnx_model_no_cache.graph.node),
    name="graph-no-cache",
    inputs=[],
    outputs=list(onnx_model_no_cache.graph.output),
    initializer=[],
)

enable_cache = onnx.helper.make_tensor_value_info(name="enable_cache", elem_type=onnx.TensorProto.BOOL, shape=[1])

if_node = onnx.helper.make_node(
    op_type="If",
    inputs=["enable_cache"],
    outputs=[o.name for o in list(onnx_model_no_cache.graph.output)],
    then_branch=graph_cache,
    else_branch=graph_no_cache,
)

if_graph_def: GraphProto = helper.make_graph(
    nodes=[if_node],
    name="if-model",
    inputs=list(onnx_model_cache.graph.input) + [enable_cache],
    outputs=list(onnx_model_no_cache.graph.output),
    initializer=list(onnx_model_no_cache.graph.initializer),
)

model_def: ModelProto = helper.make_model(if_graph_def, producer_name="onnx-example")

onnx.checker.check_model(model_def)

name: cache_node_1260 - size: 0.01
name: cache_node_1261 - size: 0.01
name: cache_node_1271 - size: 0.01
name: cache_node_1272 - size: 0.01


In [7]:
model = model.cpu()
model_decoder = model_decoder.cpu()
input_ids = input_ids.cpu()

out_enc: BaseModelOutputWithPastAndCrossAttentions = model.encoder(input_ids=input_ids)
out_dec_pytorch = model_decoder(input_ids=input_ids[:, :-1], encoder_hidden_states=out_enc.last_hidden_state)

ort_cachable = create_model_for_provider(model_def.SerializeToString(), "CPUExecutionProvider")
input_ort = dict()
input_ort["input_ids"] = input_ids
input_ort["encoder_hidden_states"] = out_enc.last_hidden_state
input_ort["enable_cache"] = torch.tensor([False], device="cpu", dtype=torch.bool)

output_shape = {"logits": tuple(input_ort["input_ids"].shape) + (int(model.config.vocab_size),)}

result_no_cache = inference_onnx_binding(
    model_onnx=ort_cachable,
    inputs=input_ort,
    device=input_ids.device.type,
    output_shape=output_shape,
)

input_ort["enable_cache"] = torch.tensor([True], device="cpu", dtype=torch.bool)
input_ort["input_ids"] = input_ort["input_ids"][:, -1:].type(torch.int32)
output_shape = {"logits": tuple(input_ort["input_ids"].shape) + (int(model.config.vocab_size),)}

for index, (k_dec, v_dec, k_enc, v_enc) in enumerate(
    out_dec_pytorch.past_key_values
):  # type: int, (torch.Tensor, torch.Tensor)
    input_ort[f"past_key_values.{index}.decoder.key"] = k_dec
    input_ort[f"past_key_values.{index}.decoder.value"] = v_dec
    input_ort[f"past_key_values.{index}.encoder.key"] = k_enc
    input_ort[f"past_key_values.{index}.encoder.value"] = v_enc


result_cache = inference_onnx_binding(
    model_onnx=ort_cachable,
    inputs=input_ort,
    device=input_ids.device.type,
    output_shape=output_shape,
)

assert np.allclose(a=result_cache["logits"][:, -1:, :], b=result_no_cache["logits"][:, -1:, :], atol=1e-2)

result_python = model_decoder(input_ids=input_ids, encoder_hidden_states=out_enc.last_hidden_state)

assert np.allclose(
    a=result_no_cache["logits"][:, -1:, :], b=result_python.last_hidden_state[:, -1:, :].detach().numpy(), atol=1e-2
)

# del result_python
# del result_cache
del result_no_cache
del ort_cachable
del input_ort
del out_enc
del out_dec_pytorch

2022-04-24 21:31:32.155804671 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'cache_node_1207'. It is not used by any node and should be removed from the model.
2022-04-24 21:31:32.155836105 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'cache_node_1157'. It is not used by any node and should be removed from the model.
2022-04-24 21:31:32.155841281 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'cache_node_1073'. It is not used by any node and should be removed from the model.
2022-04-24 21:31:32.155845637 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'cache_node_1056'. It is not used by any node and should be removed from the model.
2022-04-24 21:31:32.155849310 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'cache_node_1006'. It is not used by any node and should be removed from t

In [19]:
model = model.to("cuda")
model_decoder = model_decoder.cuda()
model = model.eval()
model_decoder = model_decoder.eval()
input_ids = input_ids.cuda()
enc_onnx = create_model_for_provider("test-enc-opt.onnx", "CUDAExecutionProvider")
# model_def.SerializeToString()
# "test-dec-no-cache.onnx"
dec_onnx = create_model_for_provider(model_def.SerializeToString(), "CUDAExecutionProvider")


def decoder_pytorch_inference(input_ids: torch.Tensor, last_hidden_state: torch.Tensor):
    return model_decoder(input_ids=input_ids, encoder_hidden_states=last_hidden_state).last_hidden_state


# TODO export past present values from model
def decoder_onnx_inference(input_ids: torch.Tensor, last_hidden_state: torch.Tensor, enable_cache: torch.Tensor):
    output_shape = {"logits": tuple(input_ids.shape) + (int(model.config.vocab_size),)}
    result_dict = inference_onnx_binding(
        model_onnx=dec_onnx,
        inputs={"input_ids": input_ids, "encoder_hidden_states": last_hidden_state, "enable_cache": enable_cache},
        device=input_ids.device.type,
        output_shape=output_shape,
    )
    return BaseModelOutputWithPastAndCrossAttentions(
        last_hidden_state=result_dict["logits"],
    )  # past_key_values=((torch.tensor(1, device="cuda"),),)


out_enc: BaseModelOutputWithPastAndCrossAttentions = model.encoder(input_ids=input_ids)
dec_onnx_out = decoder_onnx_inference(
    input_ids=input_ids, last_hidden_state=out_enc.last_hidden_state, enable_cache=torch.tensor([False], device="cuda")
).last_hidden_state
assert np.allclose(a=dec_onnx_out.detach().cpu().numpy(), b=out_full.logits.detach().cpu().numpy(), atol=1e-1)


def encoder_onnx_inference(input_ids: torch.Tensor, **_) -> BaseModelOutputWithPastAndCrossAttentions:
    last_hidden_state = inference_onnx_binding(
        model_onnx=enc_onnx,  # noqa: F821
        inputs={"input_ids": input_ids},
        output_shape=tuple(input_ids.shape) + (int(model.encoder.config.d_model),),
        device=input_ids.device.type,
    )["output"]
    print("encoder")
    return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=last_hidden_state)


def encoder_pytorch_inference(input_ids, **_) -> BaseModelOutputWithPastAndCrossAttentions:
    return model.encoder(input_ids=input_ids)


# https://github.com/NVIDIA/TensorRT/blob/main/demo/HuggingFace/T5/export.py
class ExtT5(torch.nn.Module, GenerationMixin):
    def __init__(self, config: PretrainedConfig, device: torch.device, encoder_func: Callable, decoder_func: Callable):
        super(ExtT5, self).__init__()
        self.main_input_name = "input_ids"  # https://github.com/huggingface/transformers/pull/14803
        self.config: PretrainedConfig = config
        self.device: torch.device = device

        self.encoder_func = encoder_func
        self.decoder_func = decoder_func

    def get_encoder(self):
        return self.encoder_func

    def get_decoder(self):
        return self.decoder_func

    def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
        # if past is not None:
        #     print("past is present")
        #     input_ids = input_ids[:, -1:]

        return {
            self.main_input_name: input_ids,
            "encoder_hidden_states": kwargs["encoder_outputs"]["last_hidden_state"],
            "enable_cache": torch.tensor([False], device="cuda", dtype=torch.bool),
        }

    def forward(self, input_ids: torch.Tensor, encoder_hidden_states: torch.Tensor, enable_cache: torch.Tensor, **_):
        dec_output = self.get_decoder()(
            input_ids=input_ids, last_hidden_state=encoder_hidden_states, enable_cache=enable_cache
        )
        return Seq2SeqLMOutput(logits=dec_output.last_hidden_state)


model_gen = (
    ExtT5(
        config=model.config,
        device=model.device,
        encoder_func=encoder_onnx_inference,  # encoder_pytorch_inference
        decoder_func=decoder_onnx_inference,  # decoder_pytorch_inference
    )
    .cuda()
    .eval()
)

with torch.inference_mode():
    print(
        tokenizer.decode(
            model_gen.generate(inputs=input_ids, min_length=30, max_length=30, num_beams=7, no_repeat_ngram_size=2)[0],
            skip_special_tokens=False,
        )
    )
    print(
        tokenizer.decode(
            model.generate(inputs=input_ids, min_length=30, max_length=30, num_beams=7, no_repeat_ngram_size=2)[0],
            skip_special_tokens=False,
        )
    )

start = time()
for _ in range(3):
    model_gen.generate(inputs=input_ids, max_length=500, num_beams=1, no_repeat_ngram_size=2, min_length=500)
print(time() - start)

model.config.use_cache = False
with torch.inference_mode():
    start = time()
    for _ in range(3):
        model.generate(inputs=input_ids, max_length=500, num_beams=1, no_repeat_ngram_size=2, min_length=500)
    print(time() - start)

model = model.cpu()
model_decoder = model_decoder.cpu()
del enc_onnx
del dec_onnx

2022-04-24 22:09:05.204572444 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'cache_node_1272'. It is not used by any node and should be removed from the model.
2022-04-24 22:09:05.204602931 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'cache_node_1261'. It is not used by any node and should be removed from the model.
2022-04-24 22:09:05.204612159 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'cache_node_1271'. It is not used by any node and should be removed from the model.
2022-04-24 22:09:05.204618382 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'cache_node_1260'. It is not used by any node and should be removed from the model.
2022-04-24 22:09:05.223103677 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer '1225'. It is not used by any node and should be removed from the model.
2

encoder
<pad> Studien studies show that study studies have shown that a number of studies study suggest that the study study is based on e Studies.
<pad> Studien studies show that study studies have shown that a number of studies study suggest that the study study is based on e Studies.
encoder
encoder
encoder
16.8681697845459
12.968188285827637


In [None]:
# model_to_export = ExportT5(decoder=model.decoder, lm_head=model.lm_head).eval()
# out_model_export: torch.Tensor = model_to_export(input_ids=input_ids, encoder_hidden_states=out_enc.last_hidden_state)
# assert np.allclose(out_model_export.detach().cpu().numpy(), out_full.logits.detach().cpu().numpy(), atol=1e-5)
#
# inputs_onnx = {"input_ids": input_ids, "encoder_hidden_states": out_enc.last_hidden_state}
#
# convert_to_onnx(
#     model_pytorch=model_to_export,
#     output_path="test-dec.onnx",
#     inputs_pytorch=inputs_onnx,
#     var_output_seq=False,
#     quantization=False,
#     fix_output_dim_size=False,  # specific to decoder part
# )
# optimize_onnx(
#     onnx_path="test-dec.onnx",
#     onnx_optim_model_path="test-dec-opt.onnx",
#     architecture="bert",
#     use_cuda=True,
#     fp16=True,
#     num_attention_heads=model.config.num_heads,
#     hidden_size=model.config.d_model,
# )

In [None]:
trt_logger: Logger = trt.Logger(trt.Logger.ERROR)
runtime: Runtime = trt.Runtime(trt_logger)
trt_model_name = "trt-t5-dec.plan"

# create only of does not exist because it's slow to run...

# 768 for base model, 512 for small, make it dependent from the Pytorch model configuration
input_id_shape = TensorRTShape(min_shape=[5, 1], optimal_shape=[5, 500], max_shape=[5, 500], input_name="input_ids")
encoder_hidden_states_shape = TensorRTShape(
    min_shape=[5, 1, 512], optimal_shape=[5, 500 // 2, 512], max_shape=[5, 500, 512], input_name="encoder_hidden_states"
)


model = model.cuda()
model_onnx: ModelProto = onnx.load("test-dec.onnx")
model_onnx_all_nodes = add_output_nodes(model=model_onnx)
onnx_graph: Dict[str, Set[str]] = get_adjency_dict(model=model_onnx)
ort_model_all_nodes = create_model_for_provider(model_onnx_all_nodes.SerializeToString(), "CUDAExecutionProvider")


# use info from tokenizer size and max shape provided through the command line
def get_random_input():
    input = torch.randint(high=tokenizer.vocab_size, size=(5, 500), dtype=torch.int32, device="cuda")
    hidden_state = model.encoder(input_ids=input).last_hidden_state.detach().cpu().numpy()
    return {"input_ids": input.detach().cpu().numpy(), "encoder_hidden_states": hidden_state}


keep_fp32 = get_list_fp32_nodes(
    onnx_graph=onnx_graph, model=ort_model_all_nodes, get_input=get_random_input, nb_try=200
)
model = model.cpu()

In [None]:
engine: ICudaEngine = build_engine(
    runtime=runtime,
    onnx_file_path="test-dec.onnx",
    logger=trt_logger,
    workspace_size=20000 * 1024**2,
    fp16=True,
    int8=False,
    input_shapes=[input_id_shape, encoder_hidden_states_shape],
    fp16_fix=get_fix_fp16_network_func(keep_fp32=keep_fp32),
)
save_engine(engine, trt_model_name)

tensorrt_model = load_engine(runtime=runtime, engine_file_path=trt_model_name)
a = tensorrt_model(
    {
        "input_ids": input_ids.type(torch.int32).repeat((5, 1)),
        "encoder_hidden_states": out_enc.last_hidden_state.repeat((5, 1, 1)),
    }
)
print(a[0])

benchmark_input = torch.ones((5, 500), dtype=torch.int32, device="cuda")
benchmark_enc_output = out_enc.last_hidden_state.repeat((5, 1, 1))
for _ in range(10):
    tensorrt_model(
        {
            "input_ids": benchmark_input,
            "encoder_hidden_states": benchmark_enc_output,
        }
    )
start = time()
for _ in range(100):
    tensorrt_model(
        {
            "input_ids": benchmark_input,
            "encoder_hidden_states": benchmark_enc_output,
        }
    )
print(time() - start)

dec_onnx = create_model_for_provider("test-dec-opt.onnx", "CUDAExecutionProvider")
dec_onnx_out = decoder_onnx_inference(input_ids=input_ids, last_hidden_state=out_enc.last_hidden_state)


for _ in range(10):
    decoder_onnx_inference(input_ids=benchmark_input, last_hidden_state=benchmark_enc_output)
start = time()
for _ in range(100):
    decoder_onnx_inference(input_ids=benchmark_input, last_hidden_state=benchmark_enc_output)
print(time() - start)

model.cuda()
for _ in range(10):
    model.decoder(input_ids=benchmark_input, encoder_hidden_states=benchmark_enc_output)
start = time()
for _ in range(100):
    model.decoder(input_ids=benchmark_input, encoder_hidden_states=benchmark_enc_output)
print(time() - start)

# TensorRT, ONNX Runtime, Pytorch

# sequence 500
# 0.8640644550323486
# 0.6695075035095215
# 1.1308434009552002

# sequence 250
# 0.9177014827728271
# 0.6861860752105713
# 1.1923034191131592

In [None]:
out_enc: BaseModelOutputWithPastAndCrossAttentions = model.encoder(input_ids=input_ids)
model.decoder(
    input_ids=input_ids, encoder_hidden_states=out_enc.last_hidden_state, past_key_values=None
).last_hidden_state[:, -1, :]

In [None]:
out_dec_pytorch = model.decoder(input_ids=input_ids[:, :-1], encoder_hidden_states=out_enc.last_hidden_state)
model.decoder(
    input_ids=input_ids[:, -1:],
    encoder_hidden_states=out_enc.last_hidden_state,
    past_key_values=out_dec_pytorch.past_key_values,
).last_hidden_state[:, -1, :]

In [None]:
# from itertools import chain
# from transformers.onnx.features import FeaturesManager
#
# feature = "seq2seq-lm-with-past"
# model = FeaturesManager.get_model_from_feature(feature, model_name)
# model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=feature)
# onnx_config = model_onnx_config(model.config)
#
# with torch.no_grad():
#     model.config.return_dict = True
#     model.eval()
#
#     # Check if we need to override certain configuration item
#     if onnx_config.values_override is not None:
#         for override_config_key, override_config_value in onnx_config.values_override.items():
#             setattr(model.config, override_config_key, override_config_value)
#
#     # Ensure inputs match
#     model_inputs = onnx_config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
#     for k, v in model_inputs.items():
#         if isinstance(v, torch.Tensor):
#             model_inputs[k] = model_inputs[k].type(torch.int32)
#     onnx_outputs = list(onnx_config.outputs.keys())
#
#     onnx_config.patch_ops()
#
#     # export can works with named args but the dict containing named args as to be last element of the args tuple
#     torch.onnx.export(
#         model,
#         (model_inputs,),
#         f="test-dec-cache.onnx",
#         input_names=list(onnx_config.inputs.keys()),
#         output_names=onnx_outputs,
#         dynamic_axes={name: axes for name, axes in chain(onnx_config.inputs.items(), onnx_config.outputs.items())},
#         do_constant_folding=True,
#         use_external_data_format=onnx_config.use_external_data_format(model.num_parameters()),
#         enable_onnx_checker=True,
#         opset_version=13,
#     )
#
#     onnx_config.restore_ops()

In [None]:
ort_cache = create_model_for_provider("test-dec-cache.onnx", "CPUExecutionProvider")
input_ort = dict()
input_ort["input_ids"] = input_ids[:, -1:].type(torch.int32).detach().cpu().numpy()
input_ort["encoder_hidden_states"] = out_enc.last_hidden_state.detach().cpu().numpy()

for index, (k_dec, v_dec, k_enc, v_enc) in enumerate(
    out_dec_pytorch.past_key_values
):  # type: int, (torch.Tensor, torch.Tensor)
    input_ort[f"past_key_values.{index}.decoder.key"] = k_dec.detach().cpu().numpy()
    input_ort[f"past_key_values.{index}.decoder.value"] = v_dec.detach().cpu().numpy()
    input_ort[f"past_key_values.{index}.encoder.key"] = k_enc.detach().cpu().numpy()
    input_ort[f"past_key_values.{index}.encoder.value"] = v_enc.detach().cpu().numpy()

ort_cache.run(["logits"], input_ort)[0]

In [None]:
ort_no_cache = create_model_for_provider("test-dec-no-cache.onnx", "CPUExecutionProvider")
input_no_cache = dict()
input_no_cache["input_ids"] = input_ids.type(torch.int32).detach().cpu().numpy()
input_no_cache["encoder_hidden_states"] = out_enc.last_hidden_state.detach().cpu().numpy()

ort_no_cache.run(["logits"], input_no_cache)[0][:, -1, :]

In [None]:
# ort_cache = create_model_for_provider("test-dec-cache.onnx", "CUDAExecutionProvider")
# input_cache = dict()
# input_cache["input_ids"] = input_ids[:, -1:]
# input_cache["encoder_hidden_states"] = out_enc.last_hidden_state.detach()
#
# for index, (k_dec, v_dec, k_enc, v_enc) in enumerate(
#     out_dec_pytorch.past_key_values
# ):  # type: int, (torch.Tensor, torch.Tensor, torch.Tensor)
#     input_cache[f"past_key_values.{index}.decoder.key"] = k_dec.cuda()
#     input_cache[f"past_key_values.{index}.decoder.value"] = v_dec.cuda()
#     input_cache[f"past_key_values.{index}.encoder.key"] = k_enc.cuda()
#     input_cache[f"past_key_values.{index}.encoder.value"] = v_enc.cuda()
#
#
# print(inference_onnx_binding(model_onnx=ort_cache, inputs=input_cache, device="cuda", output_shape={"logits": (1, 1, 512)})[
#     "logits"
# ])

In [None]:
# ort_cache = create_model_for_provider("test-dec-no-cache.onnx", "CUDAExecutionProvider")
# input_cache = dict()
# input_cache["input_ids"] = input_ids
# input_cache["encoder_hidden_states"] = out_enc.last_hidden_state.detach()
#
# print(inference_onnx_binding(model_onnx=ort_cache, inputs=input_cache, device="cuda", output_shape={"logits": (1, 4, 512)})[
#     "logits"
# ][:,-1,:])

In [None]:
# out_enc: BaseModelOutputWithPastAndCrossAttentions = model.encoder(input_ids=input_ids)
# out_dec_pytorch = model.decoder(input_ids=input_ids[:, :-1], encoder_hidden_states=out_enc.last_hidden_state)
#
# ort_cache = create_model_for_provider(model_def.SerializeToString(), "CPUExecutionProvider")
# input_cache = dict()
# input_cache["input_ids"] = input_ids.type(torch.int32).detach().cpu().numpy()
# input_cache["encoder_hidden_states"] = out_enc.last_hidden_state.detach().cpu().numpy()
# input_cache["enable_cache"] = np.array([False])
#
# for index, (k_dec, v_dec, k_enc, v_enc) in enumerate(
#     out_dec_pytorch.past_key_values
# ):  # type: int, (torch.Tensor, torch.Tensor)
#     input_cache[f"past_key_values.{index}.decoder.key"] = k_dec.detach().cpu().numpy()
#     input_cache[f"past_key_values.{index}.decoder.value"] = v_dec.detach().cpu().numpy()
#     input_cache[f"past_key_values.{index}.encoder.key"] = k_enc.detach().cpu().numpy()
#     input_cache[f"past_key_values.{index}.encoder.value"] = v_enc.detach().cpu().numpy()
#
# print(ort_cache.run(["logits"], input_cache)[0][:,-1,:][:, :10])
# print(ort_cache.run(["logits"], input_cache)[0].shape)
#
#
# input_cache["enable_cache"] = np.array([True])
# input_cache["input_ids"] = input_cache["input_ids"][:, -1:]
# print(ort_cache.run(["logits"], input_cache)[0][:,-1,:][:, :10])
# print(ort_cache.run(["logits"], input_cache)[0].shape)

In [28]:
import gc

gc.collect()
model.cuda()
input_ids_benchmark = torch.ones((4, 200), dtype=torch.int32, device="cuda")
out_enc: BaseModelOutputWithPastAndCrossAttentions = model.encoder(input_ids=input_ids_benchmark)  #
out_dec_pytorch = model.decoder(input_ids=input_ids_benchmark[:, :-1], encoder_hidden_states=out_enc.last_hidden_state)

ort_cache = create_model_for_provider(model_def.SerializeToString(), "CUDAExecutionProvider")
input_ort = dict()
input_ort["input_ids"] = input_ids_benchmark.type(torch.int32)
input_ort["encoder_hidden_states"] = out_enc.last_hidden_state
input_ort["enable_cache"] = torch.tensor([False], device="cuda", dtype=torch.bool)

for index, (k_dec, v_dec, k_enc, v_enc) in enumerate(
    out_dec_pytorch.past_key_values
):  # type: int, (torch.Tensor, torch.Tensor)
    input_ort[f"past_key_values.{index}.decoder.key"] = k_dec
    input_ort[f"past_key_values.{index}.decoder.value"] = v_dec
    input_ort[f"past_key_values.{index}.encoder.key"] = k_enc
    input_ort[f"past_key_values.{index}.encoder.value"] = v_enc


start = time()
for _ in range(10):
    result_dict = inference_onnx_binding(
        model_onnx=ort_cache,
        inputs=input_ort,
        device=input_ids_benchmark.device.type,
        output_shape={"logits": tuple(input_ort["input_ids"].shape) + (int(model.config.d_model),)},
    )
print(time() - start)
print(result_dict["logits"][:, -1:, :][0, :, :10])

input_ort["enable_cache"] = torch.tensor([True], device="cuda", dtype=torch.bool)
input_ort["input_ids"] = input_ort["input_ids"][:, -1:].type(torch.int32)
start = time()
for _ in range(10):
    result_dict = inference_onnx_binding(
        model_onnx=ort_cache,
        inputs=input_ort,
        device=input_ids.device.type,
        output_shape={"logits": tuple(input_ort["input_ids"].shape) + (int(model.config.d_model),)},
    )
print(time() - start)
print(result_dict["logits"][:, -1:, :][0, :, :10])

del input_ids_benchmark
del ort_cache
del input_ort
del out_enc
del out_dec_pytorch

2022-04-24 21:14:58.513001016 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'cache_node_1207'. It is not used by any node and should be removed from the model.
2022-04-24 21:14:58.513051249 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'cache_node_1157'. It is not used by any node and should be removed from the model.
2022-04-24 21:14:58.513062059 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'cache_node_1073'. It is not used by any node and should be removed from the model.
2022-04-24 21:14:58.513069042 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'cache_node_1056'. It is not used by any node and should be removed from the model.
2022-04-24 21:14:58.513076180 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'cache_node_1006'. It is not used by any node and should be removed from t

RuntimeError: Error in execution: CUDA error executing cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyHostToDevice, GetStream(kCudaStreamDefault))

In [None]:
del input_ids_benchmark
del out_enc
del out_dec_pytorch
del input_ort
del ort_cache

In [None]:
# tout_enc: BaseModelOutputWithPastAndCrossAttentions = model.encoder(input_ids=torch.range(1, 1000, dtype=torch.int32, device="cuda").unsqueeze(0))
# out_dec_pytorch = model.decoder(input_ids=torch.range(1, 1000, dtype=torch.int32, device="cuda").unsqueeze(0), encoder_hidden_states=out_enc.last_hidden_state)
#
# ort_cache = create_model_for_provider(model_def.SerializeToString(), "CUDAExecutionProvider")
# input_cache = dict()
# input_cache["input_ids"] = input_ids.type(torch.int32)
# input_cache["encoder_hidden_states"] = out_enc.last_hidden_state
# input_cache["enable_cache"] = torch.tensor([False], device="cuda", dtype=torch.bool)
#
# for index, (k_dec, v_dec, k_enc, v_enc) in enumerate(
#     out_dec_pytorch.past_key_values
# ):  # type: int, (torch.Tensor, torch.Tensor)
#     input_cache[f"past_key_values.{index}.decoder.key"] = torch.zeros((1,8,1,64), dtype=torch.float32, device="cuda") # k_dec.detach().cpu().numpy()
#     input_cache[f"past_key_values.{index}.decoder.value"] = torch.zeros((1,8,1,64), dtype=torch.float32, device="cuda") # v_dec.detach().cpu().numpy()
#     input_cache[f"past_key_values.{index}.encoder.key"] = torch.zeros((1,8,1,64), dtype=torch.float32, device="cuda") # k_enc.detach().cpu().numpy()
#     input_cache[f"past_key_values.{index}.encoder.value"] = torch.zeros((1,8,1,64), dtype=torch.float32, device="cuda") # v_enc.detach().cpu().numpy()
#
#
# start = time()
# for _ in range(10):
#     result_dict = inference_onnx_binding(
#     model_onnx=ort_cache,
#     inputs=input_cache,
#     device=input_ids.device.type,
#     output_shape={"logits" : tuple(input_ids.shape) + (int(model.config.vocab_size),)},
#     )
# print(time()-start)

In [None]:
# out_enc: BaseModelOutputWithPastAndCrossAttentions = model.encoder(input_ids=torch.range(1, 1000, dtype=torch.int32, device="cuda").unsqueeze(0))
# out_dec_pytorch = model.decoder(input_ids=torch.range(1, 1000, dtype=torch.int32, device="cuda").unsqueeze(0), encoder_hidden_states=out_enc.last_hidden_state)
#
# ort_cache = create_model_for_provider(model_def.SerializeToString(), "CUDAExecutionProvider")
# input_cache = dict()
# input_cache["input_ids"] = input_ids[:, :-1].type(torch.int32).detach().cpu().numpy()
# input_cache["encoder_hidden_states"] = out_enc.last_hidden_state.detach().cpu().numpy()
# input_cache["enable_cache"] = torch.tensor([False], device="cuda", dtype=torch.bool)
#
# for index, (k_dec, v_dec, k_enc, v_enc) in enumerate(
#     out_dec_pytorch.past_key_values
# ):  # type: int, (torch.Tensor, torch.Tensor)
#     input_cache[f"past_key_values.{index}.decoder.key"] = k_dec.detach().cpu().numpy()
#     input_cache[f"past_key_values.{index}.decoder.value"] = v_dec.detach().cpu().numpy()
#     input_cache[f"past_key_values.{index}.encoder.key"] = k_enc.detach().cpu().numpy()
#     input_cache[f"past_key_values.{index}.encoder.value"] = v_enc.detach().cpu().numpy()
#
#
# start = time()
# for _ in range(10):
#     ort_cache.run(["logits"], input_cache)
# print(time()-start)