In [1]:
import numpy as np
from cuda import cudart
import torch
from torch import Tensor, nn
import tensorrt as trt
import math

## Generate input and data shape

In [2]:
config = dict()

batch_size, seq_len, hidden_size = 4, 45, 4096
intermediate_size = 11008
num_attention_heads = 32
num_key_value_heads = 32
max_position_embeddings = 2048
rope_theta = 10000.0

config["hidden_size"] = hidden_size
config["intermediate_size"] = intermediate_size
config["num_heads"] = num_attention_heads
config["head_dim"] = config["hidden_size"] // config["num_heads"]
config["num_key_value_heads"] = num_key_value_heads
config["num_key_value_groups"] = config["num_heads"] // config["num_key_value_heads"]
config["max_position_embeddings"] = max_position_embeddings
config["rope_theta"] = rope_theta

In [3]:
data = torch.ones(batch_size, seq_len, hidden_size)
attention_mask = torch.ones(batch_size, 1, seq_len, seq_len)
position_ids = torch.arange(0, seq_len)
position_ids = position_ids.repeat(batch_size, 1)

## torch attention

In [4]:
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    
    repeat at the second dimension
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

In [5]:
class LlamaRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        t = torch.arange(max_position_embeddings, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(torch.get_default_dtype()), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(torch.get_default_dtype()), persistent=False)

    def rotate_half(self, x):
        """Rotates half the hidden dims of the input."""
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        return torch.cat((-x2, x1), dim=-1)

    def forward(self, q, k, v, position_ids, seq_len=None):
        # v: [bs, num_attention_heads, seq_len, head_size]
        cos = self.cos_cached[:, :, :seq_len, ...].to(dtype=v.dtype)
        sin = self.sin_cached[:, :, :seq_len, ...].to(dtype=v.dtype)
        cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
        sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
        cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
        sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]

        # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
        q_embed = (q * cos) + (self.rotate_half(q) * sin)
        k_embed = (k * cos) + (self.rotate_half(k) * sin)
        return q_embed, k_embed


In [6]:
class LlamaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: dict):
        super().__init__()
        self.config = config
        self.hidden_size = config["hidden_size"]
        self.num_heads = config["num_heads"]
        self.head_dim = config["hidden_size"] // config["num_heads"]
        self.num_key_value_heads = config["num_key_value_heads"]
        self.num_key_value_groups = config["num_heads"] // config["num_key_value_heads"]
        self.max_position_embeddings = config["max_position_embeddings"]
        self.rope_theta = config["rope_theta"]

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
        self._init_rope()

    def _init_rope(self):
        print(
            "init rope",
            self.head_dim,
            self.max_position_embeddings,
            self.rope_theta,
        )
        self.rotary_emb = LlamaRotaryEmbedding(
            self.head_dim,
            max_position_embeddings=self.max_position_embeddings,
            base=self.rope_theta,
        )

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()


    def load(self, dir):
        weights = torch.load(dir)
        self_attn_weights = dict()
        for key in weights.keys():
            print(key)
            if key == "model.layers.18.self_attn.rotary_emb.inv_freq":
                print(weights[key])
                continue
            if key.split(".")[3] == "self_attn":
                self_attn_weights[key[key.find(key.split(".")[4]):]] = weights[key]

        self.load_state_dict(self_attn_weights)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: None,
        position_ids: None,
        past_key_value: None,
        output_attentions: bool = False,
        use_cache: bool = False,
    ):
        # bsz = batch size; q_len = query length; _ = hidden size
        bsz, q_len, _ = hidden_states.size()

        # do projection
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # reshape
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

#         #####################################################
#         # in hugging face, they do have kv cache, however, they don't have other attention optimization
#         # this could be done directly in tensorRT by using dynamic shape
#         kv_seq_len = key_states.shape[-2]
#         if past_key_value is not None:
#             kv_seq_len += past_key_value[0].shape[-2]

#         query_states, key_states = self.rotary_emb(query_states, key_states, value_states, position_ids, seq_len=q_len)

#         if past_key_value is not None:
#             # reuse k, v, self_attention
#             key_states = torch.cat([past_key_value[0], key_states], dim=2)
#             value_states = torch.cat([past_key_value[1], value_states], dim=2)

