# 1-Intergrate Prebuild Kernel
This notebook demonstrates how to integrate a prebuilt NKI kernel in our code and compare this to an existing implementation

> Special thanks to Hanno Bever (bevhanno@amazon.de) for providing the custom attention code

## Available NKI Kernels
The source code of the kernels in the neuronxcc.nki.kernels namespace is available at the Github Repository [nki-samples](https://github.com/aws-neuron/nki-samples). They are optimized kernels from the Neuron Team serving as samples. The repository also contains numeric tests, performance benchmarks, as well as scripts to use them in real models.

You are welcome to customize them to fit your unique workloads, and contributing to the repository by opening a PR. Note that these kernels are already being deployed as part of the Neuron stack. 

In this notebook we're going to use the [nki.kernels.flash_fwd](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/api/generated/nki.kernels.flash_fwd.html) kernel.

In [None]:
from neuronxcc import nki
from neuronxcc.nki.kernels import flash_fwd as FlashAttentionForward

_flash_fwd_call = nki.jit()(FlashAttentionForward)

To compare the NKI kernel with normal Flash Attention we're going to use the following `PyTorch` implementation

In [None]:
from torch import nn

class AttentionOrginal(nn.Module):
    def __init__(self, query_dim, context_dim=None, heads=8, dropout=0.0):
        super().__init__()
        context_dim = query_dim if context_dim is None else context_dim
        self.mha = nn.MultiheadAttention(
            embed_dim=query_dim,
            num_heads=heads,
            kdim=context_dim,
            vdim=context_dim,
            dropout=dropout,
            batch_first=True,
        )

    def forward(self, x, context=None, mask=None):
        context = x if context is None else context
        out = self.mha(x, context, context, need_weights=False)
        return out[0]

Intergrating this with our NKI code would look the following

In [None]:
from typing import List, Optional
import torch

from torch.nn.functional import linear
from neuronxcc.nki.kernels.attention import FlashConfig


class AttentionNki(nn.Module):
    def __init__(self, mha: nn.MultiheadAttention):
        super().__init__()
        self.mha = mha

    def _in_projection_packed(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        w: torch.Tensor,
        b: Optional[torch.Tensor] = None,
    ) -> List[torch.Tensor]:
        r"""Perform the in-projection step of the attention operation, using packed weights.

        Output is a triple containing projection tensors for query, key and value.

        Args:
            q, k, v: query, key and value tensors to be projected. For self-attention,
                these are typically the same tensor; for encoder-decoder attention,
                k and v are typically the same tensor. (We take advantage of these
                identities for performance if they are present.) Regardless, q, k and v
                must share a common embedding dimension; otherwise their shapes may vary.
            w: projection weights for q, k and v, packed into a single tensor. Weights
                are packed along dimension 0, in q, k, v order.
            b: optional projection biases for q, k and v, packed into a single tensor
                in q, k, v order.

        Shape:
            Inputs:
            - q: :math:`(..., E)` where E is the embedding dimension
            - k: :math:`(..., E)` where E is the embedding dimension
            - v: :math:`(..., E)` where E is the embedding dimension
            - w: :math:`(E * 3, E)` where E is the embedding dimension
            - b: :math:`E * 3` where E is the embedding dimension

            Output:
            - in output list :math:`[q', k', v']`, each output tensor will have the
                same shape as the corresponding input tensor.
        """
        E = q.size(-1)
        if k is v:
            if q is k:
                # self-attention
                proj = linear(q, w, b)
                # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
                proj = (
                    proj.unflatten(-1, (3, E))
                    .unsqueeze(0)
                    .transpose(0, -2)
                    .squeeze(-2)
                    .contiguous()
                )
                return proj[0], proj[1], proj[2]
            else:
                # encoder-decoder attention
                w_q, w_kv = w.split([E, E * 2])
                if b is None:
                    b_q = b_kv = None
                else:
                    b_q, b_kv = b.split([E, E * 2])
                q_proj = linear(q, w_q, b_q)
                kv_proj = linear(k, w_kv, b_kv)
                # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
                kv_proj = (
                    kv_proj.unflatten(-1, (2, E))
                    .unsqueeze(0)
                    .transpose(0, -2)
                    .squeeze(-2)
                    .contiguous()
                )
                return (q_proj, kv_proj[0], kv_proj[1])
        else:
            w_q, w_k, w_v = w.chunk(3)
            if b is None:
                b_q = b_k = b_v = None
            else:
                b_q, b_k, b_v = b.chunk(3)
            return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)

    def forward(self, x, context=None, mask=None):
        context = x if context is None else context

        query, key, value = x, context, context

        if self.mha.batch_first:
            # make sure that the transpose op does not affect the "is" property
            if key is value:
                if query is key:
                    query = key = value = query.transpose(1, 0)
                else:
                    query, key = (x.transpose(1, 0) for x in (query, key))
                    value = key
            else:
                query, key, value = (x.transpose(1, 0) for x in (query, key, value))

        # set up shape vars
        tgt_len, bsz, embed_dim = query.shape
        src_len, _, _ = key.shape

        head_dim = self.mha.head_dim
        num_heads = self.mha.num_heads
        assert head_dim * self.mha.num_heads == embed_dim
        assert key.shape == value.shape

        q, k, v = self._in_projection_packed(
            query, key, value, self.mha.in_proj_weight, self.mha.in_proj_bias
        )

        # reshape q, k, v for multihead attention and make them batch first
        #
        q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
        k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
        v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)

        # update source sequence length after adjustments
        src_len = k.size(1)

        q = q.view(bsz, num_heads, tgt_len, head_dim)
        k = k.view(bsz, num_heads, src_len, head_dim)
        v = v.view(bsz, num_heads, src_len, head_dim)

        # *************************************************************************************************
        # NKI Kernel replacement START
        # *************************************************************************************************

        q = q.permute(0, 1, 3, 2)
        k = k.permute(0, 1, 3, 2)
        v = v.permute(0, 1, 3, 2)

        config = FlashConfig(
            **{"seq_tile_size": 2048, "training": False, "should_transpose_v": True}
        )
        attn_output = _flash_fwd_call[bsz, self.mha.num_heads](
            q, k, v, seed=None, logit_bias=None, use_causal_mask=False, config=config,
        )
        
        attn_output = attn_output.reshape(bsz, num_heads, tgt_len, head_dim)

        # *************************************************************************************************
        # NKI Kernel replacement END
        # *************************************************************************************************

        attn_output = (
            attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
        )

        attn_output = torch.nn.functional.linear(
            attn_output, self.mha.out_proj.weight, self.mha.out_proj.bias
        )

        attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))

        if self.mha.batch_first:
            return attn_output.transpose(1, 0)

