# 2-Custom Operators
This notebook demonstrates how to insert a NKI kernel as a custom operators into a PyTorch.

## Using NKI kernels
To register a NKI kernel registration, you need to call a decorated NKI function.

Let’s examine a guiding example below where we randomly initialize two inputs, add them together, and then multiply the result by the two input tensors element-wise. This effectively calculates: `a * b * (a + b)`.

We define a common NKI kernel for addition. For more information on the kernel, see [SPMD Tensor Addition](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/tutorials/spmd_tensor_addition.html).

In [None]:
import neuronxcc.nki as nki
import neuronxcc.nki.language as nl

@nki.jit
def nki_tensor_add_kernel_(a_input, b_input):
  """NKI kernel to compute element-wise addition of two input tensors
  
  This kernel assumes strict input/output sizes can be uniformly tiled to [128,512]

  Args:
      a_input: a first input tensor
      b_input: a second input tensor

  Returns:
      c_output: an output tensor
  """

  # Create output tensor shared between all SPMD instances as result tensor
  c_output = nl.ndarray(a_input.shape, dtype=a_input.dtype, buffer=nl.shared_hbm)

  # Calculate tile offsets based on current 'program'
  offset_i_x = nl.program_id(0) * 128
  offset_i_y = nl.program_id(1) * 512

  # Generate tensor indices to index tensors a and b
  ix_, iy_ = nl.mgrid[0:128, 0:512]
  ix = offset_i_x + ix_
  iy = offset_i_y + iy_

  # Load input data from device memory (HBM) to on-chip memory (SBUF)
  # We refer to an indexed portion of a tensor as an intermediate tensor
  a_tile = nl.load(a_input[ix, iy])
  b_tile = nl.load(b_input[ix, iy])

  # compute a + b
  c_tile = a_tile + b_tile

  # store the addition results back to device memory (c_output)
  nl.store(c_output[ix, iy], value=c_tile)

  # Transfer the ownership of `c_output` to the caller
  return c_output

## PyTorch
We can perform `(a + b) * a * b` using native PyTorch code.

In [None]:
import torch
from torch_xla.core import xla_model as xm

device = xm.xla_device()

a = torch.randn(256, 1024, dtype=torch.float32).to(device)
b = torch.randn(256, 1024, dtype=torch.float32).to(device)
c = a + b
out = a * b * c

print(out)

Now let’s replace the tensor addition (`c = a + b`) with a NKI kernel. To do this we replace the `+` operator with a call to the NKI kernel caller (`nki_tensor_add`), and everything else works as before.

In [None]:
def nki_tensor_add(a_input, b_input):
  """NKI kernel caller to compute element-wise addition of two input tensors

  This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs

  Args:
      a_input: a first input tensor, of shape [N*128, M*512]
      b_input: a second input tensor, of shape [N*128, M*512]

  Returns:
      a tensor of shape [N*128, M*512], the result of a_input + b_input
  """

  # The SPMD launch grid denotes the number of kernel instances.
  # In this case, we use a 2D grid where the size of each invocation is 128x512
  grid_x = a_input.shape[0] // 128
  grid_y = a_input.shape[1] // 512

  return nki_tensor_add_kernel_[grid_x, grid_y](a_input, b_input)

device = xm.xla_device()
a = torch.randn(256, 1024, dtype=torch.float32).to(device)
b = torch.randn(256, 1024, dtype=torch.float32).to(device)
c = nki_tensor_add(a, b) # calling a NKI kernel, instead of the built-in torch op
out = a * b * c
print(out)

To understand what happens under the hood when we compile the above code, we can print HLO IR graph generated by XLA by setting the `NEURON_FRAMEWORK_DEBUG` environment variable. For example, you may add the following lines to your code:

In [None]:
import os
os.environ['NEURON_FRAMEWORK_DEBUG'] = "1"

A `.pbtxt` file is then written in your run directory that has the corresponding human-readable HLO IR.