#         past_key_value = (key_states, value_states) if use_cache else None

#         print(self.num_key_value_groups)
#         # repeat k/v heads if n_kv_heads < n_heads
#         key_states = repeat_kv(key_states, self.num_key_value_groups)
#         value_states = repeat_kv(value_states, self.num_key_value_groups)
#         #####################################################

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        # attention_mask needs to be infered
        attn_weights = attn_weights + attention_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)

        attn_output = torch.matmul(attn_weights, value_states)

        attn_output = attn_output.transpose(1, 2) # .contiguous()
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        return attn_output, attn_weights, None

        # since normally it will be false
#         if not output_attentions:
#             attn_weights = None

#         return attn_output, attn_weights, past_key_value

## Test torch

In [7]:
model = LlamaAttention(config)

device = torch.device("cuda")

model.load("/home/fuchiang137/.cache/huggingface/hub/models--decapoda-research--llama-7b-hf/snapshots/5f98eefcc80e437ef68d457ad7bf167c2c6a1348/pytorch_model-00019-of-00033.bin")
model = model.to(device)

data_D = data.to(device)
attention_mask_D = attention_mask.to(device)
position_ids_D = position_ids.to(device)
# output = model(data)

past_key_value = None

# attentiona mask
# position_ids
# specifies the position id of the corresponding hidden state tensor element
# e.g. hid = [3, 4, 6] => pos_id = [0, 1, 2]
# past_key_value
# if use cache, past key value will contain past kv values
output = model(hidden_states=data_D,
               attention_mask=attention_mask_D,
               position_ids=position_ids_D,
               past_key_value=past_key_value,
               output_attentions=False,
               use_cache=True,)

