In [None]:
!apt-get update && apt-get install -y graphviz
!pip install expecttest pydot

In [None]:
import torch
import torch.fx
import torch.nn as nn
print(torch.__version__)

In [None]:
import tensorrt
print(tensorrt.__version__)

In [None]:
import torch_tensorrt
print(torch_tensorrt.__version__)

In [None]:
from torch_tensorrt.fx.tracer.acc_tracer import acc_normalizer, acc_ops, acc_shape_prop, acc_utils  # noqa: F401
from torch.fx.experimental.normalize import NormalizeArgs

import torch_tensorrt
from torch_tensorrt.fx.utils import LowerPrecision
import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting

In [None]:
from torch.fx.node import Argument, Target
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from torch_tensorrt.fx.converters.converter_utils import SourceIR
from torch_tensorrt.fx.converter_registry import tensorrt_converter
from torch_tensorrt.fx.tracer.acc_tracer.acc_op_properties import AccOpProperty, register_acc_op_properties
from torch_tensorrt.fx.tracer.acc_tracer.acc_normalizer import (
    register_acc_op,
    register_acc_op_mapping,
    register_custom_acc_mapper_fn,
)
from torch_tensorrt.fx.types import (
    TRTNetwork,
    TRTTensor,
)
from torch_tensorrt.fx.converters.impl import activation

In [None]:
import tensorrt as trt
TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
trt.init_libnvinfer_plugins(TRT_LOGGER, "")
print(f"Register libnvinfer plugins")
registry = trt.get_plugin_registry()
print(f"Registry: {registry}")
for plugin in registry.plugin_creator_list:
    print(plugin.name)

In [None]:
from torch_tensorrt.fx.converter_registry import CONVERTERS
from torch_tensorrt.fx.tracer.acc_tracer.acc_normalizer import _acc_ops, _normalization_dict
from torch_tensorrt.fx.tracer.acc_tracer import acc_ops

print(">>" * 40)
print("Converters")
print(">>" * 40)
for op in CONVERTERS:
    print(op)
    
print(">>" * 40)
print("acc_ops")
print(">>" * 40)
for op in _acc_ops:
    print(op)
    
print(">>" * 40)
print("_normalization_dict")
print(">>" * 40)
for op in _normalization_dict:
    print(op)

In [None]:
for op in list(CONVERTERS.keys()):
    if op == acc_ops.gelu:
        CONVERTERS.pop(op)
        print(f"removed converter {op}")

for op in list(_acc_ops):
    if op.__name__ == "gelu":
        _acc_ops.remove(op)
        print(f"removed acc_op: {op}")
        
for (op, target) in list(_normalization_dict.keys()):
    if "gelu" in str(target) or "GELU" in str(target):
        _normalization_dict.pop((op, target))
        print(f"removed normalization_dict op: {op}")

In [None]:
import ctypes
from pathlib import Path


def load_torchtrt_plugins():
    # ctypes.CDLL(osp.join(dir_path, 'libamirstan_plugin.so'))
    # suppose plugins lib installed into HOME
    lib_path = Path(torch_tensorrt.__file__).parent / "lib"
    print(f"Using torch_tensorrt: {torch_tensorrt.__version__}, lib_path={lib_path}")
    # "libtorchtrt.so", "libtorchtrt_runtime.so", 
    for lib in ["libtorchtrt_plugins.so"]:
        path = lib_path / lib
        if not path.exists():
            print(f"Failed to load lib: {path}")
        ctypes.CDLL(str(path))
        print(f"Loaded {path}")

load_torchtrt_plugins()

In [None]:
for item in registry.plugin_creator_list:
    print(item.name, item.plugin_version, item.plugin_namespace)

In [None]:
plugin.plugin_namespace

In [None]:
plugin.plugin_version

In [None]:
import tensorrt as trt

trt.ITensor?

