In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import maskclip_onnx

class MaskclipBackbone(nn.Module):
    def __init__(self, model_name="ViT-B/16", convert_to_fp16=False):
        super().__init__()
        self.model_name = model_name
        self.model, _ = maskclip_onnx.clip.load(
            model_name,
            download_root=os.getenv('TORCH_HOME', os.path.join(os.path.expanduser('~'), '.cache', 'torch')),
            convert_to_fp16=convert_to_fp16  # has to be false for ONNX export in torch>=2.0.0
        )

    def forward(self, img):
        features = self.model.get_patch_encodings(img)
        return features

  param_schemas = callee.param_schemas()
  param_schemas = callee.param_schemas()


In [2]:
torch.__version__

'2.4.1+cu118'

In [3]:
my_clip_backbone = MaskclipBackbone(convert_to_fp16=True).cuda().eval()

In [4]:
# set seed
torch.manual_seed(0)
test_tensor = torch.randn((64, 3, 240, 320), dtype=torch.float32).cuda()

In [5]:
# FP16 inference
my_clip_backbone = MaskclipBackbone(convert_to_fp16=True).cuda().eval()
import time

with torch.no_grad():
    # warm up
    for _ in range(10):
        output = my_clip_backbone(test_tensor)

    start_cp = time.time()
    TEST_TIME = 100
    for _ in range(TEST_TIME):
        output = my_clip_backbone(test_tensor)

    end_cp = time.time()
print(f"Maskclip inference time: {(end_cp - start_cp) / TEST_TIME * 1000} ms")

Maskclip inference time: 100.33645868301392 ms


In [6]:
# FP32 inference. Subsequent steps need to use FP32 version.
my_clip_backbone = MaskclipBackbone(convert_to_fp16=False).cuda().eval()
import time

with torch.no_grad():
    # warm up
    for _ in range(10):
        output = my_clip_backbone(test_tensor)

    start_cp = time.time()
    TEST_TIME = 100
    for _ in range(TEST_TIME):
        output = my_clip_backbone(test_tensor)

    end_cp = time.time()
print(f"Maskclip inference time: {(end_cp - start_cp) / TEST_TIME * 1000} ms")

Maskclip inference time: 256.88424348831177 ms


In [7]:
# onnx_program = torch.onnx.dynamo_export(my_export_wrapper, test_tensor)
torch.onnx.export(
    my_clip_backbone,                  # model to export
    test_tensor,        # inputs of the model,
    "test_model.onnx",        # filename of the ONNX model
    input_names = ['input'],   # the model's input names
    export_params=True,
)

# torch.onnx.export(my_export_model, input_tensor, 'exported_clip.onnx', export_params=True)

  if num_patches == num_og_patches and w == h:
  assert w0 * h0 == num_patches, "Number of patches does not match"
  patch_per_ax = int(np.sqrt(num_og_patches))
  scale_factor=(float(w0 / patch_per_ax), float(h0 / patch_per_ax)),
  int(w0) == patch_pos_embed_interp.shape[-2] and int(h0) == patch_pos_embed_interp.shape[-1]
  assert (


In [8]:
from maskclip_onnx.onnx_tensorrt import TensorRTBackend

In [9]:
trt_engine = TensorRTBackend.prepare("test_model.onnx",
                                        device='CUDA:0',
                                        serialize_engine=True,
                                        verbose=False,
                                        serialized_engine_path="test_model.trt")

FAST FP16 detected. Enabling precision to FP16...
Loading serialized engine from test_model_fp16.trt


In [10]:
output_trt = trt_engine.run(test_tensor, 'torch_cuda')
with torch.no_grad():
    output_vanilla = my_clip_backbone(test_tensor).cpu().numpy()

In [11]:
output_trt[0].device

device(type='cuda', index=0)

In [12]:
np.isclose(output_trt[0].cpu().numpy(), output_vanilla, rtol=1e-2, atol=1e-2)

array([[[ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True],
        ...,
        [ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True]],

       [[ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True],
        ...,
        [ True,  True,  True, ...,  True,  True,  True],
        [False,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True, False]],

       [[ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True],
        [ True, False,  True, ...,  True,  True,  True],
        ...,
        [False,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  Tr

In [13]:
import time

# Profiling with device-host-device-host-device memory transfer
with torch.no_grad():
    # warm up
    for _ in range(10):
        arr = test_tensor.cpu().numpy()
        output = trt_engine.run(arr)
        output = torch.tensor(output[0]).cuda()

    start_cp = time.time()
    TEST_TIME = 100
    for _ in range(TEST_TIME):
        arr = test_tensor.cpu().numpy()
        output = trt_engine.run(arr)
        output = torch.tensor(output[0]).cuda()

    end_cp = time.time()
print(f"Maskclip inference time: {(end_cp - start_cp) / TEST_TIME * 1000} ms")

Maskclip inference time: 90.04132509231567 ms


In [14]:
import time

# Profiling with host-device-host memory transfer
with torch.no_grad():
    # warm up
    arr = test_tensor.cpu().numpy()
    for _ in range(10):
        output = trt_engine.run(arr)

    start_cp = time.time()
    TEST_TIME = 100
    for _ in range(TEST_TIME):
        output = trt_engine.run(arr)

    end_cp = time.time()
print(f"Maskclip inference time: {(end_cp - start_cp) / TEST_TIME * 1000} ms")

Maskclip inference time: 51.995224952697754 ms


In [15]:
import time

# Profiling with device-host memory transfer
with torch.no_grad():
    # warm up
    for _ in range(10):
        output = trt_engine.run(test_tensor, input_output_mode='torch_cuda')

    start_cp = time.time()
    TEST_TIME = 100
    for _ in range(TEST_TIME):
        output = trt_engine.run(test_tensor, input_output_mode='torch_cuda')

    end_cp = time.time()
print(f"Maskclip inference time: {(end_cp - start_cp) / TEST_TIME * 1000} ms")

Maskclip inference time: 47.76280879974365 ms
