In [1]:
import numpy as np
from cuda import cudart
import torch
from torch import Tensor, nn
import tensorrt as trt
import math

## Generate input and data shape

In [2]:
config = dict()

batch_size, seq_len, hidden_size = 4, 45, 4096
intermediate_size = 11008
num_attention_heads = 32
num_key_value_heads = 32
max_position_embeddings = 2048
rope_theta = 10000.0

config["hidden_size"] = hidden_size
config["intermediate_size"] = intermediate_size
config["num_heads"] = num_attention_heads
config["head_dim"] = config["hidden_size"] // config["num_heads"]
config["num_key_value_heads"] = num_key_value_heads
config["num_key_value_groups"] = config["num_heads"] // config["num_key_value_heads"]
config["max_position_embeddings"] = max_position_embeddings
config["rope_theta"] = rope_theta

In [3]:
data = torch.ones(batch_size, seq_len, hidden_size)
attention_mask=torch.ones(batch_size, 1, seq_len, seq_len)
position_ids = torch.arange(0, seq_len)
position_ids = position_ids.repeat(batch_size, 1)

## torch attention

In [4]:
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    
    repeat at the second dimension
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

In [116]:
class LlamaRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        print(torch.arange(0, 20, 2))
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        t = torch.arange(max_position_embeddings, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(torch.get_default_dtype()), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(torch.get_default_dtype()), persistent=False)

    def rotate_half(self, x):
        """Rotates half the hidden dims of the input."""
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        return torch.cat((-x2, x1), dim=-1)

    def forward(self, q, k, v, position_ids, seq_len=None):
        # v: [bs, num_attention_heads, seq_len, head_size]
        cos = self.cos_cached[:, :, :seq_len, ...].to(dtype=v.dtype)
        sin = self.sin_cached[:, :, :seq_len, ...].to(dtype=v.dtype)
        cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
        sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
        cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
        sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]

        # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
        q_embed = (q * cos) + (self.rotate_half(q) * sin)
        k_embed = (k * cos) + (self.rotate_half(k) * sin)
        return q_embed, k_embed


In [117]:
test_batch_size = 4
test_seq_len = 45
test_dim = 128
test_max_position_embeddings = 2048

rotary_emb = LlamaRotaryEmbedding(test_dim, test_max_position_embeddings, 10000.0)


query_states = torch.ones(test_batch_size, 32, test_seq_len, test_dim)
key_states = torch.ones(test_batch_size, 32, test_seq_len, test_dim)
value_states = torch.ones(test_batch_size, 32, test_seq_len, test_dim)
position_ids = torch.arange(test_seq_len).repeat(test_batch_size, 1)

query_states, key_states = rotary_emb(query_states, key_states, value_states, position_ids, seq_len=test_seq_len)

print(query_states[0])
print(query_states.shape)
print(key_states[0])
print(key_states.shape)