Let’s examine the XLA output of this example. In line #5 we can identify that the tensor addition is now mapped to an HLO `custom-call` instruction, with `AwsNeuronCustomNativeKernel` as `custom_call_target`. The output of that `custom-call` is then consumed by the next instruction in line #6 as usual.

```python
ENTRY %SyncTensorsGraph.22 (p0.2: f32[256,1024], p1.2: f32[256,1024]) -> (f32[256,1024]) {
 %p1.2 = f32[256,1024]{1,0} parameter(1), frontend_attributes={neff_input_name="input1"}
 %p0.2 = f32[256,1024]{1,0} parameter(0), frontend_attributes={neff_input_name="input0"}
 %multiply = f32[256,1024]{1,0} multiply(f32[256,1024]{1,0} %p1.2, f32[256,1024]{1,0} %p0.2)
 %custom-call.2 = f32[256,1024]{1,0} custom-call(f32[256,1024]{1,0} %p1.2, f32[256,1024]{1,0} %p0.2), custom_call_target="AwsNeuronCustomNativeKernel", api_version=API_VERSION_UNSPECIFIED, backend_config="...")
 %multiply.1 = f32[256,1024]{1,0} multiply(f32[256,1024]{1,0} %multiply, f32[256,1024]{1,0} %custom-call.2)
 ROOT %tuple = (f32[256,1024]{1,0}) tuple(f32[256,1024]{1,0} %multiply.1), frontend_attributes={neff_output_names="output0"}
}
```

The Neuron compiler replaces the above custom-call with the corresponding NKI kernel implementation while optimizing the rest of the compute graph as usual. At the end of the compilation process, a single compiled binary NEFF file is generated representing the entire graph including the NKI kernel. For more information about NEFF files, see [Neuron Compiler](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/compiler/index.html).

## Using NKI in training graphs

If you are using NKI to implement a new operator in a training graph, you might need to make the new operator interplay with the `autograd` engine in the framework. To do this, in PyTorch, you can subclass the framework’s base operator class and implement both the `forward()` and `backward()` methods. The `autograd` engine then uses the `backward()` method when performing auto-differentiation. See Extending [torch.autograd](https://pytorch.org/docs/stable/notes/extending.html) in the PyTorch Docs for instructions on doing this in PyTorch.

Let’s reuse the `nki_tensor_add` kernel from before and demonstrate how to train a simple compute graph `(a+b)*a*b` in PyTorch.

## PyTorch

We define a `NkiAddFunc` class, which leverages the `nki_tensor_add` kernel in its `forward()` function. The gradients of both input tensors in `y = a + b` are ones, so the `backward()` function propagates the `dy` gradients from the previous backward function.

In [None]:
import torch
import torch_xla.core.xla_model as xm
device = xm.xla_device()

class NkiAddFunc(torch.autograd.Function):
  @staticmethod
  def forward(ctx, a, b):
    return nki_tensor_add(a, b)

  @staticmethod
  def backward(ctx, dy, *args):
    # gradients for a and b
    return dy, dy

# now, let's define the compute graph
a = torch.randn(256, 1024, dtype=torch.float32).to(device).detach().requires_grad_()
b = torch.randn(256, 1024, dtype=torch.float32).to(device).detach().requires_grad_()
c = NkiAddFunc.apply(a, b)
out = a * b * c

# here we define a (dummy) loss-function, in prep for backward propagation
loss = out.sum()

# lastly, let's invoke the auto-grad engine
loss.backward()

xm.mark_step()

## Release the NeuronCore for the next notebook

Before moving to the next notebook we need to release the NeuronCore. If we don't do this the next notebook will not be able to use the resources - you can also stop the kernel via the GUI.

> When running the command in the next cell the notebook will give an error - this is to be expected.

In [None]:
import IPython
IPython.Application.instance().kernel.do_shutdown(True)