init rope 128 2048 10000.0
model.layers.18.self_attn.q_proj.weight
model.layers.18.self_attn.k_proj.weight
model.layers.18.self_attn.v_proj.weight
model.layers.18.self_attn.o_proj.weight
model.layers.18.mlp.gate_proj.weight
model.layers.18.mlp.down_proj.weight
model.layers.18.mlp.up_proj.weight
model.layers.18.input_layernorm.weight
model.layers.18.post_attention_layernorm.weight
model.layers.18.self_attn.rotary_emb.inv_freq
tensor([1.0000e+00, 8.6596e-01, 7.4989e-01, 6.4938e-01, 5.6234e-01, 4.8697e-01,
        4.2170e-01, 3.6517e-01, 3.1623e-01, 2.7384e-01, 2.3714e-01, 2.0535e-01,
        1.7783e-01, 1.5399e-01, 1.3335e-01, 1.1548e-01, 1.0000e-01, 8.6596e-02,
        7.4989e-02, 6.4938e-02, 5.6234e-02, 4.8697e-02, 4.2170e-02, 3.6517e-02,
        3.1623e-02, 2.7384e-02, 2.3714e-02, 2.0535e-02, 1.7783e-02, 1.5399e-02,
        1.3335e-02, 1.1548e-02, 1.0000e-02, 8.6596e-03, 7.4989e-03, 6.4938e-03,
        5.6234e-03, 4.8697e-03, 4.2170e-03, 3.6517e-03, 3.1623e-03, 2.7384e-03,
        2.3

In [8]:
attn_output, attn_weights, past_key_value = output

In [9]:
print(attn_output.shape)
print(attn_output[0])

torch.Size([4, 45, 4096])
tensor([[-0.3422, -1.3077,  3.2849,  ..., -1.2631, -0.5749,  1.4960],
        [-0.3422, -1.3077,  3.2849,  ..., -1.2631, -0.5749,  1.4960],
        [-0.3422, -1.3077,  3.2849,  ..., -1.2631, -0.5749,  1.4960],
        ...,
        [-0.3422, -1.3077,  3.2849,  ..., -1.2631, -0.5749,  1.4960],
        [-0.3422, -1.3077,  3.2849,  ..., -1.2631, -0.5749,  1.4960],
        [-0.3422, -1.3077,  3.2849,  ..., -1.2631, -0.5749,  1.4960]],
       device='cuda:0', grad_fn=<SelectBackward0>)


In [10]:
print(attn_weights.shape)
print(attn_weights[0])

torch.Size([4, 32, 45, 45])
tensor([[[0.0222, 0.0222, 0.0222,  ..., 0.0222, 0.0222, 0.0222],
         [0.0222, 0.0222, 0.0222,  ..., 0.0222, 0.0222, 0.0222],
         [0.0222, 0.0222, 0.0222,  ..., 0.0222, 0.0222, 0.0222],
         ...,
         [0.0222, 0.0222, 0.0222,  ..., 0.0222, 0.0222, 0.0222],
         [0.0222, 0.0222, 0.0222,  ..., 0.0222, 0.0222, 0.0222],
         [0.0222, 0.0222, 0.0222,  ..., 0.0222, 0.0222, 0.0222]],

        [[0.0222, 0.0222, 0.0222,  ..., 0.0222, 0.0222, 0.0222],
         [0.0222, 0.0222, 0.0222,  ..., 0.0222, 0.0222, 0.0222],
         [0.0222, 0.0222, 0.0222,  ..., 0.0222, 0.0222, 0.0222],
         ...,
         [0.0222, 0.0222, 0.0222,  ..., 0.0222, 0.0222, 0.0222],
         [0.0222, 0.0222, 0.0222,  ..., 0.0222, 0.0222, 0.0222],
         [0.0222, 0.0222, 0.0222,  ..., 0.0222, 0.0222, 0.0222]],

        [[0.0222, 0.0222, 0.0222,  ..., 0.0222, 0.0222, 0.0222],
         [0.0222, 0.0222, 0.0222,  ..., 0.0222, 0.0222, 0.0222],
         [0.0222, 0.0222, 0.02

In [11]:
print(past_key_value.shape)
print(past_key_value[0])

AttributeError: 'NoneType' object has no attribute 'shape'

## Breaking down LlamaRotaryEmbedding

## tensorRT Attention

In [12]:
# seq length is not specified, since it is a dynamic size
def trt_create(batch_size, hidden_size, intermediate_size, model):
    
    logger = trt.Logger(trt.Logger.ERROR)
    builder = trt.Builder(logger)

    network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    config = builder.create_builder_config()

    # input
    hidden_states = network.add_input('hidden_states', trt.DataType.FLOAT, (batch_size, -1, hidden_size))
    attention_mask = network.add_input('attention_mask', trt.DataType.FLOAT, (batch_size, 1, -1, -1))

    # dynamic shape optimization
    profile = builder.create_optimization_profile();
    profile.set_shape("hidden_states", (batch_size, 1, hidden_size), (batch_size, 1, hidden_size), (batch_size, 45, hidden_size))
    profile.set_shape("attention_mask", (batch_size, 1, 1, 1), (batch_size, 1, 1, 1), (batch_size, 1, 45, 45))
    config.add_optimization_profile(profile)

    # self.q_proj(hidden_states)
    q_proj_weight = model.q_proj.weight.clone().detach().cpu().numpy()
    q_proj_weight = np.expand_dims(q_proj_weight, 0)
    q_proj_weight_shape = list(q_proj_weight.shape)
    q_proj_weight_layer = network.add_constant(shape=q_proj_weight_shape, weights=trt.Weights(q_proj_weight))

    q_proj_layer = network.add_matrix_multiply(hidden_states,
                                               trt.MatrixOperation.NONE,
                                               q_proj_weight_layer.get_output(0),
                                               trt.MatrixOperation.TRANSPOSE)

    # query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    q_proj_shuffle_layer = network.add_shuffle(q_proj_layer.get_output(0))
    q_proj_shuffle_layer.reshape_dims = trt.Dims([batch_size, -1, model.num_heads, model.head_dim])
    q_proj_shuffle_layer.second_transpose = trt.Permutation([0, 2, 1, 3])

    
    # self.k_proj(hidden_states)
    k_proj_weight = model.k_proj.weight.clone().detach().cpu().numpy()
    k_proj_weight = np.expand_dims(k_proj_weight, 0)
    k_proj_weight_shape = list(k_proj_weight.shape)
    k_proj_weight_layer = network.add_constant(shape=k_proj_weight_shape, weights=trt.Weights(k_proj_weight))

    k_proj_layer = network.add_matrix_multiply(hidden_states,
                                               trt.MatrixOperation.NONE,
                                               k_proj_weight_layer.get_output(0),
                                               trt.MatrixOperation.TRANSPOSE)

    # key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    k_proj_shuffle_layer = network.add_shuffle(k_proj_layer.get_output(0))
    k_proj_shuffle_layer.reshape_dims = trt.Dims([batch_size, -1, model.num_heads, model.head_dim])
    k_proj_shuffle_layer.second_transpose = trt.Permutation([0, 2, 3, 1])


    # self.v_proj(hidden_states)
    v_proj_weight = model.v_proj.weight.clone().detach().cpu().numpy()
    v_proj_weight = np.expand_dims(v_proj_weight, 0)
    v_proj_weight_shape = list(v_proj_weight.shape)
    v_proj_weight_layer = network.add_constant(shape=v_proj_weight_shape, weights=trt.Weights(v_proj_weight))

    v_proj_layer = network.add_matrix_multiply(hidden_states,
                                               trt.MatrixOperation.NONE,
                                               v_proj_weight_layer.get_output(0),
                                               trt.MatrixOperation.TRANSPOSE)

    # value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    v_proj_shuffle_layer = network.add_shuffle(v_proj_layer.get_output(0))
    v_proj_shuffle_layer.reshape_dims = trt.Dims([batch_size, -1, model.num_heads, model.head_dim])
    v_proj_shuffle_layer.second_transpose = trt.Permutation([0, 2, 1, 3])

    # attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
    attn_weights_mult_layer = network.add_matrix_multiply(q_proj_shuffle_layer.get_output(0),
                                                          trt.MatrixOperation.NONE,
                                                          k_proj_shuffle_layer.get_output(0),
                                                          trt.MatrixOperation.NONE)

    sqrt_head_dim = np.array([1 / math.sqrt(model.head_dim)], np.float32).reshape(-1)
    attn_weights_layer = network.add_scale(attn_weights_mult_layer.get_output(0),
                                           trt.ScaleMode.UNIFORM,
                                           scale=sqrt_head_dim)

    # attn_weights = attn_weights + attention_mask
    attn_mix_mask = network.add_elementwise(attn_weights_layer.get_output(0),
                                            attention_mask,
                                            op=trt.ElementWiseOperation.SUM)

    # attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
    softmax_layer = network.add_softmax(attn_mix_mask.get_output(0))
    softmax_layer.axes = 1 << 3

    # attn_output = torch.matmul(attn_weights, value_states)
    attn_output_layer = network.add_matrix_multiply(softmax_layer.get_output(0),
                                                    trt.MatrixOperation.NONE,
                                                    v_proj_shuffle_layer.get_output(0),
                                                    trt.MatrixOperation.NONE)

    # attn_output = attn_output.transpose(1, 2).contiguous()
    # attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
    attn_output_shuffle_layer = network.add_shuffle(attn_output_layer.get_output(0))
    attn_output_shuffle_layer.first_transpose = trt.Permutation([0, 2, 1, 3])
    attn_output_shuffle_layer.reshape_dims = trt.Dims([batch_size, -1, hidden_size])

    # attn_output = self.o_proj(attn_output)
    o_proj_weight = model.o_proj.weight.clone().detach().cpu().numpy()
    o_proj_weight = np.expand_dims(o_proj_weight, 0)
    o_proj_weight_shape = list(o_proj_weight.shape)
    o_proj_weight_layer = network.add_constant(shape=o_proj_weight_shape, weights=trt.Weights(o_proj_weight))
    
    o_proj_layer = network.add_matrix_multiply(attn_output_shuffle_layer.get_output(0),
                                               trt.MatrixOperation.NONE,
                                               o_proj_weight_layer.get_output(0),
                                               trt.MatrixOperation.TRANSPOSE)

    # output
    # the order of output will be related to the order of the tensor creation
    network.mark_output(softmax_layer.get_output(0))
    network.mark_output(o_proj_layer.get_output(0))
    network.mark_output(v_proj_shuffle_layer.get_output(0))

    engineString = builder.build_serialized_network(network, config)
    
    return engineString

In [13]:
trt_engineStr = trt_create(batch_size, hidden_size, intermediate_size, model)

In [14]:
def trt_inference(batch_size, hidden_size, engineString, raw_data, raw_attn_mask):
#     print(engineString)
#     print("Runtime")
    logger = trt.Logger(trt.Logger.ERROR)
    engine = trt.Runtime(logger).deserialize_cuda_engine(engineString)
    context = engine.create_execution_context()

    # dynamic shape configure
    print("Set input shape", (batch_size, seq_len, hidden_size))

    context.set_input_shape("hidden_states", (batch_size, seq_len, hidden_size))
    context.set_binding_shape(0, (batch_size, seq_len, hidden_size))

    context.set_input_shape("attention_mask", (batch_size, 1, seq_len, seq_len))
    context.set_binding_shape(1, (batch_size, 1, seq_len, seq_len))
    print("Set input shape completed")

    data = np.array(raw_data)
    attention_mask = np.array(raw_attn_mask)

    _, stream = cudart.cudaStreamCreate()
#     print("Reshaping")

    inputH0 = np.ascontiguousarray(data.reshape(-1))
    inputH1 = np.ascontiguousarray(attention_mask.reshape(-1))
    outputH0 = np.empty(context.get_binding_shape(2), dtype=trt.nptype(engine.get_binding_dtype(2)))
    outputH1 = np.empty(context.get_binding_shape(3), dtype=trt.nptype(engine.get_binding_dtype(3)))
    outputH2 = np.empty(context.get_binding_shape(4), dtype=trt.nptype(engine.get_binding_dtype(4)))
#     print("Reshaped")

    # initialize input and output data
    _, inputD0 = cudart.cudaMallocAsync(inputH0.nbytes, stream)
    _, inputD1 = cudart.cudaMallocAsync(inputH1.nbytes, stream)
    _, outputD0 = cudart.cudaMallocAsync(outputH0.nbytes, stream)
    _, outputD1 = cudart.cudaMallocAsync(outputH1.nbytes, stream)
    _, outputD2 = cudart.cudaMallocAsync(outputH2.nbytes, stream)

    # move input to device
    cudart.cudaMemcpyAsync(inputD0, inputH0.ctypes.data, inputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, stream)
    cudart.cudaMemcpyAsync(inputD1, inputH1.ctypes.data, inputH1.nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, stream)

    # execute
#     print("execute")
    context.execute_async_v2([int(inputD0), int(inputD1), int(outputD0), int(outputD1), int(outputD2)], stream)

    # move output back to host
    cudart.cudaMemcpyAsync(outputH0.ctypes.data, outputD0, outputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, stream)
    cudart.cudaMemcpyAsync(outputH1.ctypes.data, outputD1, outputH1.nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, stream)
    cudart.cudaMemcpyAsync(outputH2.ctypes.data, outputD2, outputH2.nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, stream)

    # wait for everythidden_sizeg
    cudart.cudaStreamSynchronize(stream)

    cudart.cudaStreamDestroy(stream)
    cudart.cudaFree(inputD0)
    cudart.cudaFree(outputD0)
    cudart.cudaFree(outputD1)
    cudart.cudaFree(outputD2)

    return outputH0, outputH1, outputH2

In [15]:
trt_output = trt_inference(batch_size, hidden_size, trt_engineStr, data, attention_mask)

trt_query_states, trt_key_states, trt_value_states = trt_output

Set input shape (4, 45, 4096)
Set input shape completed


  context.set_binding_shape(0, (batch_size, seq_len, hidden_size))
  context.set_binding_shape(1, (batch_size, 1, seq_len, seq_len))
  outputH0 = np.empty(context.get_binding_shape(2), dtype=trt.nptype(engine.get_binding_dtype(2)))
  outputH0 = np.empty(context.get_binding_shape(2), dtype=trt.nptype(engine.get_binding_dtype(2)))
  outputH1 = np.empty(context.get_binding_shape(3), dtype=trt.nptype(engine.get_binding_dtype(3)))
  outputH1 = np.empty(context.get_binding_shape(3), dtype=trt.nptype(engine.get_binding_dtype(3)))
  outputH2 = np.empty(context.get_binding_shape(4), dtype=trt.nptype(engine.get_binding_dtype(4)))
  outputH2 = np.empty(context.get_binding_shape(4), dtype=trt.nptype(engine.get_binding_dtype(4)))


In [16]:
print(trt_query_states.shape)
print(trt_query_states[0])

(4, 32, 45, 128)
[[[-0.7616343  -0.32004273  0.5074545  ...  0.8154828   1.3458583
    1.7691885 ]
  [-0.7616343  -0.32004273  0.5074545  ...  0.8154828   1.3458583
    1.7691885 ]
  [-0.7616343  -0.32004273  0.5074545  ...  0.8154828   1.3458583
    1.7691885 ]
  ...
  [-0.7616343  -0.32004273  0.5074545  ...  0.8154828   1.3458583
    1.7691885 ]
  [-0.7616343  -0.32004273  0.5074545  ...  0.8154828   1.3458583
    1.7691885 ]
  [-0.7616343  -0.32004273  0.5074545  ...  0.8154828   1.3458583
    1.7691885 ]]

 [[ 2.1507761   0.6692885   3.2076879  ...  1.4208608   0.12534922
    0.04899704]
  [ 2.1507761   0.6692885   3.2076879  ...  1.4208608   0.12534922
    0.04899704]
  [ 2.1507761   0.6692885   3.2076879  ...  1.4208608   0.12534922
    0.04899704]
  ...
  [ 2.1507761   0.6692885   3.2076879  ...  1.4208608   0.12534922
    0.04899704]
  [ 2.1507761   0.6692885   3.2076879  ...  1.4208608   0.12534922
    0.04899704]
  [ 2.1507761   0.6692885   3.2076879  ...  1.4208608   0.1253

In [17]:
print(trt_key_states.shape)
print(trt_key_states[0])

(4, 32, 45, 45)
[[[0.02222222 0.02222222 0.02222222 ... 0.02222222 0.02222222 0.02222222]
  [0.02222222 0.02222222 0.02222222 ... 0.02222222 0.02222222 0.02222222]
  [0.02222222 0.02222222 0.02222222 ... 0.02222222 0.02222222 0.02222222]
  ...
  [0.02222222 0.02222222 0.02222222 ... 0.02222222 0.02222222 0.02222222]
  [0.02222222 0.02222222 0.02222222 ... 0.02222222 0.02222222 0.02222222]
  [0.02222222 0.02222222 0.02222222 ... 0.02222222 0.02222222 0.02222222]]

 [[0.02222222 0.02222222 0.02222222 ... 0.02222222 0.02222222 0.02222222]
  [0.02222222 0.02222222 0.02222222 ... 0.02222222 0.02222222 0.02222222]
  [0.02222222 0.02222222 0.02222222 ... 0.02222222 0.02222222 0.02222222]
  ...
  [0.02222222 0.02222222 0.02222222 ... 0.02222222 0.02222222 0.02222222]
  [0.02222222 0.02222222 0.02222222 ... 0.02222222 0.02222222 0.02222222]
  [0.02222222 0.02222222 0.02222222 ... 0.02222222 0.02222222 0.02222222]]

 [[0.02222222 0.02222222 0.02222222 ... 0.02222222 0.02222222 0.02222222]
  [0.0

In [18]:
print(trt_value_states.shape)
print(trt_value_states[0])

(4, 45, 4096)
[[-0.342221   -1.3077192   3.2848744  ... -1.263129   -0.57490236
   1.4959995 ]
 [-0.342221   -1.3077192   3.2848744  ... -1.263129   -0.57490236
   1.4959995 ]
 [-0.342221   -1.3077192   3.2848744  ... -1.263129   -0.57490236
   1.4959995 ]
 ...
 [-0.342221   -1.3077192   3.2848744  ... -1.263129   -0.57490236
   1.4959995 ]
 [-0.342221   -1.3077192   3.2848744  ... -1.263129   -0.57490236
   1.4959995 ]
 [-0.342221   -1.3077192   3.2848744  ... -1.263129   -0.57490236
   1.4959995 ]]