tensor([ 0,  2,  4,  6,  8, 10, 12, 14, 16, 18])
tensor([[[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
         [-0.3012, -0.1138,  0.0502,  ...,  1.0002,  1.0001,  1.0001],
         [-1.3254, -1.1475, -0.9265,  ...,  1.0003,  1.0003,  1.0002],
         ...,
         [ 0.5165,  1.2106,  0.9173,  ...,  1.0064,  1.0056,  1.0048],
         [ 1.3869,  1.3412, -0.0624,  ...,  1.0066,  1.0057,  1.0050],
         [ 0.9821,  0.5273, -1.0086,  ...,  1.0068,  1.0059,  1.0051]],

        [[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
         [-0.3012, -0.1138,  0.0502,  ...,  1.0002,  1.0001,  1.0001],
         [-1.3254, -1.1475, -0.9265,  ...,  1.0003,  1.0003,  1.0002],
         ...,
         [ 0.5165,  1.2106,  0.9173,  ...,  1.0064,  1.0056,  1.0048],
         [ 1.3869,  1.3412, -0.0624,  ...,  1.0066,  1.0057,  1.0050],
         [ 0.9821,  0.5273, -1.0086,  ...,  1.0068,  1.0059,  1.0051]],

        [[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.000

In [108]:
test_batch_size = 4
test_seq_len = 45
test_dim = 128
test_max_position_embeddings = 2048

rotary_emb = LlamaRotaryEmbedding(test_dim, test_max_position_embeddings, 10000.0)


query_states = torch.ones(test_batch_size, 32, test_seq_len, test_dim)
key_states = torch.ones(test_batch_size, 32, test_seq_len, test_dim)
value_states = torch.ones(test_batch_size, 32, test_seq_len, test_dim)
position_ids = torch.arange(test_seq_len).repeat(test_batch_size, 1)

cos, sin = rotary_emb(value_states, seq_len=test_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

print(query_states[0])
print(query_states.shape)
print(key_states[0])
print(key_states.shape)

torch.Size([4, 32, 45, 128]) torch.Size([4, 1, 45, 128])
tensor([[[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
         [-0.3012, -0.1138,  0.0502,  ...,  1.0002,  1.0001,  1.0001],
         [-1.3254, -1.1475, -0.9265,  ...,  1.0003,  1.0003,  1.0002],
         ...,
         [ 0.5165,  1.2106,  0.9173,  ...,  1.0064,  1.0056,  1.0048],
         [ 1.3869,  1.3412, -0.0624,  ...,  1.0066,  1.0057,  1.0050],
         [ 0.9821,  0.5273, -1.0086,  ...,  1.0068,  1.0059,  1.0051]],

        [[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
         [-0.3012, -0.1138,  0.0502,  ...,  1.0002,  1.0001,  1.0001],
         [-1.3254, -1.1475, -0.9265,  ...,  1.0003,  1.0003,  1.0002],
         ...,
         [ 0.5165,  1.2106,  0.9173,  ...,  1.0064,  1.0056,  1.0048],
         [ 1.3869,  1.3412, -0.0624,  ...,  1.0066,  1.0057,  1.0050],
         [ 0.9821,  0.5273, -1.0086,  ...,  1.0068,  1.0059,  1.0051]],

        [[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000

## Breaking down LlamaRotaryEmbedding

## tensorRT MLP

In [47]:
# seq length is not specified, since it is a dynamic size
def trt_create(batch_size, hidden_size, intermediate_size):
    
    logger = trt.Logger(trt.Logger.ERROR)
    builder = trt.Builder(logger)

    network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    config = builder.create_builder_config()

    # input
    inputT0 = network.add_input('inputT0', trt.DataType.FLOAT, (4, 4, 4))
    
#     network.mark_output(layer.get_output(0))

    # dynamic shape optimization
#     profile = builder.create_optimization_profile();
#     profile.set_shape("inputT0", (batch_size, 1,4), (batch_size, 1, 2, hidden_size), (batch_size, 1, 3, hidden_size))
#     config.add_optimization_profile(profile)

    slice_layer = network.add_slice(inputT0, start=(0, 0, 0), shape=inputT0.shape, stride=(1, 1, 1))

#     # self.up_proj(x)
#     up_proj_weight = model.up_proj.weight.clone().detach().cpu().numpy()
#     up_proj_layer = network.add_fully_connected(inputT0, model.intermediate_size, up_proj_weight)

#     # act_fn(self.gate_proj(x))
#     gate_proj_weight = model.gate_proj.weight.clone().detach().cpu().numpy()
#     gate_proj_layer = network.add_fully_connected(inputT0, model.intermediate_size, gate_proj_weight)

#     selu_sigmoid_layer = network.add_activation(gate_proj_layer.get_output(0), type=trt.ActivationType.SIGMOID)
#     selu_mult_layer = network.add_elementwise(gate_proj_layer.get_output(0), selu_sigmoid_layer.get_output(0), op=trt.ElementWiseOperation.PROD)

#     # act_fn(self.gate_proj(x)) * self.up_proj(x)
#     before_down_proj_layer = network.add_elementwise(selu_mult_layer.get_output(0), up_proj_layer.get_output(0), op=trt.ElementWiseOperation.PROD)

#     down_proj_weight = model.down_proj.weight.clone().detach().cpu().numpy()
#     down_proj_layer = network.add_fully_connected(before_down_proj_layer.get_output(0), hidden_size, down_proj_weight)

    # output
    network.mark_output(slice_layer.get_output(0))

    engineString = builder.build_serialized_network(network, config)
    
    return engineString

In [48]:
trt_engineStr = trt_create(batch_size, hidden_size, intermediate_size)

In [49]:
def trt_inference(batch_size, hidden_size, engineString, raw_data):
#     print(engineString)
#     print("Runtime")
    logger = trt.Logger(trt.Logger.ERROR)
    engine = trt.Runtime(logger).deserialize_cuda_engine(engineString)
    context = engine.create_execution_context()

    # dynamic shape configure
#     print("Set input shape", (batch_size, 1, hidden_size))
#     context.set_input_shape("inputT0", (batch_size, 1, hidden_size))
#     context.set_binding_shape(0, (batch_size, 1, hidden_size))
#     origin_inputshape = context.get_binding_shape(0)

#     print("Set input shape completed")

    data = np.array(raw_data)

    _, stream = cudart.cudaStreamCreate()
#     print("Reshaping")

    inputH0 = np.ascontiguousarray(data.reshape(-1))
    outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))
#     print("Reshaped")

    # initialize input and output data
    _, inputD0 = cudart.cudaMallocAsync(inputH0.nbytes, stream)
    _, outputD0 = cudart.cudaMallocAsync(outputH0.nbytes, stream)

    # move input to device
    cudart.cudaMemcpyAsync(inputD0, inputH0.ctypes.data, inputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, stream)

    # execute
#     print("execute")
    context.execute_async_v2([int(inputD0), int(outputD0)], stream)

    # move output back to host
    cudart.cudaMemcpyAsync(outputH0.ctypes.data, outputD0, outputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, stream)

    # wait for everythidden_sizeg
    cudart.cudaStreamSynchronize(stream)

    cudart.cudaStreamDestroy(stream)
    cudart.cudaFree(inputD0)
    cudart.cudaFree(outputD0)

    return outputH0

In [50]:
# up_proj_weight = model.up_proj.weight.clone().detach().cpu().numpy()
another_data = torch.arange(4 * 4 * 4, dtype=torch.float).reshape(4,4, 4)
print(another_data)

trt_output = trt_inference(batch_size, hidden_size, trt_engineStr, another_data)

# trt_output = trt_output.reshape(batch_size, seq_len, hidden_size)
print("output_trt :", trt_output.shape)
print(trt_output)

tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.],
         [12., 13., 14., 15.]],

        [[16., 17., 18., 19.],
         [20., 21., 22., 23.],
         [24., 25., 26., 27.],
         [28., 29., 30., 31.]],

        [[32., 33., 34., 35.],
         [36., 37., 38., 39.],
         [40., 41., 42., 43.],
         [44., 45., 46., 47.]],

        [[48., 49., 50., 51.],
         [52., 53., 54., 55.],
         [56., 57., 58., 59.],
         [60., 61., 62., 63.]]])
output_trt : (4, 4, 4)
[[[ 0.  1.  2.  3.]
  [ 4.  5.  6.  7.]
  [ 8.  9. 10. 11.]
  [12. 13. 14. 15.]]

 [[16. 17. 18. 19.]
  [20. 21. 22. 23.]
  [24. 25. 26. 27.]
  [28. 29. 30. 31.]]

 [[32. 33. 34. 35.]
  [36. 37. 38. 39.]
  [40. 41. 42. 43.]
  [44. 45. 46. 47.]]

 [[48. 49. 50. 51.]
  [52. 53. 54. 55.]
  [56. 57. 58. 59.]
  [60. 61. 62. 63.]]]


  outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))
  outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))


In [114]:
class LlamaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: dict):
        super().__init__()
        self.config = config
        self.hidden_size = config["hidden_size"]
        self.num_heads = config["num_heads"]
        self.head_dim = config["hidden_size"] // config["num_heads"]
        self.num_key_value_heads = config["num_key_value_heads"]
        self.num_key_value_groups = config["num_heads"] // config["num_key_value_heads"]
        self.max_position_embeddings = config["max_position_embeddings"]
        self.rope_theta = config["rope_theta"]

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
        self._init_rope()

    def _init_rope(self):
        print(
            "init rope",
            self.head_dim,
            self.max_position_embeddings,
            self.rope_theta,
        )
        self.rotary_emb = LlamaRotaryEmbedding(
            self.head_dim,
            max_position_embeddings=self.max_position_embeddings,
            base=self.rope_theta,
        )

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: None,
        position_ids: None,
        past_key_value: None,
        output_attentions: bool = False,
        use_cache: bool = False,
    ):
        # bsz = batch size; q_len = query length; _ = hidden size
        bsz, q_len, _ = hidden_states.size()

        # do projection
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # reshape
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        

        # in hugging face, they do have kv cache, however, they don't have other attention optimization
        # this could be done directly in tensorRT by using dynamic shape
        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            print("past_key_value is not None, kv_seq_len")
            kv_seq_len += past_key_value[0].shape[-2]

        print(query_states.shape)
        print(key_states.shape)
        print(value_states.shape)


        print("kv_seq_len", kv_seq_len, position_ids)
        query_states, key_states = rotary_emb(query_states, key_states, value_states, position_ids, seq_len=test_seq_len)


        #####################################################

        if past_key_value is not None:
            print("past_key_value is not None, self_attention")
            # reuse k, v, self_attention
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)

        #####################################################

        past_key_value = (key_states, value_states) if use_cache else None

        print(self.num_key_value_groups)
        # repeat k/v heads if n_kv_heads < n_heads
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        attn_weights = attn_weights + attention_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_output = torch.matmul(attn_weights, value_states)

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value