## Creating the modules
Both the `AttentionOrginal` module and the `AttentionNki` will use `query_dim=256`, `heads=4`, `context_dim=256` and `dropout=0.0`

In [None]:
# MHA: 256 4 256 256 0.0
mha_module_org = AttentionOrginal(256, heads=4, context_dim=256, dropout=0.0)
mha_module_nki = AttentionNki(mha_module_org.mha)

## Compiling the NKI module

Using `torch-neuronx`, there are two ways that a model can be executed for inference:

- **XLA LazyTensor Inference**: A model is executed on Neuron by calling `to()` to move `Parameter` and `Tensor` data using the `xm.xla_device()`. Executing operations uses torch [Lazy Tensor](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/programming-guide/inference/trace-vs-xla-lazytensor.html#xla-lazytensor) to record, compile, and execute the graph.

- **(Recommended) Traced Inference**: A model is traced prior to inference using the [trace()](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/api-reference-guide/inference/api-torch-neuronx-trace.html#torch_neuronx.trace) API. This trace is similar to `torch.jit.trace()` but instead creates a Neuron-specific [TorchScript](https://pytorch.org/docs/stable/jit.html) artifact. This artifact provides improved performance and portability compared to XLA [Lazy Tensor](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/programming-guide/inference/trace-vs-xla-lazytensor.html#xla-lazytensor) inference.

To learn more about both methods see [Comparison of Traced Inference versus XLA Lazy Tensor Inference (torch-neuronx)](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/programming-guide/inference/trace-vs-xla-lazytensor.html)


In [None]:
import os
import torch_neuronx

os.environ["XLA_DISABLE_FUNCTIONALIZATION"] = ""
os.environ["NEURON_TRANSFER_WITH_STATIC_RING_OPS"] = ""

os.environ["NEURON_RT_NUM_CORES"] = "1"
os.environ["NEURON_CC_FLAGS"] = (
    "--log_level=INFO --cache_dir=./neuron_cache --model-type=generic -O1"
)
os.environ["XLA_USE_BF16"] = "1"

COMPILER_WORKDIR_ROOT = "compiler"

x = torch.randn([1, 8192, 256])
context = torch.randn([1, 8192, 256])

example_inputs = x, context

mha_compiled_neuron = torch_neuronx.trace(
    mha_module_nki,
    example_inputs,
    compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, "MHA"),
    compiler_args=["--model-type", "transformer", "--auto-cast", "all", "--auto-cast-type", "bf16", "--optlevel", "1"],
)

## Looking at the results
Note that while running both examples the resulting output tensor is almost identical, differences in output can be explained by the different datatypes used but are very insignificant if they occur at all.

In [None]:
DTYPE = torch.float

x = torch.randn([1, 8192, 256], dtype=DTYPE)
context = torch.randn_like(x)

org_result = mha_module_org(x, context=context)
print(org_result.shape)
print(org_result)

In [None]:
DTYPE = torch.bfloat16

x = x.bfloat16()
context = context.bfloat16()

neuron_result = mha_compiled_neuron(x, context)
print(neuron_result.shape)
print(neuron_result.float())

## Release the NeuronCore for the next notebook

Before moving to the next notebook we need to release the NeuronCore. If we don't do this the next notebook will not be able to use the resources - you can also stop the kernel via the GUI.

> When running the command in the next cell the notebook will give an error - this is to be expected.

In [None]:
import IPython
IPython.Application.instance().kernel.do_shutdown(True)