In [None]:
!pip install transformers diffusers optimum

In [None]:
import os; print(os.getenv("CUDA_MODULE_LOADING", None))
print(os.getenv("HF_HOME"))
print(os.getenv("TRANSFORMERS_CACHE"))
print(os.getenv("TORCH_HOME"))

In [None]:
import torch
import torch.fx
import torch.nn as nn
print(torch.__version__)

In [None]:
import tensorrt
print(tensorrt.__version__)

In [None]:
import torch_tensorrt
print(torch_tensorrt.__version__)

In [None]:
import torch_tensorrt
from torch_tensorrt.fx.utils import LowerPrecision
import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter

In [None]:
from torch_tensorrt.fx.tracer.acc_tracer import acc_ops
import torch_tensorrt.fx.converter_registry as registry

In [None]:
print(acc_ops.expand in registry.CONVERTERS.keys())
registry.CONVERTERS.pop(acc_ops.expand)
print(acc_ops.expand in registry.CONVERTERS.keys())
# for k in registry.CONVERTERS.keys():
#     print(k)

In [None]:
import requests
from PIL import Image

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
image

In [None]:
import warnings
from transformers import CLIPProcessor, CLIPModel
from transformers import AutoProcessor, CLIPVisionModel


# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# CLIPModel: txt + vision
# model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").cuda().eval()

# CLIPVIsionModel: vision
model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32").cuda().eval()

In [None]:
from transformers import CLIPModel

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").cuda().eval()
model = model.vision_model

In [None]:
model

In [None]:
from PIL import Image
import torchvision.transforms.functional as F

# CLIP
inputs = [torch.randn((1, 3, 224, 224), dtype=torch.float32, device='cuda')]

# >> Inference
with torch.inference_mode():
    output = model(inputs[0])

In [None]:
from torch import fx
from transformers.utils.fx import HFTracer, get_concrete_args


input_names = list(["pixel_values"])
concrete_args = get_concrete_args(model, input_names)
concrete_args

args = {
    "pixel_values": torch.randn((1, 3, 224, 224), dtype=torch.float32, device='cuda')
}
concrete_args = {
    "output_attentions": None,
    "output_hidden_states": None,
    "return_dict": True
}

inputs = [
    args["pixel_values"], 
]

tracer = HFTracer()
traced_graph = tracer.trace(model, concrete_args=concrete_args, dummy_inputs=args)
traced_model = torch.fx.GraphModule(model, traced_graph)

# traced_model.config = model.config
# # The model class must be stored as an attribute to allow model deserialization, which uses trace, and thus
# # _generate_dummy_input, where the model class is needed.
# traced_model.class_for_deserialization = model.__class__
# # traced_model.device = model.device

In [None]:
# TRANSFORMERS FX Tracing
from transformers.utils.fx import symbolic_trace

# For CLIPModel: input_names=["input_ids", "pixel_values"]
# For CLIPVIsionModel: input_names=["pixel_values"]
with torch.inference_mode():
    traced = symbolic_trace(
        model, input_names=["pixel_values"],
        disable_check=True,
    )
type(traced)

In [None]:
traced_model

In [None]:
from transformers.models.clip.modeling_clip import CLIPVisionTransformer, CLIPAttention, CLIPVisionEmbeddings, CLIPEncoder

# acc_tracer is a custom fx tracer that maps nodes whose targets are PyTorch operators
# to acc ops.
with torch.inference_mode():
    trt_traced = acc_tracer.trace(
        traced_model, inputs, 
    )

In [None]:
trt_traced

In [None]:
# Splitter will split the model into several submodules. The name of submodules will
# be either `run_on_acc_{}` or `run_on_gpu_{}`. Submodules named `run_on_acc_{}` can
# be fully lowered to TensorRT via fx2trt while submodules named `run_on_gpu_{}` has
# unsupported ops and can't be lowered by fx2trt. We can still run `run_on_gpu_{}`
# submodules on Gpu if ops there have cuda implementation, the naming is a bit
# confusing and we'll improve it.
splitter = TRTSplitter(trt_traced, inputs)
splitter

In [None]:
# Preview functionality allows us to see what are the supported ops and unsupported
# ops. We can optionally the dot graph which will color supported ops and unsupported
# ops differently.
_ = splitter.node_support_preview(dump_graph=False)

In [None]:
print("Non Acc Nodes")
print(splitter.non_acc_submodule_name)

# splitter.sample_input
# splitter.split_preview()
# dir(splitter)
# print("Acc Nodes")
# print(splitter.acc_nodes)

In [None]:
# Split.
split_mod = splitter()

In [None]:
type(split_mod)

In [None]:
# After split we have three submodules, _run_on_acc_0 and _run_on_gpu_1.
print(split_mod.graph)

In [None]:
def get_submod_inputs(_mod, _submod, _inputs):
    acc_inputs = None

    def get_input(self, __inputs):
        nonlocal acc_inputs
        acc_inputs = __inputs

    handle = _submod.register_forward_pre_hook(get_input)
    # with torch.inference_mode():
    _mod(*_inputs)
    handle.remove()
    return acc_inputs


# Since the model is splitted into three segments. We need to lower each TRT eligible segment.
# If we know the model can be fully lowered, we can skip the splitter part.
for name, _ in split_mod.named_children():
    print(f"Splitting {name}")
    if "_run_on_acc" in name:
        submod = getattr(split_mod, name)

        # Get submodule inputs for fx2trt
        acc_inputs = get_submod_inputs(split_mod, submod, inputs)
        # print(f"submod: {submod}")
        # print(f"name: {name}")

        # fx2trt replacement
        interp = TRTInterpreter(
            submod,
            InputTensorSpec.from_tensors(acc_inputs),
            explicit_batch_dimension=True,
        )
        r = interp.run(lower_precision=LowerPrecision.FP32)
        trt_mod = TRTModule(*r)
        setattr(split_mod, name, trt_mod)

In [None]:
# from pathlib import Path
# from nos.constants import NOS_MODELS_DIR

# model_dir = Path(NOS_MODELS_DIR, f"cache/{MODEL_NAME}")
# model_dir.mkdir(parents=True, exist_ok=True)

# lowered_model_output = split_mod(*inputs)

# Save and load model
# W, H = 224, 224
# model_id = MODEL_NAME.replace("/", "-") + "_" + f"{W}x{H}" + "_" + "fp16"
# filename = f"{model_dir}/{model_id}.torchtrt.pt"
filename = "clip_224x224.torchtrt.pt"
torch.save(split_mod, filename)

In [None]:
print(f"Loading torchtrt filename: {filename}")
reload_trt_mod = torch.load(filename)

In [None]:
with torch.inference_mode():
    reload_model_output = reload_trt_mod(*inputs)
reload_model_output

In [None]:
# Make sure the results match
with torch.inference_mode():
    regular_model_output = model(*inputs)
dict(regular_model_output)

In [None]:
%%timeit -n 100
with torch.inference_mode():
    _ = model(*inputs)

In [None]:
%%timeit -n 100
with torch.inference_mode():
    _ = reload_trt_mod(*inputs)

In [None]:
for o1, o2 in zip(reload_model_output, regular_model_output):
    torch.testing.assert_close(
        o1.cpu().float(), o2.cpu().float(), rtol=2e-02, atol=2e-02, equal_nan=True
    )