In [1]:
import numpy as np
from cuda import cudart
import torch
from torch import Tensor, nn
import tensorrt as trt
import math

In [2]:
config = dict()

batch_size, seq_len, hidden_size = 4, 45, 4096
intermediate_size = 11008
num_attention_heads = 32
num_key_value_heads = 32
max_position_embeddings = 2048
rope_theta = 10000.0

config["hidden_size"] = hidden_size
config["intermediate_size"] = intermediate_size
config["num_heads"] = num_attention_heads
config["head_dim"] = config["hidden_size"] // config["num_heads"]
config["num_key_value_heads"] = num_key_value_heads
config["num_key_value_groups"] = config["num_heads"] // config["num_key_value_heads"]
config["max_position_embeddings"] = max_position_embeddings
config["rope_theta"] = rope_theta

In [3]:
data = torch.ones(batch_size, seq_len, hidden_size)
attention_mask = torch.ones(batch_size, 1, seq_len, seq_len)
position_ids = torch.arange(0, seq_len)
position_ids = position_ids.repeat(batch_size, 1)

## tensorRT dynamic shape

In [4]:
# seq length is not specified, since it is a dynamic size
def trt_create(batch_size, hidden_size, intermediate_size, model):
    
    logger = trt.Logger(trt.Logger.ERROR)
    builder = trt.Builder(logger)

    network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    config = builder.create_builder_config()

    # input
    inputT0 = network.add_input('inputT0', trt.DataType.FLOAT, (batch_size, -1, hidden_size))

    # dynamic shape optimization
    profile = builder.create_optimization_profile();
    profile.set_shape("inputT0", (batch_size, 1, hidden_size), (batch_size, 1, hidden_size), (batch_size, 45, hidden_size))
    config.add_optimization_profile(profile)
    
    q_proj_weight = np.ones((hidden_size, model['head_dim']), dtype=np.float32)
    q_proj_weight = q_proj_weight.reshape(1, hidden_size, model['head_dim'])

    input_shape = [1, hidden_size, model['head_dim']]
    
    q_proj_weight_layer = network.add_constant(shape=input_shape, weights=trt.Weights(q_proj_weight))

    q_proj_layer = network.add_matrix_multiply(inputT0, trt.MatrixOperation.NONE, q_proj_weight_layer.get_output(0), trt.MatrixOperation.NONE)

    # output
    network.mark_output(q_proj_layer.get_output(0))

    engineString = builder.build_serialized_network(network, config)
    
    return engineString

In [5]:
trt_engineStr = trt_create(batch_size, hidden_size, intermediate_size, config)

In [6]:
def trt_inference(batch_size, hidden_size, engineString, raw_data):
#     print(engineString)
#     print("Runtime")
    logger = trt.Logger(trt.Logger.ERROR)
    engine = trt.Runtime(logger).deserialize_cuda_engine(engineString)
    context = engine.create_execution_context()

#     dynamic shape configure
    print("Set input shape", (batch_size, 15, hidden_size))
    context.set_input_shape("inputT0", (4, 15, 4096))
    context.set_binding_shape(0, (batch_size, 15, hidden_size))
    origin_inputshape = context.get_binding_shape(0)

    print("Set input shape completed")

    data = np.array(raw_data)

    _, stream = cudart.cudaStreamCreate()
#     print("Reshaping")

    inputH0 = np.ascontiguousarray(data.reshape(-1))
    outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))
#     print("Reshaped")

    # initialize input and output data
    _, inputD0 = cudart.cudaMallocAsync(inputH0.nbytes, stream)
    _, outputD0 = cudart.cudaMallocAsync(outputH0.nbytes, stream)

    # move input to device
    cudart.cudaMemcpyAsync(inputD0, inputH0.ctypes.data, inputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, stream)

    # execute
#     print("execute")
    context.execute_async_v2([int(inputD0), int(outputD0)], stream)

    # move output back to host
    cudart.cudaMemcpyAsync(outputH0.ctypes.data, outputD0, outputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, stream)

    # wait for everythidden_sizeg
    cudart.cudaStreamSynchronize(stream)

    cudart.cudaStreamDestroy(stream)
    cudart.cudaFree(inputD0)
    cudart.cudaFree(outputD0)

    return outputH0

In [7]:
data = data.reshape(batch_size, seq_len, hidden_size)
print(data.shape)

trt_output = trt_inference(batch_size, hidden_size, trt_engineStr, data)

# trt_output = trt_output.reshape(batch_size, seq_len, hidden_size)
print("output_trt :", trt_output.shape)
print(trt_output)

torch.Size([4, 45, 4096])
Set input shape (4, 15, 4096)
Set input shape completed
output_trt : (4, 15, 128)
[[[4096. 4096. 4096. ... 4096. 4096. 4096.]
  [4096. 4096. 4096. ... 4096. 4096. 4096.]
  [4096. 4096. 4096. ... 4096. 4096. 4096.]
  ...
  [4096. 4096. 4096. ... 4096. 4096. 4096.]
  [4096. 4096. 4096. ... 4096. 4096. 4096.]
  [4096. 4096. 4096. ... 4096. 4096. 4096.]]

 [[4096. 4096. 4096. ... 4096. 4096. 4096.]
  [4096. 4096. 4096. ... 4096. 4096. 4096.]
  [4096. 4096. 4096. ... 4096. 4096. 4096.]
  ...
  [4096. 4096. 4096. ... 4096. 4096. 4096.]
  [4096. 4096. 4096. ... 4096. 4096. 4096.]
  [4096. 4096. 4096. ... 4096. 4096. 4096.]]

 [[4096. 4096. 4096. ... 4096. 4096. 4096.]
  [4096. 4096. 4096. ... 4096. 4096. 4096.]
  [4096. 4096. 4096. ... 4096. 4096. 4096.]
  ...
  [4096. 4096. 4096. ... 4096. 4096. 4096.]
  [4096. 4096. 4096. ... 4096. 4096. 4096.]
  [4096. 4096. 4096. ... 4096. 4096. 4096.]]

 [[4096. 4096. 4096. ... 4096. 4096. 4096.]
  [4096. 4096. 4096. ... 4096. 4

  context.set_binding_shape(0, (batch_size, 15, hidden_size))
  origin_inputshape = context.get_binding_shape(0)
  outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))
  outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))