## Test torch

In [115]:
model = LlamaAttention(config)

device = torch.device("cuda")

# model.load("/home/fuchiang137/.cache/huggingface/hub/models--decapoda-research--llama-7b-hf/snapshots/5f98eefcc80e437ef68d457ad7bf167c2c6a1348/pytorch_model-00019-of-00033.bin")
model = model.to(device)

data_D = data.to(device)
attention_mask_D = attention_mask.to(device)
position_ids_D = position_ids.to(device)
# output = model(data)

past_key_value = None

# attentiona mask
# position_ids
# specifies the position id of the corresponding hidden state tensor element
# e.g. hid = [3, 4, 6] => pos_id = [0, 1, 2]
# past_key_value
# if use cache, past key value will contain past kv values
output = model(hidden_states=data_D,
               attention_mask=attention_mask_D,
               position_ids=position_ids_D,
               past_key_value=past_key_value,
               output_attentions=False,
               use_cache=True,)

print(output)


init rope 128 2048 10000.0
tensor([ 0,  2,  4,  6,  8, 10, 12, 14, 16, 18])
torch.Size([4, 32, 45, 128])
torch.Size([4, 32, 45, 128])
torch.Size([4, 32, 45, 128])
kv_seq_len 45 tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
         36, 37, 38, 39, 40, 41, 42, 43, 44],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
         36, 37, 38, 39, 40, 41, 42, 43, 44],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
         36, 37, 38, 39, 40, 41, 42, 43, 44],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
         36, 37, 38, 39, 40, 41, 42, 

RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)

## tensorRT MLP

In [70]:
# seq length is not specified, since it is a dynamic size
def trt_create(batch_size, hidden_size, intermediate_size, model):
    
    logger = trt.Logger(trt.Logger.ERROR)
    builder = trt.Builder(logger)

    network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    config = builder.create_builder_config()

    # input
    inputT0 = network.add_input('inputT0', trt.DataType.FLOAT, (batch_size, 1, -1, hidden_size))
    
#     network.mark_output(layer.get_output(0))

    # dynamic shape optimization
    profile = builder.create_optimization_profile();
    profile.set_shape("inputT0", (batch_size, 1, 1, hidden_size), (batch_size, 1, 2, hidden_size), (batch_size, 1, 3, hidden_size))
    config.add_optimization_profile(profile)

    slice_layer = network.add_slice(inputT0, start=(0, 0, 0, 0), shape=(2, 2, 2, 2), stride=(1, 1, 1, 1))

#     # self.up_proj(x)
#     up_proj_weight = model.up_proj.weight.clone().detach().cpu().numpy()
#     up_proj_layer = network.add_fully_connected(inputT0, model.intermediate_size, up_proj_weight)

#     # act_fn(self.gate_proj(x))
#     gate_proj_weight = model.gate_proj.weight.clone().detach().cpu().numpy()
#     gate_proj_layer = network.add_fully_connected(inputT0, model.intermediate_size, gate_proj_weight)

