In [1]:
import torch
from transformers import AutoImageProcessor, AutoModelForZeroShotImageClassification, AutoTokenizer, ZeroShotImageClassificationPipeline, SiglipProcessor
from modeling_siglip import SiglipModel
from torch2trt import torch2trt

model = SiglipModel.from_pretrained('siglip-large-epoch5-augv2-upscale_0.892_cont_5ep_0.905', torch_dtype=torch.float16).cuda()

In [2]:
text_model = model.text_model
vision_model = model.vision_model

In [11]:
dummy = torch.ones(1, 3, 384, 384, dtype=torch.float16, device='cuda')
vision_model(dummy)

BaseModelOutputWithPooling(last_hidden_state=tensor([[[ 2.9453, -0.6533,  0.4304,  ...,  0.6572, -0.6221, -1.0801],
         [-2.0645, -0.8403, -1.4346,  ..., -1.4717, -0.5347, -0.4978],
         [-0.4597, -0.1000, -1.9961,  ..., -1.2549, -1.2227, -0.5923],
         ...,
         [-1.4844, -0.2078,  0.4170,  ...,  0.6338, -0.3677, -0.3564],
         [ 1.7197, -1.0742,  0.5806,  ...,  1.1084,  0.3052, -0.2898],
         [ 1.5098, -0.5830,  0.6035,  ..., -1.5576, -1.0205, -0.1233]]],
       device='cuda:0', dtype=torch.float16, grad_fn=<NativeLayerNormBackward0>), pooler_output=tensor([[ 1.1611,  0.6851, -0.2966,  ...,  0.6460,  0.0237,  0.5977]],
       device='cuda:0', dtype=torch.float16, grad_fn=<SelectBackward0>), hidden_states=None, attentions=None)

In [None]:
model_trt = torch2trt(vision_model, [dummy], fp16_mode=True, min_shapes=[(1, 3, 384, 384)], opt_shapes=[(4, 3, 384, 384)], max_shapes=[(10, 3, 384, 384)], use_onnx=True)

In [12]:
y = vision_model(dummy).pooler_output
y_trt = model_trt(dummy)['pooler_output']
print(torch.max(torch.abs(y - y_trt)))

tensor(0.0156, device='cuda:0', dtype=torch.float16, grad_fn=<MaxBackward1>)


In [13]:
torch.save(model_trt.state_dict(), 'vision_trt.pth')

In [2]:
from torch2trt import TRTModule
import torch

model_trt = TRTModule()

model_trt.load_state_dict(torch.load('vision_trt.pth'))

<All keys matched successfully>

In [3]:
dummy = torch.ones(1, 64, dtype=torch.long, device='cuda')
text_model(dummy)

input_ids tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')
attention_mask None
position_ids None
output_attentions None
output_hidden_states None
return_dict None


BaseModelOutputWithPooling(last_hidden_state=tensor([[[-0.2983, -0.0971, -0.3064,  ...,  0.0174, -0.1897, -0.2632],
         [-0.2986, -0.1025, -0.3005,  ...,  0.0073, -0.1958, -0.2559],
         [ 0.3943, -0.1912,  0.0609,  ..., -1.0537, -0.3950, -0.9185],
         ...,
         [-0.6899, -0.2805,  0.0827,  ..., -0.5435, -0.0458, -0.5972],
         [-0.3210, -0.5361,  0.6953,  ..., -0.3293,  0.3630, -0.8301],
         [-0.9072, -1.2393,  0.5645,  ..., -1.2764, -0.6094, -0.6641]]],
       device='cuda:0', dtype=torch.float16, grad_fn=<NativeLayerNormBackward0>), pooler_output=tensor([[-0.0044,  0.5698, -0.6294,  ...,  1.7988, -0.3899, -0.4961]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

In [None]:
model_trt = torch2trt(text_model, [dummy], fp16_mode=True, min_shapes=[(1, 64)], opt_shapes=[(1, 64)], max_shapes=[(1, 64)], use_onnx=True)

In [10]:
y = text_model(dummy).pooler_output
y_trt = model_trt(dummy)['pooler_output']
print(torch.max(torch.abs(y - y_trt)))

input_ids tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')
attention_mask None
position_ids None
output_attentions None
output_hidden_states None
return_dict None
tensor(0.0156, device='cuda:0', dtype=torch.float16, grad_fn=<MaxBackward1>)


In [11]:
torch.save(model_trt.state_dict(), 'text_trt.pth')

In [3]:
torch.onnx.export(
      vision_model,
      torch.ones(1, 3, 384, 384, dtype=torch.float32),
      "vision.onnx",
      opset_version=17,
      input_names=["input"],
      output_names=["output"],
      dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
)

  if interpolate_pos_encoding:
  if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
  if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):


In [4]:
torch.onnx.export(
      text_model,
      torch.zeros(1, 64, dtype=torch.int64),
      "text.onnx",
      opset_version=17,
      input_names=["input"],
      output_names=["output"],
      dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
)

In [None]:
!trtexec --onnx=vision.onnx --saveEngine=vision.engine --minShapes=input:1x3x384x384 --maxShapes=input:10x3x384x384 --optShapes=input:4x3x384x384 --memPoolSize=workspace:8000 --fp16
!trtexec --onnx=text.onnx --saveEngine=text.engine --minShapes=input:1x64 --maxShapes=input:1x64 --optShapes=input:1x64 --memPoolSize=workspace:8000 --fp16