In [1]:
import numpy as np
from cuda import cudart
import torch
from torch import Tensor, nn
import tensorrt as trt
import math

## Generate input and data shape

In [2]:
config = dict()

batch_size, seq_len, hidden_size = 4, 45, 4096
intermediate_size = 11008
num_attention_heads = 32
num_key_value_heads = 32
max_position_embeddings = 2048
rope_theta = 10000.0

config["hidden_size"] = hidden_size
config["intermediate_size"] = intermediate_size
config["num_heads"] = num_attention_heads
config["head_dim"] = config["hidden_size"] // config["num_heads"]
config["num_key_value_heads"] = num_key_value_heads
config["num_key_value_groups"] = config["num_heads"] // config["num_key_value_heads"]
config["max_position_embeddings"] = max_position_embeddings
config["rope_theta"] = rope_theta

In [3]:
data = torch.ones(batch_size, seq_len, hidden_size)
attention_mask = torch.ones(batch_size, 1, seq_len, seq_len)
position_ids = torch.arange(0, seq_len)
position_ids = position_ids.repeat(batch_size, 1)

## torch attention

In [4]:
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    
    repeat at the second dimension
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

In [5]:
class LlamaRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        t = torch.arange(max_position_embeddings, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(torch.get_default_dtype()), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(torch.get_default_dtype()), persistent=False)

    def rotate_half(self, x):
        """Rotates half the hidden dims of the input."""
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        return torch.cat((-x2, x1), dim=-1)

    def forward(self, q, k, v, position_ids, seq_len=None):
        # v: [bs, num_attention_heads, seq_len, head_size]
        cos = self.cos_cached[:, :, :seq_len, ...].to(dtype=v.dtype)
        sin = self.sin_cached[:, :, :seq_len, ...].to(dtype=v.dtype)
        cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
        sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
        cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
        sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]

        # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
        q_embed = (q * cos) + (self.rotate_half(q) * sin)
        k_embed = (k * cos) + (self.rotate_half(k) * sin)
        return q_embed, k_embed


In [6]:
class LlamaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: dict):
        super().__init__()
        self.config = config
        self.hidden_size = config["hidden_size"]
        self.num_heads = config["num_heads"]
        self.head_dim = config["hidden_size"] // config["num_heads"]
        self.num_key_value_heads = config["num_key_value_heads"]
        self.num_key_value_groups = config["num_heads"] // config["num_key_value_heads"]
        self.max_position_embeddings = config["max_position_embeddings"]
        self.rope_theta = config["rope_theta"]

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
        self._init_rope()

    def _init_rope(self):
        print(
            "init rope",
            self.head_dim,
            self.max_position_embeddings,
            self.rope_theta,
        )
        self.rotary_emb = LlamaRotaryEmbedding(
            self.head_dim,
            max_position_embeddings=self.max_position_embeddings,
            base=self.rope_theta,
        )

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()


    def load(self, dir):
        weights = torch.load(dir)
        self_attn_weights = dict()
        for key in weights.keys():
            print(key)
            if key == "model.layers.18.self_attn.rotary_emb.inv_freq":
                print(weights[key])
                continue
            if key.split(".")[3] == "self_attn":
                self_attn_weights[key[key.find(key.split(".")[4]):]] = weights[key]

        self.load_state_dict(self_attn_weights)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: None,
        position_ids: None,
        past_key_value: None,
        output_attentions: bool = False,
        use_cache: bool = False,
    ):
        # bsz = batch size; q_len = query length; _ = hidden size
        bsz, q_len, _ = hidden_states.size()

        # do projection
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # reshape
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        #####################################################
        # in hugging face, they do have kv cache, however, they don't have other attention optimization
        # this could be done directly in tensorRT by using dynamic shape
        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            kv_seq_len += past_key_value[0].shape[-2]

        query_states, key_states = self.rotary_emb(query_states, key_states, value_states, position_ids, seq_len=q_len)

        if past_key_value is not None:
            # reuse k, v, self_attention
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)

        past_key_value = (key_states, value_states) if use_cache else None

        print(self.num_key_value_groups)
        # repeat k/v heads if n_kv_heads < n_heads
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
        #####################################################

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        attn_weights = attn_weights + attention_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_output = torch.matmul(attn_weights, value_states)

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value

## Test torch

In [7]:
model = LlamaAttention(config)

device = torch.device("cuda")

model.load("/home/fuchiang137/.cache/huggingface/hub/models--decapoda-research--llama-7b-hf/snapshots/5f98eefcc80e437ef68d457ad7bf167c2c6a1348/pytorch_model-00019-of-00033.bin")

model = model.to(device)

data_D = data.to(device)
attention_mask_D = attention_mask.to(device)
position_ids_D = position_ids.to(device)
# output = model(data)

past_key_value = None

# attentiona mask
# position_ids
# specifies the position id of the corresponding hidden state tensor element
# e.g. hid = [3, 4, 6] => pos_id = [0, 1, 2]
# past_key_value
# if use cache, past key value will contain past kv values
output = model(hidden_states=data_D,
               attention_mask=attention_mask_D,
               position_ids=position_ids_D,
               past_key_value=past_key_value,
               output_attentions=False,
               use_cache=True,)

print(output)


