# 0 - setup
In this guide, we will implement a simple “Hello World” style NKI kernel and run it on a NeuronDevice (Trainium/Inferentia2 or beyond device). We will showcase how to invoke a NKI kernel standalone through NKI baremetal mode and also through ML frameworks (PyTorch). Before diving into kernel implementation, let’s make sure you have the correct environment setup for running NKI kernels.

## Environment Setup
You need a [Trn1](https://aws.amazon.com/ec2/instance-types/trn1/) or [Inf2](https://aws.amazon.com/ec2/instance-types/inf2/) instance set up on AWS to run NKI kernels on a NeuronDevice. Once logged into the instance, follow steps below to ensure you have all the required packages installed in your Python environment.

NKI is shipped as part of the Neuron compiler package. To make sure you have the latest compiler package, see [Setup Guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/setup/index.html) for an installation guide.

You can verify that NKI is available in your compiler installation by running the following command:

In [None]:
import neuronxcc.nki

This attempts to import the NKI package. It will error out if NKI is not included in your Neuron compiler version or if the Neuron compiler is not installed. The import might take about a minute the first time you run it. Whenever possible, we recommend using local instance NVMe volumes instead of EBS for executable code.

If you intend to run NKI kernels without any ML framework for quick prototyping, you will also need NumPy installed.

To call NKI kernels from PyTorch, you also need to have torch_neuronx installed. For an installation guide, see PyTorch Neuron Setup. You can verify that you have torch_neuronx installed by running the following command:

In [None]:
import torch_neuronx

## Implementing your first NKI kernel
In current NKI release, all input and output tensors must be passed into the kernel as device memory (HBM) tensors on a NeuronDevice. The body of the kernel typically consists of three main phases:

1. Load the inputs from device memory to on-chip memory (SBUF).
2. Perform the desired computation.
3. Store the outputs from on-chip memory to device memory.

For more details on the above terms, see [NKI Programming Model](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/programming_model.html).

In [None]:
from neuronxcc import nki
import neuronxcc.nki.language as nl

@nki.jit
def nki_tensor_add_kernel(a_input, b_input):
    """
    NKI kernel to compute element-wise addition of two input tensors
    """
    
    # Check all input/output tensor shapes are the same for element-wise operation
    assert a_input.shape == b_input.shape

    # Check size of the first dimension does not exceed on-chip memory tile size limit,
    # so that we don't need to tile the input to keep this example simple
    assert a_input.shape[0] <= nl.tile_size.pmax

    # Load the inputs from device memory to on-chip memory
    a_tile = nl.load(a_input)
    b_tile = nl.load(b_input)

    # Specify the computation (in our case: a + b)
    c_tile = nl.add(a_tile, b_tile)

    # Create a HBM tensor as the kernel output
    c_output = nl.ndarray(a_input.shape, dtype=a_input.dtype, buffer=nl.shared_hbm)

    # Store the result to c_output from on-chip memory to device memory
    nl.store(c_output, value=c_tile)

    # Return kernel output as function output
    return c_output

## Running the kernel
Next, we will cover unique ways to run the above NKI kernel on a NeuronDevice:

1. NKI baremetal: run NKI kernel with no ML framework involvement
2. PyTorch: run NKI kernel as a PyTorch operator
3. JAX: run NKI kernel as a JAX operator (not used in this workshop)

All three run modes can call the same kernel function decorated with the `nki.jit` decorator as discussed above:
```
@nki.jit
def nki_tensor_add_kernel(a_input, b_input):
```
The `nki.jit` decorator automatically chooses the correct run mode by checking the incoming tensor type:

1. NumPy arrays as input: run in NKI baremetal mode
2. PyTorch tensors as input: run in PyTorch mode
3. JAX tensors: run in JAX mode

See [nki.jit](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/api/generated/nki.jit.html) API doc for more details.

### NKI baremetal

Baremetal mode expects input tensors of the NKI kernel to be NumPy arrays. The kernel also converts its NKI output tensors to NumPy arrays. To invoke the kernel, we first initialize the two input tensors `a` and `b` as NumPy arrays. Finally, we call the NKI kernel just like any other Python function:

In [None]:
import numpy as np

a = np.ones((4, 3), dtype=np.float16)
b = np.ones((4, 3), dtype=np.float16)

# Run NKI kernel on a NeuronDevice
c = nki_tensor_add_kernel(a, b)

print(c)

> Alternatively, we can decorate the kernel with [nki.baremetal](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/api/generated/nki.baremetal.html) or pass the `mode` parameter to the `nki.jit` decorator, `@nki.jit(mode='baremetal')`, to bypass the dynamic mode detection. See [nki.baremetal](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/api/generated/nki.baremetal.html) API doc for more available input arguments for the baremetal mode.

### PyTorch

To run the above `nki_tensor_add_kernel` kernel using PyTorch, we initialize the input and output tensors as PyTorch `device` tensors instead.

In [None]:
import torch
from torch_xla.core import xla_model as xm

nki_tensor_add_kernel_pytorch = nki.jit(nki_tensor_add_kernel, mode='torchxla')

device = xm.xla_device()

a = torch.ones((4, 3), dtype=torch.float16).to(device=device)
b = torch.ones((4, 3), dtype=torch.float16).to(device=device)

c = nki_tensor_add_kernel_pytorch(a, b)

print(c)  # an implicit XLA barrier/mark-step (triggers XLA compilation)

> Alternatively, we can pass the `mode='torchxla'` parameter into the `nki.jit` decorator to bypass the dynamic mode detection.

## Release the NeuronCore for the next notebook

Before moving to the next notebook we need to release the NeuronCore. If we don't do this the next notebook will not be able to use the resources - you can also stop the kernel via the GUI.

> When running the command in the next cell the notebook will give an error - this is to be expected.

In [None]:
import IPython
IPython.Application.instance().kernel.do_shutdown(True)