In [None]:
!pip install --upgrade pip
!pip install loguru opencv-python-headless tabulate netron

In [None]:
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install diffusers transformers onnx
!pip install --upgrade colored polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com
!pip install onnx==1.13.1 onnxruntime==1.14.1
!pip install git+https://github.com/microsoft/onnx-script

In [None]:
import torch
import torch.fx
import torch.nn as nn
print(torch.__version__)

In [None]:
import onnxscript
from onnxscript.onnx_opset import opset17 as op

custom_opset = onnxscript.values.Opset(domain="torch.onnx", version=1)


@onnxscript.script(custom_opset)
def ScaledDotProductAttention(
    query,
    key,
    value,
    dropout_p,
):
    # Swap the last two axes of key
    key_shape = op.Shape(key)
    key_last_dim = key_shape[-1:]
    key_second_last_dim = key_shape[-2:-1]
    key_first_dims = key_shape[:-2]
    # Contract the dimensions that are not the last two so we can transpose
    # with a static permutation.
    key_squeezed_shape = op.Concat(
        op.Constant(value_ints=[-1]), key_second_last_dim, key_last_dim, axis=0
    )
    key_squeezed = op.Reshape(key, key_squeezed_shape)
    key_squeezed_transposed = op.Transpose(key_squeezed, perm=[0, 2, 1])
    key_transposed_shape = op.Concat(key_first_dims, key_last_dim, key_second_last_dim, axis=0)
    key_transposed = op.Reshape(key_squeezed_transposed, key_transposed_shape)

    embedding_size = op.CastLike(op.Shape(query)[-1], query)
    scale = op.Div(1.0, op.Sqrt(embedding_size))

    # https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653
    # Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math
    query_scaled = op.Mul(query, op.Sqrt(scale))
    key_transposed_scaled = op.Mul(key_transposed, op.Sqrt(scale))
    attn_weight = op.Softmax(
        op.MatMul(query_scaled, key_transposed_scaled),
        axis=-1,
    )
    attn_weight, _ = op.Dropout(attn_weight, dropout_p)
    return op.MatMul(attn_weight, value)


def custom_scaled_dot_product_attention(g, query, key, value, attn_mask, dropout, is_causal, scale=None):
    return g.onnxscript_op(ScaledDotProductAttention, query, key, value, dropout).setType(value.type())


torch.onnx.register_custom_op_symbolic(
    symbolic_name="aten::scaled_dot_product_attention",
    symbolic_fn=custom_scaled_dot_product_attention,
    opset_version=17,
)


In [None]:
from torch_tensorrt.fx.utils import LowerPrecision
import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter

In [None]:
from torchvision.models.detection import fasterrcnn_mobilenet_v3_large_320_fpn

model = fasterrcnn_mobilenet_v3_large_320_fpn(weights="DEFAULT").cuda().eval()

In [None]:
# from diffusers import StableDiffusionPipeline
from diffusers import DiffusionPipeline
from diffusers import DDIMScheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline

# Use the DDIMScheduler scheduler here instead
scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1",
                                            subfolder="scheduler")

# pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16)
# pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
# pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1")
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1",
                                                custom_pipeline="stable_diffusion_tensorrt_txt2img",
                                                revision='fp16',
                                                torch_dtype=torch.float16,
                                                scheduler=scheduler,)
pipe.set_cached_folder("stabilityai/stable-diffusion-2-1", revision='fp16',)
pipe = pipe.to("cuda")

In [None]:
model = torch.hub.load("Megvii-BaseDetection/YOLOX", "yolox_l").cuda().eval()

In [None]:
model

In [None]:
model.backbone

In [None]:
pipe

In [None]:
dir(pipe)

In [None]:
pipe.unet.conv_in

In [None]:
batch_size = 1
H, W = 512, 512

vae_input = [torch.randn((batch_size, 3, H // 8, W // 8)).cuda()]
unet_input = [torch.randn((batch_size, 4, H // 8, W // 8)).cuda()]

In [None]:
model = pipe.unet
# model

In [None]:
from PIL import Image
import torchvision.transforms.functional as F

img = Image.open("test.jpg")
img = img.resize((640, 480))
img = F.to_tensor(img).unsqueeze(0).cuda()
inputs = [img]

In [None]:
# acc_tracer is a custom fx tracer that maps nodes whose targets are PyTorch operators
# to acc ops.
traced = acc_tracer.trace(model, unet_input)

In [None]:
traced

In [None]:
# Splitter will split the model into several submodules. The name of submodules will
# be either `run_on_acc_{}` or `run_on_gpu_{}`. Submodules named `run_on_acc_{}` can
# be fully lowered to TensorRT via fx2trt while submodules named `run_on_gpu_{}` has
# unsupported ops and can't be lowered by fx2trt. We can still run `run_on_gpu_{}`
# submodules on Gpu if ops there have cuda implementation, the naming is a bit
# confusing and we'll improve it.
splitter = TRTSplitter(traced, unet_input)
splitter

In [None]:
# Preview functionality allows us to see what are the supported ops and unsupported
# ops. We can optionally the dot graph which will color supported ops and unsupported
# ops differently.
_ = splitter.node_support_preview(dump_graph=False)

In [None]:
splitter.non_acc_submodule_name

In [None]:
# Split.
split_mod = splitter()

In [None]:
# After split we have three submodules, _run_on_acc_0 and _run_on_gpu_1.
print(split_mod.graph)

In [None]:
def get_submod_inputs(mod, submod, inputs):
    acc_inputs = None

    def get_input(self, inputs):
        nonlocal acc_inputs
        acc_inputs = inputs

    handle = submod.register_forward_pre_hook(get_input)
    mod(*inputs)
    handle.remove()
    return acc_inputs


# Since the model is splitted into three segments. We need to lower each TRT eligible segment.
# If we know the model can be fully lowered, we can skip the splitter part.
for name, _ in split_mod.named_children():
    print(f"Splitting {name}")
    if "_run_on_acc" in name:
        submod = getattr(split_mod, name)
        # Get submodule inputs for fx2trt
        acc_inputs = get_submod_inputs(split_mod, submod, inputs)

        # fx2trt replacement
        interp = TRTInterpreter(
            submod,
            InputTensorSpec.from_tensors(acc_inputs),
            explicit_batch_dimension=True,
        )
        r = interp.run(lower_precision=LowerPrecision.FP32)
        trt_mod = TRTModule(*r)
        setattr(split_mod, name, trt_mod)

In [None]:
lowered_model_output = split_mod(*inputs)

# Save and load model
torch.save(split_mod, "trt.pt")
reload_trt_mod = torch.load("trt.pt")
reload_model_output = reload_trt_mod(*inputs)

In [None]:
reload_model_output

In [None]:
# Make sure the results match
regular_model_output = model.backbone(*inputs)

In [None]:
regular_model_output

In [None]:
%%timeit -n 100
_ = model.backbone(*inputs)

In [None]:
%%timeit -n 100
_ = reload_trt_mod(*inputs)