In [1]:
input_size = 225
win_size = 3
stride = 2
output_size = input_size // stride

In [2]:
import numpy as np
import torch

x = np.random.randn(3, input_size, input_size)
x = np.float32(x)
input_tensor = torch.tensor(x)
input_batch = input_tensor.unsqueeze(0)

In [3]:
input_batch.shape

torch.Size([1, 3, 225, 225])

In [4]:
input_batch

tensor([[[[ 3.5725e-01, -1.6917e+00,  5.2136e-01,  ..., -6.0795e-01,
            1.5753e+00, -3.8070e-01],
          [-5.2666e-01, -5.3604e-01,  2.8471e-01,  ..., -4.2220e-01,
            3.7675e-02, -2.1327e-01],
          [ 5.2262e-01, -9.2496e-01,  1.1425e+00,  ..., -8.2697e-01,
           -1.4993e+00,  1.0177e+00],
          ...,
          [-2.4542e-01,  7.2722e-01, -1.2134e+00,  ...,  1.0461e+00,
           -1.0332e+00,  1.1853e-02],
          [ 1.5355e+00,  9.1631e-01,  2.9137e-01,  ..., -2.2114e+00,
            6.3102e-01,  1.4526e-01],
          [-5.4330e-01,  1.4354e-02,  1.7691e+00,  ...,  1.3342e+00,
           -2.0007e-01, -1.3380e+00]],

         [[ 6.8778e-01,  3.5977e-01,  1.2019e+00,  ...,  1.1744e+00,
           -1.0098e-01, -1.7514e+00],
          [ 1.2804e+00, -8.9257e-02,  1.8362e+00,  ...,  1.4048e+00,
            2.4415e-01,  9.5697e-01],
          [ 9.1319e-01, -8.5339e-01, -6.3060e-01,  ...,  9.9020e-01,
           -5.5876e-01, -5.3186e-01],
          ...,
     

In [5]:
def define_trt_network(network):
    # Input
    input_tensor = network.add_input(name='input', dtype=trt.float32, shape=(1, 3, input_size, input_size))

    # MaxPool2d
    layer = network.add_pooling( \
        input=input_tensor, type=trt.PoolingType.MAX, window_size=(win_size, win_size))
    layer.stride = (stride, stride)

    # Output
    layer.get_output(0).name = 'output'
    network.mark_output(tensor=layer.get_output(0))

In [6]:
import trt_analyzer
import tensorrt as trt

EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)

def build_engine(logger):
    with trt.Builder(logger) as builder, builder.create_network(EXPLICIT_BATCH) as network:
        builder.max_workspace_size = 1 << 30
        # Define the TRT network using weights from the PyTorch model.
        #define_trt_plugin_network(network)
        define_trt_network(network)
        # Get network info
        global net_dict
        net_dict = trt_analyzer.network_dict(network)
        # Build and return an engine.
        return builder.build_cuda_engine(network)

In [7]:
import tensorrt as trt
import common

TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
net_dict = None
with build_engine(TRT_LOGGER) as engine:
    inputs, outputs, bindings, stream = common.allocate_buffers(engine)
    with engine.create_execution_context() as context:
        inputs[0].host = input_batch.numpy()
        trt_outputs = common.do_inference_v2( \
            context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)

In [8]:
import pandas as pd

pd.DataFrame(net_dict)

Unnamed: 0,Name,Type,Inputs,Outputs,Type Specific Params
0,(Unnamed Layer* 0) [Pooling],LayerType.POOLING,"(1, 3, 225, 225)","(1, 3, 112, 112)","type=PoolingType.MAX wsize=(3, 3) stride=(2, 2..."


In [9]:
reference = trt_outputs[0].reshape((3, input_size // 2, input_size // 2))
print(reference)

[[[1.1424867  1.1424867  1.7960731  ... 0.26349372 0.43127626 1.5752807 ]
  [1.7005147  1.5448505  2.220654   ... 1.9787482  1.4446608  1.0176722 ]
  [1.7005147  2.7461123  2.7461123  ... 1.4446608  1.4446608  0.29846618]
  ...
  [1.9945829  1.9945829  1.1967089  ... 1.8578961  1.3442125  1.3442125 ]
  [2.419224   3.2397647  0.38520214 ... 1.8578961  1.5268359  1.0461092 ]
  [1.7690994  3.2397647  0.6958099  ... 1.6682177  1.6682177  1.3341585 ]]

 [[1.8362406  1.8362406  1.9174484  ... 0.725377   1.4048389  1.4048389 ]
  [0.9131926  0.2932047  1.9174484  ... 1.2848766  1.5044066  1.6129793 ]
  [1.5864489  0.7369177  0.32757738 ... 1.3945507  1.5044066  1.5044066 ]
  ...
  [1.4659975  1.0164951  2.1848512  ... 0.12717292 1.4331068  1.4331068 ]
  [1.6343282  1.6343282  1.1453598  ... 0.88451827 0.88451827 1.8136477 ]
  [1.6343282  1.6343282  1.1453598  ... 0.88451827 1.420002   1.8136477 ]]

 [[1.7271981  3.2363951  3.2363951  ... 2.112953   2.0569186  1.1088425 ]
  [2.5487862  1.749025

In [10]:
import sys
import os

cur_path = %pwd
plugin_path = os.path.join(cur_path, 'plugin')
sys.path.append(plugin_path)
from trt_plugin_pb2 import copy_Message
from trt_plugin_pb2 import pooling_Message
import trt_plugin_pb2

In [11]:
import ctypes

lib_file = os.path.join(plugin_path, 'build', 'libPoolingPlugin.so')
lib = ctypes.CDLL(lib_file)

In [12]:
import tensorrt as trt

registry = trt.get_plugin_registry()
print([c.name for c in registry.plugin_creator_list])
print([c.plugin_namespace for c in registry.plugin_creator_list])

['RnRes2Br2bBr2c_TRT', 'RnRes2Br2bBr2c_TRT', 'RnRes2Br1Br2c_TRT', 'RnRes2Br1Br2c_TRT', 'CustomSkipLayerNormPluginDynamic', 'CustomEmbLayerNormPluginDynamic', 'CustomGeluPluginDynamic', 'CustomQKVToContextPluginDynamic', 'CustomFCPluginDynamic', 'SingleStepLSTMPlugin', 'pooling', 'copy']
['', '', '', '', '', '', '', '', '', '', 'macnica_trt_plugins', 'macnica_trt_plugins']


In [13]:
namespace = 'macnica_trt_plugins'
macnica_creators = [c for c in registry.plugin_creator_list if c.plugin_namespace == namespace]
for c in macnica_creators:
    registry.register_creator(c, namespace)

In [14]:
def define_trt_plugin_network(network):
    # Input
    input_tensor = network.add_input(name='input', dtype=trt.float32, shape=(1, 3, input_size, input_size))
    
    ### Custom Pooling Layer with CUDA or cuDNN ###
    creator = registry.get_plugin_creator( \
        type='pooling', version='1', plugin_namespace='macnica_trt_plugins')
    sz = input_tensor.shape
    message = pooling_Message( \
        dims=sz, mode=trt_plugin_pb2.Maximum, window=[win_size, win_size], \
        stride=[stride, stride], impl=trt_plugin_pb2.CUDA)
    plg = creator.deserialize_plugin('pooling', message.SerializeToString())
    layer = network.add_plugin_v2(inputs=[input_tensor], plugin=plg)

    # Output
    layer.get_output(0).name = 'output'
    network.mark_output(tensor=layer.get_output(0))

In [15]:
import trt_analyzer
import tensorrt as trt

EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)

def build_engine2(logger):
    with trt.Builder(logger) as builder, builder.create_network(EXPLICIT_BATCH) as network:
        builder.max_workspace_size = 1 << 30
        # Define the TRT network using weights from the PyTorch model.
        define_trt_plugin_network(network)
        #define_trt_network(network)
        # Get network info
        global net_dict
        net_dict = trt_analyzer.network_dict(network)
        # Build and return an engine.
        return builder.build_cuda_engine(network)

In [16]:
net_dict = None
with build_engine2(TRT_LOGGER) as engine:
    inputs, outputs, bindings, stream = common.allocate_buffers(engine)
    with engine.create_execution_context() as context:
        inputs[0].host = input_batch.numpy()
        trt_outputs = common.do_inference_v2( \
            context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)

In [17]:
pd.DataFrame(net_dict)

Unnamed: 0,Name,Type,Inputs,Outputs,Type Specific Params
0,(Unnamed Layer* 0) [PluginV2Ext],LayerType.PLUGIN_V2,"(1, 3, 225, 225)","(1, 3, 112, 112)",


In [18]:
result = trt_outputs[0].reshape((3, output_size, output_size))
print(result)

[[[1.1424867  1.1424867  1.7960731  ... 0.26349372 0.43127626 1.5752807 ]
  [1.7005147  1.5448505  2.220654   ... 1.9787482  1.4446608  1.0176722 ]
  [1.7005147  2.7461123  2.7461123  ... 1.4446608  1.4446608  0.29846618]
  ...
  [1.9945829  1.9945829  1.1967089  ... 1.8578961  1.3442125  1.3442125 ]
  [2.419224   3.2397647  0.38520214 ... 1.8578961  1.5268359  1.0461092 ]
  [1.7690994  3.2397647  0.6958099  ... 1.6682177  1.6682177  1.3341585 ]]

 [[1.8362406  1.8362406  1.9174484  ... 0.725377   1.4048389  1.4048389 ]
  [0.9131926  0.2932047  1.9174484  ... 1.2848766  1.5044066  1.6129793 ]
  [1.5864489  0.7369177  0.32757738 ... 1.3945507  1.5044066  1.5044066 ]
  ...
  [1.4659975  1.0164951  2.1848512  ... 0.12717292 1.4331068  1.4331068 ]
  [1.6343282  1.6343282  1.1453598  ... 0.88451827 0.88451827 1.8136477 ]
  [1.6343282  1.6343282  1.1453598  ... 0.88451827 1.420002   1.8136477 ]]

 [[1.7271981  3.2363951  3.2363951  ... 2.112953   2.0569186  1.1088425 ]
  [2.5487862  1.749025

In [19]:
print(sum(abs(result.flatten() - reference.flatten())) / len(result.flatten()))

0.0
