In [None]:
import tensorrt as trt
import common

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
# The Onnx path is used for Onnx models.
def build_engine_onnx(model_file):
    builder = trt.Builder(TRT_LOGGER)
    network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    network = builder.create_network(common.EXPLICIT_BATCH)
    config = builder.create_builder_config()
    parser = trt.OnnxParser(network, TRT_LOGGER)

    max_ws=512*1024*1024
    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, max_ws)
    # config.max_workspace_size = common.GiB(1)
    # Load the Onnx model and parse it in order to populate the TensorRT network.
    with open(model_file, "rb") as model:
        if not parser.parse(model.read()):
            print("ERROR: Failed to parse the ONNX file.")
            for error in range(parser.num_errors):
                print(parser.get_error(error))
            return None
    return builder.build_serialized_network(network, config)


In [None]:
model_name = 'findCenter_folded'
onnx_path = '../onnx/'+model_name+'.onnx'
engine_path = model_name+'.trt'

In [None]:
# serialized_engine = build_engine_onnx(onnx_path)
# with open(engine_path, 'wb') as f:
#     f.write(serialized_engine)

In [None]:
runtime = trt.Runtime(TRT_LOGGER)
with open(engine_path, "rb") as f:
    serialized_engine = f.read()
engine = runtime.deserialize_cuda_engine(serialized_engine)

In [None]:
context = engine.create_execution_context()

In [None]:
inputs, outputs, bindings, stream = common.allocate_buffers(engine)

In [None]:
inputs

In [None]:
outputs

In [None]:
import pickle
pickle_dir = "/workspace/centerformer/work_dirs/partition/sample_data/"

with open(pickle_dir + "findcenter_input.pkl", 'rb') as handle:
    input_tensor = pickle.load(handle)

#convert input_tensor to numpy array
import numpy as np
input_tensor = input_tensor.detach().cpu().numpy()

inputs[0].host = input_tensor

In [None]:
trt_outputs = common.do_inference_v2(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)

In [None]:
outputs

In [None]:
output_shape = [(1, 1, 128, 128), ]
for i in range(6):
    print(trt_outputs[i])