In [21]:
import torch
import torchvision
 
def get_model():
    model = torchvision.models.wide_resnet50_2(pretrained=0)
    return model
 
def get_onnx(model, onnx_save_path, example_tensor):
 
    example_tensor = example_tensor.cuda()
    model = model.cuda()
 
    _ = torch.onnx.export(model,  # model being run
                                  example_tensor,  # model input (or a tuple for multiple inputs)
                                  onnx_save_path,
                                  verbose=False,  # store the trained parameter weights inside the model file
                                  do_constant_folding=True,
                                  input_names=['input'],
                                  output_names=['output']
                                  )

if __name__ == '__main__':
 
    model = get_model()
    onnx_save_path = "onnx/resnet50_2.onnx"
    example_tensor = torch.randn(1, 3, 288, 512, device='cuda')
 
         # Export model
    get_onnx(model, onnx_save_path, example_tensor)

In [32]:
from tensorflow.python.compiler.tensorrt import trt_convert as trt
def ONNX2TRT(args, calib=None):
    ''' convert onnx to tensorrt engine, use mode of ['fp32', 'fp16', 'int8']
    :return: trt engine
    '''
 
    #assert args.mode.lower() in 'fp32'
 
    G_LOGGER = trt.Logger(trt.Logger.WARNING)
    with trt.Builder(G_LOGGER) as builder, builder.create_network() as network, \
            trt.OnnxParser(network, G_LOGGER) as parser:
 
        builder.max_batch_size = args.batch_size
        builder.max_workspace_size = 1 << 30
        if args.mode.lower() == 'int8':
            assert (builder.platform_has_fast_int8 == True), "not support int8"
            builder.int8_mode = True
            builder.int8_calibrator = calib
        elif args.mode.lower() == 'fp16':
            assert (builder.platform_has_fast_fp16 == True), "not support fp16"
            builder.fp16_mode = True
 
        print('Loading ONNX file from path {}...'.format(args.onnx_file_path))
        with open(args.onnx_file_path, 'rb') as model:
            print('Beginning ONNX file parsing')
            parser.parse(model.read())
        print('Completed parsing of ONNX file')
 
        print('Building an engine from file {}; this may take a while...'.format(args.onnx_file_path))
        engine = builder.build_cuda_engine(network)
        print("Created engine success! ")
 
        # Save plan file
        print('Saving TRT engine file to path {}...'.format(args.engine_file_path))
        with open(args.engine_file_path, "wb") as f:
            f.write(engine.serialize())
        print('Engine file has already saved to {}!'.format(args.engine_file_path))
        return engine

In [36]:
import tensorrt

ModuleNotFoundError: No module named 'tensorrt'

In [33]:
aa = ONNX2TRT("onnx/resnet50_2.onnx")

AttributeError: module 'tensorflow.python.compiler.tensorrt.trt_convert' has no attribute 'Logger'

In [23]:
def loadEngine2TensorRT(filepath):
    G_LOGGER = trt.Logger(trt.Logger.WARNING)
         # Deserialization engine
    with open(filepath, "rb") as f, trt.Runtime(G_LOGGER) as runtime:
        engine = runtime.deserialize_cuda_engine(f.read())
        return engine

In [None]:
# Create engine through engine file
engine = loadEngine2TensorRT('path_to_engine_file')
 
 # Prepare input and output data
img = Image.open('XXX.jpg')
img = D.transform(img).unsqueeze(0)
img = img.numpy()
output = np.empty((1, 2), dtype=np.float32)
 
 #Create context
context = engine.create_execution_context()
 
 # Allocate memory 
d_input = cuda.mem_alloc(1 * img.size * img.dtype.itemsize)
d_output = cuda.mem_alloc(1 * output.size * output.dtype.itemsize)
bindings = [int(d_input), int(d_output)]
 
 # pycuda operation buffer
stream = cuda.Stream()
 
 # Put the input data into the device
cuda.memcpy_htod_async(d_input, img, stream)
 
 # Execution model
context.execute_async(100, bindings, stream.handle, None)
 
 # Take the prediction result from the buffer
cuda.memcpy_dtoh_async(output, d_output, stream)
 # Thread synchronization
stream.synchronize()
 
print(output)