init rope 128 2048 10000.0
model.layers.18.self_attn.q_proj.weight
model.layers.18.self_attn.k_proj.weight
model.layers.18.self_attn.v_proj.weight
model.layers.18.self_attn.o_proj.weight
model.layers.18.mlp.gate_proj.weight
model.layers.18.mlp.down_proj.weight
model.layers.18.mlp.up_proj.weight
model.layers.18.input_layernorm.weight
model.layers.18.post_attention_layernorm.weight
model.layers.18.self_attn.rotary_emb.inv_freq
tensor([1.0000e+00, 8.6596e-01, 7.4989e-01, 6.4938e-01, 5.6234e-01, 4.8697e-01,
        4.2170e-01, 3.6517e-01, 3.1623e-01, 2.7384e-01, 2.3714e-01, 2.0535e-01,
        1.7783e-01, 1.5399e-01, 1.3335e-01, 1.1548e-01, 1.0000e-01, 8.6596e-02,
        7.4989e-02, 6.4938e-02, 5.6234e-02, 4.8697e-02, 4.2170e-02, 3.6517e-02,
        3.1623e-02, 2.7384e-02, 2.3714e-02, 2.0535e-02, 1.7783e-02, 1.5399e-02,
        1.3335e-02, 1.1548e-02, 1.0000e-02, 8.6596e-03, 7.4989e-03, 6.4938e-03,
        5.6234e-03, 4.8697e-03, 4.2170e-03, 3.6517e-03, 3.1623e-03, 2.7384e-03,
        2.3

## Breaking down LlamaRotaryEmbedding

## tensorRT reshape (for apporximation of squeeze or unsqueeze)

In [8]:
# seq length is not specified, since it is a dynamic size
def trt_create(batch_size, hidden_size, intermediate_size):
    
    logger = trt.Logger(trt.Logger.ERROR)
    builder = trt.Builder(logger)

    network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    config = builder.create_builder_config()

    # input
    inputT0 = network.add_input('inputT0', trt.DataType.FLOAT, (1, 24))

    shuffle_layer = network.add_shuffle(inputT0)
    shuffle_layer.reshape_dims = trt.Dims([1, 2, 3, 4])
    shuffle_layer.second_transpose = trt.Permutation([0, 2, 1, 3])
    
    # slice_layer = network.add_slice(inputT0, start=(0, 0, 0), shape=inputT0.shape, stride=(1, 1, 1))

    # output
    network.mark_output(shuffle_layer.get_output(0))

    engineString = builder.build_serialized_network(network, config)

    return engineString

In [9]:
trt_engineStr = trt_create(batch_size, hidden_size, intermediate_size)

In [10]:
def trt_inference(batch_size, hidden_size, engineString, raw_data):
#     print(engineString)
#     print("Runtime")
    logger = trt.Logger(trt.Logger.ERROR)
    engine = trt.Runtime(logger).deserialize_cuda_engine(engineString)
    context = engine.create_execution_context()

    # dynamic shape configure
#     print("Set input shape", (batch_size, 1, hidden_size))
#     context.set_input_shape("inputT0", (batch_size, 1, hidden_size))
#     context.set_binding_shape(0, (batch_size, 1, hidden_size))
#     origin_inputshape = context.get_binding_shape(0)

#     print("Set input shape completed")

    data = np.array(raw_data)

    _, stream = cudart.cudaStreamCreate()
#     print("Reshaping")

    inputH0 = np.ascontiguousarray(data.reshape(-1))
    outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))
#     print("Reshaped")

    # initialize input and output data
    _, inputD0 = cudart.cudaMallocAsync(inputH0.nbytes, stream)
    _, outputD0 = cudart.cudaMallocAsync(outputH0.nbytes, stream)

    # move input to device
    cudart.cudaMemcpyAsync(inputD0, inputH0.ctypes.data, inputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, stream)

    # execute
#     print("execute")
    context.execute_async_v2([int(inputD0), int(outputD0)], stream)

    # move output back to host
    cudart.cudaMemcpyAsync(outputH0.ctypes.data, outputD0, outputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, stream)

    # wait for everythidden_sizeg
    cudart.cudaStreamSynchronize(stream)

    cudart.cudaStreamDestroy(stream)
    cudart.cudaFree(inputD0)
    cudart.cudaFree(outputD0)

    return outputH0

In [11]:
query_states = torch.arange(24, dtype=torch.float).reshape(1,24)
print(query_states.shape)
query_states = query_states.view(1, 2, 3, 4)
print(query_states)
print(query_states.shape)
query_states = query_states.transpose(1, 2)
print(query_states)
print(query_states.shape)


torch.Size([1, 24])
tensor([[[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.]],

         [[12., 13., 14., 15.],
          [16., 17., 18., 19.],
          [20., 21., 22., 23.]]]])
torch.Size([1, 2, 3, 4])
tensor([[[[ 0.,  1.,  2.,  3.],
          [12., 13., 14., 15.]],

         [[ 4.,  5.,  6.,  7.],
          [16., 17., 18., 19.]],

         [[ 8.,  9., 10., 11.],
          [20., 21., 22., 23.]]]])
torch.Size([1, 3, 2, 4])


In [12]:
query_states = torch.arange(24, dtype=torch.float).reshape(1,24)
print(query_states)

trt_output = trt_inference(batch_size, hidden_size, trt_engineStr, query_states)

print("output_trt :", trt_output.shape)
print(trt_output)

tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
         14., 15., 16., 17., 18., 19., 20., 21., 22., 23.]])
output_trt : (1, 3, 2, 4)
[[[[ 0.  1.  2.  3.]
   [12. 13. 14. 15.]]

  [[ 4.  5.  6.  7.]
   [16. 17. 18. 19.]]

  [[ 8.  9. 10. 11.]
   [20. 21. 22. 23.]]]]


  outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))
  outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))