#     selu_sigmoid_layer = network.add_activation(gate_proj_layer.get_output(0), type=trt.ActivationType.SIGMOID)
#     selu_mult_layer = network.add_elementwise(gate_proj_layer.get_output(0), selu_sigmoid_layer.get_output(0), op=trt.ElementWiseOperation.PROD)

#     # act_fn(self.gate_proj(x)) * self.up_proj(x)
#     before_down_proj_layer = network.add_elementwise(selu_mult_layer.get_output(0), up_proj_layer.get_output(0), op=trt.ElementWiseOperation.PROD)

#     down_proj_weight = model.down_proj.weight.clone().detach().cpu().numpy()
#     down_proj_layer = network.add_fully_connected(before_down_proj_layer.get_output(0), hidden_size, down_proj_weight)

    # output
    network.mark_output(slice_layer.get_output(0))

    engineString = builder.build_serialized_network(network, config)
    
    return engineString

In [71]:
trt_engineStr = trt_create(batch_size, hidden_size, intermediate_size, model)

[09/28/2023-07:31:43] [TRT] [E] 4: (Unnamed Layer* 0) [Slice]: out of bounds slice, input dimensions = [4,1,-1,4096], start = [0,0,0,0], size = [2,2,2,2], stride = [1,1,1,1].
[09/28/2023-07:31:43] [TRT] [E] 4: [network.cpp::validate::3121] Error Code 4: Internal Error (Layer (Unnamed Layer* 0) [Slice] failed validation)


In [62]:
def trt_inference(batch_size, hidden_size, engineString, raw_data, up_proj):
#     print(engineString)
#     print("Runtime")
    logger = trt.Logger(trt.Logger.ERROR)
    engine = trt.Runtime(logger).deserialize_cuda_engine(engineString)
    context = engine.create_execution_context()

    # dynamic shape configure
#     print("Set input shape", (batch_size, 1, hidden_size))
#     context.set_input_shape("inputT0", (batch_size, 1, hidden_size))
#     context.set_binding_shape(0, (batch_size, 1, hidden_size))
#     origin_inputshape = context.get_binding_shape(0)

#     print("Set input shape completed")

    data = np.array(raw_data)

    _, stream = cudart.cudaStreamCreate()
#     print("Reshaping")

    inputH0 = np.ascontiguousarray(data.reshape(-1))
    outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))
#     print("Reshaped")

    # initialize input and output data
    _, inputD0 = cudart.cudaMallocAsync(inputH0.nbytes, stream)
    _, outputD0 = cudart.cudaMallocAsync(outputH0.nbytes, stream)

    # move input to device
    cudart.cudaMemcpyAsync(inputD0, inputH0.ctypes.data, inputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, stream)

    # execute