In [None]:
@register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary)
@register_acc_op_mapping(op_and_target=("call_function", torch.nn.functional.gelu))
@register_acc_op_mapping(op_and_target=("call_method", "gelu"))
@register_custom_acc_mapper_fn(
    op_and_target=("call_module", torch.nn.GELU),
    arg_replacement_tuples=[
        ("input", "input"),
    ],
)
@register_acc_op
def gelu(*, input):
    return nn.functional.relu(input=input, inplace=False)


# @tensorrt_converter(torch.nn.functional.gelu)
# @tensorrt_converter(torch.nn.modules.activation.GELU)
# def _gelu(network, submod, args, kwargs, layer_name):
#     # args/kwargs should have already been normalized to kwargs
#     assert len(args) == 0

#     return activation.relu(
#         network=network,
#         target="torch.nn.functional.relu",
#         source_ir=SourceIR.NN,
#         name=layer_name,
#         input_val=kwargs["input"],
#     )


# @register_acc_op_mapping(
#     op_and_target=("call_function", torch.nn.modules.GroupNorm),
#     arg_replacement_tuples=[
#         ("input", "input"),
#         ("num_groups", "num_groups"),
#         ("weight", "weight"),
#         ("bias", "bias"),
#         ("eps", "eps"),
#     ],
# )
# @register_acc_op
# def group_norm(*, input, num_groups, weight=None, bias=None, eps=1e-05):
#     return GroupNormalizationPlugin.apply(input, self.weight, self.bias, self.num_groups, self.eps)
#     return torch.nn.functional.group_norm(
#         input, num_groups, weight=weight, bias=bias, eps=eps
#     )

In [None]:
# auto creator = getPluginRegistry()->getPluginCreator("Interpolate", "1", "torch_tensorrt");
# auto interpolate_plugin = creator->createPlugin(name, &fc);

# auto resize_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *interpolate_plugin);
# TORCHTRT_CHECK(resize_layer, "Unable to create interpolation plugin from node" << *n);

# resize_layer->setName(util::node_info(n).c_str());

# auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));

creator = trt.get_plugin_registry().get_plugin_creator("GroupNormalizationPlugin", "1", "torch_tensorrt")

In [None]:
from torch_tensorrt.fx.converter_registry import tensorrt_converter


def gelu_fn(x):
    """
    https://github.com/geohot/tinygrad/blob/18892242b006785d4e92abae7c792e7874c17df9/tinygrad/tensor.py#L522
    """
    return 0.5 * x * (1 + (x * 0.7978845608 * (1 + 0.044715 * x * x)).tanh())


@tensorrt_converter(torch.nn.functional.gelu)
@tensorrt_converter(torch.nn.modules.activation.GELU)
def relu(network, submod, args, kwargs, layer_name):
    # args/kwargs should have already been normalized to kwargs
    assert len(args) == 0

    return activation.relu(
        network=network,
        target="torch.nn.functional.relu",
        source_ir=SourceIR.NN,
        name=layer_name,
        input_val=kwargs["input"],
    )



In [None]:
import torch
import torch.nn as nn
import torch_tensorrt as torchtrt


# Create a sample network with a conv and gelu node.
# Gelu layer in Torch-TensorRT is converted to CustomGeluPluginDynamic from TensorRT plugin registry.
class ConvGelu(torch.nn.Module):
    def __init__(self):
        super(ConvGelu, self).__init__()
        self.conv = nn.Conv2d(3, 32, 3, 1)
        # self.gelu = nn.GELU()

    def forward(self, x):
        x = self.conv(x)
        # x = self.gelu(x)
        x = torch.nn.functional.gelu(x)
        # x = gelu_fn(x)
        return x
    
model = ConvGelu().eval().cuda()

In [None]:
torchtrt.Input?

In [None]:
from torch.fx import symbolic_trace, replace_pattern

# Replace `pattern` with `replacement` in `traced`
traced = symbolic_trace(model)
print(traced)

In [None]:
# Define the pattern. The FX Subgraph Rewriter will match all
# non-overlapping instances of the pattern in the larger graph.
# Note that Pattern-matching is done based on data dependencies,
# not Node names. Even though we're operating on Nodes named `a1` and
# `a2` instead of `w1` and `w2`, the pattern is still a valid match
# for the two instances of `torch.cat([w1, w2]).sum()` above. Only
# operations that contribute to the single output value of the pattern
# are considered
def pattern(x):
    return torch.nn.functional.gelu(x)

# Define the replacement (same rules as the pattern)
def replacement(x):
    return gelu_fn(x)

replace_pattern(traced, pattern, replacement)
print(traced)

In [None]:
shape = [1, 3, 5, 5]
compile_settings = {
    "inputs": [torchtrt.Input(shape, dtype=torch.float32)],
    "enabled_precisions": {torch.float32},
}
with torch.inference_mode():
    scripted_model = torch.jit.script(traced)

In [None]:
scripted_model.graph

In [None]:
with torch.inference_mode():
    trt_traced = acc_tracer.trace(
        traced, [torch.rand(*shape, dtype=torch.float32, device="cuda")], 
    )
trt_traced.graph.print_tabular()

In [None]:
trt_ts_module = torchtrt.compile(scripted_model, **compile_settings)
torch.jit.save(trt_ts_module, "conv_gelu.jit")
print("Generated Torchscript-TRT GELU model.")

In [None]:
trt_ts_module.graph

In [None]:
x = torch.rand(*shape, dtype=torch.float32, device="cuda")
with torch.inference_mode():
    trt_traced = acc_tracer.trace(model, [x])

    splitter = TRTSplitter(trt_traced, [x])
    splitter.node_support_preview(dump_graph=True)

In [None]:
from IPython.display import Image
import pydot

graphs = pydot.graph_from_dot_file("node_support.dot")
Image(graphs[0].create_png())

In [None]:
trt_ts_gelu_model = torch.load("conv_gelu.jit")

In [None]:
%%timeit -n 10
trt_ts_gelu_model(x)

In [None]:
trt.float32

In [None]:
import numpy as np
import tensorrt as trt

from torch_tensorrt.fx.converters.acc_ops_converters import get_trt_plugin, TRTPluginFieldCollection, _LOGGER
from torch_tensorrt.fx.converter_registry import tensorrt_converter
from torch_tensorrt.fx.converters.converter_utils import get_trt_tensor
from torch_tensorrt.fx.utils import torch_dtype_from_trt


#     if weight is None:
#         weight = torch.ones((input.shape[1],), dtype=input.dtype, device=input.device)
#     if bias is None:
#         bias = torch.zeros((input.shape[1],), dtype=input.dtype, device=input.device)



@register_acc_op_mapping(
    op_and_target=("call_function", torch.nn.functional.group_norm),
    arg_replacement_tuples=[
        ("input", "input"),
        ("num_groups", "num_groups"),
        ("weight", "weight"),
        ("bias", "bias"),
        ("eps", "eps"),
    ],
)
@register_acc_op
def group_norm(*, input, num_groups, weight=None, bias=None, eps=1e-05):
    return torch.nn.functional.group_norm(input, num_groups, weight=weight, bias=bias, eps=eps)


@tensorrt_converter(group_norm)
def acc_ops_group_norm(network, target, args, kwargs, name):
    input_val = kwargs["input"]
    weight = kwargs["weight"]
    bias = kwargs["bias"]
    print("input: ", input_val.shape, input_val.dtype)
    print("weight: ", weight)
    print("bias: ", bias)
    shape = (input_val.shape[1],)
    if weight is None:
        weight = torch.ones(tuple([*input_val.shape])).to(
            torch_dtype_from_trt(input_val.dtype)
        )
    weight = get_trt_tensor(network, weight, f"{name}_weight")

        # weight = network.add_input("weight", input_val.dtype, shape)
        # ones = network.add_constant(shape=shape, weights=np.ones(shape=shape, dtype=np.float32)).get_output(0)
        # weight = network.add_elementwise(weight, ones, op=trt.ElementWiseOperation.SUM).get_output(0)

    if bias is None:
        bias = torch.zeros(tuple([*input_val.shape])).to(
            torch_dtype_from_trt(input_val.dtype)
        )
    bias = get_trt_tensor(network, bias, f"{name}_bias")
        # bias = network.add_input("bias", trt.float32, shape)

    if not isinstance(input_val, trt.tensorrt.ITensor):
        raise RuntimeError(
            f"GroupNorm received input {input_val} that is not part "
            "of the TensorRT region!"
        )

    num_groups_field = trt.PluginField(
        "num_groups", np.array([kwargs["num_groups"]], dtype=np.int32), trt.PluginFieldType.INT32
    )
    eps_field = trt.PluginField(
        "eps", np.array([kwargs["eps"]], dtype=np.float32), trt.PluginFieldType.FLOAT32
    )

    field_collection = trt.PluginFieldCollection(
        [eps_field, num_groups_field]
    )

    try:
        plugin = get_trt_plugin("GroupNormalizationPlugin", field_collection, "1", "")
    except AssertionError:
        _LOGGER.error(
            "Unable to find group norm plugin, fall back to TensorRT implementation."
        )
        raise RuntimeError(
            f"Failed to build group norm plugin."
        )
    print(f"adding plugin v2")
    layer = network.add_plugin_v2([input_val, weight, bias], plugin)
    layer.name = name
    return layer.get_output(0)