#     print("execute")
    context.execute_async_v2([int(inputD0), int(outputD0)], stream)

    # move output back to host
    cudart.cudaMemcpyAsync(outputH0.ctypes.data, outputD0, outputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, stream)

    # wait for everythidden_sizeg
    cudart.cudaStreamSynchronize(stream)

    cudart.cudaStreamDestroy(stream)
    cudart.cudaFree(inputD0)
    cudart.cudaFree(outputD0)

    return outputH0

In [63]:
up_proj_weight = model.up_proj.weight.clone().detach().cpu().numpy()

trt_output = trt_inference(batch_size, hidden_size, trt_engineStr, data, up_proj_weight)

trt_output = trt_output.reshape(batch_size, seq_len, hidden_size)
print("output_trt :", trt_output.shape)
print(trt_output)

output_trt : (4, 1, 4096)
[[[ 2.0781114  4.51184    3.0770564 ... -2.6064956 -2.116741  -2.7058382]]

 [[ 2.0781114  4.51184    3.0770564 ... -2.6064956 -2.116741  -2.7058382]]

 [[ 2.0781114  4.51184    3.0770564 ... -2.6064956 -2.116741  -2.7058382]]

 [[ 2.0781114  4.51184    3.0770564 ... -2.6064956 -2.116741  -2.7058382]]]


  outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))
  outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))


## Benchmark

In [None]:
import time

### Torch

In [74]:
torch_start = time.time_ns()

output = model(data_D)

torch_complete = time.time_ns()

print("torch memory exe", (torch_complete - torch_start) / 10e6, "ms")


torch memory exe 0.2631836 ms


### TensorRT

### profile CPU/GPU time for tensorRT

In [68]:
def profile_trt_inference(batch_size, hidden_size, engineString, raw_data, up_proj):
    trt_prep_start = time.time_ns()
    
    logger = trt.Logger(trt.Logger.ERROR)
    engine = trt.Runtime(logger).deserialize_cuda_engine(engineString)
    context = engine.create_execution_context()

    trt_prep_complete = time.time_ns()

    data = np.array(raw_data)

    inputH0 = np.ascontiguousarray(data.reshape(-1))
    outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))

    memory_alloc_complete = time.time_ns()

    _, stream = cudart.cudaStreamCreate()

    # initialize input and output data
    _, inputD0 = cudart.cudaMallocAsync(inputH0.nbytes, stream)
    _, outputD0 = cudart.cudaMallocAsync(outputH0.nbytes, stream)

    # move input to device
    cudart.cudaMemcpyAsync(inputD0, inputH0.ctypes.data, inputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, stream)

    # execute
    context.execute_async_v2([int(inputD0), int(outputD0)], stream)

    # move output back to host
    cudart.cudaMemcpyAsync(outputH0.ctypes.data, outputD0, outputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, stream)

    # wait for everythidden_sizeg
    cudart.cudaStreamSynchronize(stream)

    cudart.cudaStreamDestroy(stream)
    cudart.cudaFree(inputD0)
    cudart.cudaFree(outputD0)

    trt_complete = time.time_ns()
    
    print("trt_prep", (trt_prep_complete - trt_prep_start) / 10e6, "ms")
    print("memory_alloc CPU", (memory_alloc_complete - trt_prep_complete) / 10e6, "ms")
    print("trt memory alloc & mv & exe", (trt_complete - memory_alloc_complete) / 10e6, "ms")

    return outputH0

In [69]:
trt_output = profile_trt_inference(batch_size, hidden_size, trt_engineStr, data, up_proj_weight)

trt_prep 15.1114614 ms
memory_alloc CPU 0.0241049 ms
trt memory alloc & mv & exe 0.1599591 ms


  outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))
  outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))