In [None]:
import torch
import torch.nn as nn
import torch_tensorrt as torchtrt

# create a simple norm layer.
# This norm layer uses NormalizePlugin from Torch-TensorRT

class Norm(torch.nn.Module):
    def __init__(self, C: int):
        super(Norm, self).__init__()
        self.gn = nn.GroupNorm(C // 2, C)

    def forward(self, x):
        return self.gn(x)
#         num_groups = 2
#         return torch.nn.functional.group_norm(
#             x, num_groups, eps=1e-05
#         )

In [None]:
C = 6
norm_model = Norm(C).eval().cuda()
with torch.inference_mode():
    norm_ts_module = torch.jit.script(norm_model)

In [None]:
norm_ts_module.graph

In [None]:
shape = [1, C, 64, 64]
x = torch.rand(*shape, dtype=torch.float32, device="cuda")
with torch.inference_mode():
    trt_traced = acc_tracer.trace(norm_model, [x])
print(trt_traced)

In [None]:
compile_settings = {
    "inputs": [torchtrt.Input(shape, dtype=torch.float32)],
    "enabled_precisions": {torch.float32},
}

norm_trt_ts = torchtrt.compile(norm_ts_module, **compile_settings)
torch.jit.save(norm_trt_ts, "norm_trt_ts.pt")
print("Generated Torchscript-TRT GroupNorm model.")

In [None]:
print(norm_trt_ts.graph)

In [None]:
with torch.inference_mode():
    splitter = TRTSplitter(trt_traced, [x])
    splitter.node_support_preview(dump_graph=False)
    
split_mod = splitter()
inputs = [x]

def get_submod_inputs(_mod, _submod, _inputs):
    acc_inputs = None

    def get_input(self, inputs):
        nonlocal acc_inputs
        acc_inputs = inputs

    handle = _submod.register_forward_pre_hook(get_input)
    with torch.inference_mode():
        _mod(*_inputs)
    handle.remove()
    return acc_inputs


# Since the model is splitted into three segments. We need to lower each TRT eligible segment.
# If we know the model can be fully lowered, we can skip the splitter part.
for name, _ in split_mod.named_children():
    print(f"Splitting {name}")
    if "_run_on_acc" in name:
        submod = getattr(split_mod, name)

        # Get submodule inputs for fx2trt
        acc_inputs = get_submod_inputs(split_mod, submod, inputs)

        # fx2trt replacement
        interp = TRTInterpreter(
            submod,
            InputTensorSpec.from_tensors(acc_inputs),
            explicit_batch_dimension=True,
        )
        r = interp.run(lower_precision=LowerPrecision.FP32)
        trt_mod = TRTModule(*r)
        setattr(split_mod, name, trt_mod)

In [None]:
filename = "norm_trt_engine.pt"
torch.save(split_mod, filename)
trt_ts_norm_model = torch.load(filename)

In [None]:
x = torch.randn(1, C, 64, 64, dtype=torch.float32, device="cuda")

In [None]:
# %%timeit -n 500
norm_model(x)

In [None]:
# %%timeit -n 500
trt_ts_norm_model(x)

In [None]:
import numpy as np
import tensorrt as trt


def create_groupnorm_plugin(layer_name, num_groups, eps=1e-5):
    creator = trt.get_plugin_registry().get_plugin_creator(
        'GroupNormPluginDynamic', '1', '')

    pfc = trt.PluginFieldCollection()
    pf_num_groups = trt.PluginField('num_groups',
                                    np.array([num_groups], dtype=np.int32),
                                    trt.PluginFieldType.INT32)
    pfc.append(pf_num_groups)

    pf_eps = trt.PluginField('eps', np.array([eps], dtype=np.float32),
                             trt.PluginFieldType.FLOAT32)
    pfc.append(pf_eps)
    return creator.create_plugin(layer_name, pfc)


@tensorrt_converter(torch.nn.GroupNorm.forward)
def convert_GroupNorm(network, submod, args, kwargs, layer_name):
    input = kwargs["input"]
    weight = kwargs["weight"]
    bias = kwargs["bias"]
    

#     input_trt = trt_(ctx.network, input)
#     weight_trt = trt_(ctx.network, module.weight)
#     bias_trt = trt_(ctx.network, module.bias)
    output = ctx.method_return

    num_groups = module.num_groups
    eps = module.eps

    plugin = create_groupnorm_plugin(
        'groupnorm_' + str(id(module)), num_groups=num_groups, eps=eps)

    custom_layer = ctx.network.add_plugin_v2(
        inputs=[input_trt, weight_trt, bias_trt], plugin=plugin)

    output._trt = custom_layer.get_output(0)
    


# @tensorrt_converter('torch.nn.functional.group_norm')
# def convert_group_norm(ctx):

#     input = get_arg(ctx, 'input', pos=0, default=None)
#     num_groups = get_arg(ctx, 'num_groups', pos=1, default=None)
#     weight = get_arg(ctx, 'weight', pos=2, default=None)
#     bias = get_arg(ctx, 'bias', pos=3, default=None)
#     eps = get_arg(ctx, 'eps', pos=4, default=1e-5)
#     output = ctx.method_return


#     input_trt, eps_trt = add_missing_trt_tensors(ctx.network, [input, eps])
    
#     shape = list(input.shape)
#     split_shape = [shape[0]] + [num_groups, shape[1] // num_groups] + shape[2:]
#     split_shape = tuple(split_shape)
#     keepdim = True

#     # split into groups
#     layer = ctx.network.add_shuffle(input_trt)
#     layer.reshape_dims = split_shape
#     a = layer.get_output(0)


#     # compute mean over groups
#     reduce_dims = tuple(range(2, len(split_shape)))
#     axes = torch_dim_to_trt_axes(reduce_dims)
#     layer = ctx.network.add_reduce(a, trt.ReduceOperation.AVG, axes, keepdim)
#     a_mean = layer.get_output(0)

#     # compute stdev over groups
#     a_diff = ctx.network.add_elementwise(a, a_mean, trt.ElementWiseOperation.SUB).get_output(0)
#     a_dist = ctx.network.add_elementwise(a_diff, a_diff, trt.ElementWiseOperation.PROD).get_output(0)
#     a_var = ctx.network.add_reduce(a_dist, trt.ReduceOperation.AVG, axes, keepdim).get_output(0)


#     a_var, eps_trt = broadcast_trt_tensors(ctx.network, [a_var, eps_trt], len(split_shape))

#     a_var_eps = ctx.network.add_elementwise(a_var, eps_trt, trt.ElementWiseOperation.SUM).get_output(0)
#     a_std = ctx.network.add_unary(a_var_eps, trt.UnaryOperation.SQRT).get_output(0)

#     # divide by stdev
#     b = ctx.network.add_elementwise(a_diff, a_std, trt.ElementWiseOperation.DIV).get_output(0)

#     # reshape
#     layer = ctx.network.add_shuffle(b)
#     layer.reshape_dims = shape

#     c = layer.get_output(0)

#     # handle affine version
#     if weight is not None or bias is not None:
#         if weight is not None:
#             scale = weight.detach().cpu().numpy()
#         else:
#             scale = np.ones(input.shape[1])

#         if bias is not None:
#             bias = bias.detach().cpu().numpy()
#         else:
#             bias = np.zeros(input.shape[1])

#         power = np.ones_like(scale)

#         layer = ctx.network.add_scale_nd(c, trt.ScaleMode.CHANNEL, bias, scale, power, 1)
#         c = layer.get_output(0)

#     output._trt = c


In [None]:
# from pytorch/TensorRT
# def common_batchnorm(network, mod, input_val, layer_name, is_quantized):
#     scale = to_numpy(mod.weight) / np.sqrt(to_numpy(mod.running_var) + mod.eps)
#     bias = to_numpy(mod.bias) - to_numpy(mod.running_mean) * scale
#     power = np.ones_like(scale)

#     layer = network.add_scale(input_val, trt.ScaleMode.CHANNEL, bias, scale, power)
#     layer.name = layer_name

#     if is_quantized:
#         mark_as_int8_layer(
#             layer, get_dyn_range(mod.scale, mod.zero_point, torch.quint8)
#         )

#     return layer.get_output(0)


# @tensorrt_converter(torch.nn.modules.batchnorm.BatchNorm2d)
# def batchnorm2d(network, submod, args, kwargs, layer_name):
#     # args/kwargs should have already been normalized to kwargs
#     assert len(args) == 0
#     input_val = kwargs["input"]

#     if not isinstance(input_val, trt.tensorrt.ITensor):
#         raise RuntimeError(
#             f"BatchNorm2d received input {input_val} that is not part "
#             "of the TensorRT region!"
#         )

#     return common_batchnorm(network, submod, input_val, layer_name, is_quantized=False)

In [None]:
# https://github.com/THUDM/FastLDM/blob/7b5f5ff44551dc44daf938ba007d7827e9ac8c6b/fastldm/modules.py#L307

# class BaseApply:
#     @classmethod
#     def apply(cls, *inputs, **kw_args):
#         return cls.forward(None, *inputs, **kw_args)

# BasePlugin = torch.autograd.Function # BaseApply if ONNX_ONLY else torch.autograd.Function

# class GroupNormalizationPlugin(BasePlugin):
#     # https://github.com/NVIDIA/TensorRT/tree/release/8.5/plugin/groupNormalizationPlugin
#     @staticmethod
#     def forward(ctx, x, scale, bias, num_groups, eps):
#         return F.group_norm(x, num_groups, weight=scale, bias=bias, eps=eps)

#     @staticmethod
#     def symbolic(g, x, scale, bias, num_groups, eps):
#         return g.op("GroupNormalizationPlugin", x, scale, bias, plugin_version_s='1', eps_f=eps, num_groups_i=num_groups)


# class GroupNorm(nn.Module):
#     def __init__(self, num_groups, num_channels, eps, affine=True):
#         super().__init__()
#         assert num_channels % num_groups == 0
#         self.eps = eps
#         self.num_groups = num_groups
#         self.weight = nn.Parameter(torch.ones(num_channels))
#         self.bias = nn.Parameter(torch.zeros(num_channels))

#     def forward(self, x):
#         return GroupNormalizationPlugin.apply(x, self.weight, self.bias, self.num_groups, self.eps)

In [None]:
# BreadcrumbsTensorRT/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py

# @register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary)
# @register_acc_op_mapping(op_and_target=("call_function", torch.nn.functional.gelu))
# @register_acc_op_mapping(op_and_target=("call_method", "gelu"))
# @register_custom_acc_mapper_fn(
#     op_and_target=("call_module", torch.nn.GELU),
#     arg_replacement_tuples=[
#         ("input", "input"),
#         ("approximate", "approximate"),
#     ],
# )
# @register_acc_op
# def gelu(*, input, approximate="none"):
#     return torch.nn.functional.gelu(input=input, approximate=approximate)
