# Add New Operator
Hidet is designed to be extensible. It is easy to add new operators to Hidet. There are two ways to add and schedule an operator.
1. **Rule-based Scheduling** Define the mathematical computation of the operator, and Hidet will automatically schedule the computation into a parallel tensor program with Hidet's rule-based scheduler.
2. **Template-based Scheduling** Besides the computation, users can also give the concrete implementation of the operator to achieve better performance for complex operators.  

In this tutorial, we will walk through how to define the computation of an operator and schedule it using the two methods.

In [1]:
import hidet

## 1. Define Operator Computation
Each operator takes a list of input tensors and produces a list of output tensors:

```
inputs: List[Tensor]
outputs: List[Tensor] = operator(inputs)
```
The precise mathematical definition of each operator in Hidet is defined through a domain-specific-language (DSL). In this tutorial, we will show how to define the mathematical definition of a new operator in Hidet using this DSL, which is defined in the [hidet.ir.compute](https://docs.hidet.org/stable/python_api/ir/compute.html#module-hidet.ir.compute) module.

### 1.1 Compute Primitives

Hidet provides compute primitives to define the mathematical computation of an operator.
#### 1.1.1 tensor_input
```python
tensor_input(name: sttr, dtype: str, shape: List[int])
```
The [`tensor_input()`](https://docs.hidet.org/stable/python_api/ir/compute.html#hidet.ir.compute.tensor_input) primitive defines a tensor inputby specifying the name hint, scalar data type, and shape of the tensor.

<sub>Examples:</sub>
```python
a = tensor_input('a', dtype='float32', shape=[10, 10])
b = tensor_input('b', dtype='float32', shape=[])
b = tensor_input('data', dtype='float16', shape=[1, 3, 224, 224])
```

#### 1.1.2 compute
```python
compute(name: str, shape: List[int], fcompute: Callable[[Var, ...], Expr])
```

The [`compute()`](https://docs.hidet.org/stable/python_api/ir/compute.html#hidet.ir.compute.compute) primitive defines a tensor by specifying
* the name of the tensor, just a hint for what the tensor represents,
* the shape of the tensor, and
* a function that maps an index to the expression that computes the value of the tensor at that index.

The computation of each element of the tensor is *independent* with each other and can be computed in parallel.

<sub>Semantics:</sub>
```python
# compute primitive
out = compute(
    name='hint_name',
    shape=[n1, n2, ..., nk],
    fcompute=lambda i1, i2, ..., ik: f(i1, i2, ..., ik)
)

# semantics
for i1 in range(n1):
  for i2 in range(n2):
    ...
      for ik in range(nk):
        out[i1, i2, ..., ik] = f(i1, i2, ..., ik)
```

<sub>Examples:</sub>
```python
# define an input tensor
a = tensor_input('a', dtype='float32', shape=[10, 10])

# example 1: slice the first column of a
b = compute('slice', shape=[10], fcompute=lambda i: a[i, 0])

# example 2: reverse the rows of matrix a
c = compute('reverse', shape=[10, 10], fcompute=lambda i, j: a[9 - i, j])

# example 3: add 1 to the diagonal elements of a
from hidet.ir.expr import if_then_else
d = compute(
  name='diag_add',
  shape=[10, 10],
  fcompute=lambda i, j: if_then_else(i == j, then_expr=a[i, j] + 1.0, else_expr=a[i, j])
)
```

#### 1.1.3 reduce
```python
reduce(shape: List[int], fcompute: Callable[[Var, ...], Expr], reduce_type='sum')
```

The [`reduce()`](https://docs.hidet.org/stable/python_api/ir/compute.html#hidet.ir.compute.reduce) primitive conducts a reduction operation on a domain with the given shape. It returns a scalar value and can be used in [`compute()`](https://docs.hidet.org/stable/python_api/ir/compute.html#hidet.ir.compute.compute) primitive.

<sub>Semantics:</sub>
```python
# reduce primitive
out = reduce(
    name='hint_name',
    shape=[n1, n2, ..., nk],
    fcompute=lambda i1, i2, ..., ik: f(i1, i2, ..., ik)
    reduce_type='sum' | 'max' | 'min' | 'avg'
)

# semantics
values = []
for i1 in range(n1):
  for i2 in range(n2):
    ...
      for ik in range(nk):
        values.append(f(i1, i2, ..., ik))
out = reduce_type(values)
```

<sub>Examples:</sub>
```python
# define an input tensor
a = tensor_input('a', dtype='float32', shape=[10, 10])

# example 1: sum all elements of a
c = reduce(shape=[10, 10], fcompute=lambda i, j: a[i, j], reduce_type='sum')

# example 2: sum the first column of a
d = reduce(shape=[10], fcompute=lambda i: a[i, 0], reduce_type='sum')

# example 3: matrix multiplication
b = tensor_input('b', dtype='float32', shape=[10, 10])
e = compute(
    name='e',
    shape=[10, 10],
    fcompute=lambda i, j: reduce(
        shape=[10],
        fcompute=lambda k: a[i, k] * b[k, j],
        reduce_type='sum'
    )
)
```



# 2. Define a Computation Task
The computation of each operator can be described as a directed acyclic graph (DAG). The DAG is composed of tensor nodes. Both [`tensor_input()`](https://docs.hidet.org/stable/python_api/ir/compute.html#hidet.ir.compute.tensor_input) and [`compute()`](https://docs.hidet.org/stable/python_api/ir/compute.html#hidet.ir.compute.compute) primitives create tensor nodes. The edges of the DAG are the dependencies between the tensor nodes. Such a DAG is stored in a [`Task`](https://docs.hidet.org/stable/python_api/ir/task.html#hidet.ir.task.Task) object.
```python
class Task(name: str, inputs: List[TensorNode], outputs: List[TensorNode])
```
Each task has a name, a list of inputs, and a list of outputs, correspongding to the inputs and outputs of the operator. The following example shows how to create a task.

In [2]:
def demo_task():
    from hidet.ir.compute import tensor_input, compute
    from hidet.ir.task import Task

    # define the computation DAG through the compute primitives
    a = tensor_input('a', dtype='float32', shape=[10])
    b = tensor_input('b', dtype='float32', shape=[10])
    c = compute('c', [10], lambda i: a[i] + i)
    d = compute('d', [10], lambda i: c[9 - i])
    e = compute('e', [10], lambda i: a[i] + b[i])

    # create a task object
    task = Task(name='task', inputs=[a, b], outputs=[d, e])
    print(task)


demo_task()

Task(
  name: task
  parameters: 
    a: tensor(float32, [10])
    b: tensor(float32, [10])
    d: tensor(float32, [10])
    e: tensor(float32, [10])
  inputs: [a, b]
  outputs: [d, e]
  computations: 
    b: tensor(float32, [10])
    e: float32[10] where e[v] = (a[v] + b[v])
    a: tensor(float32, [10])
    c: float32[10] where c[v_1] = (a[v_1] + v_1)
    d: float32[10] where d[v_2] = c[(9 - v_2)]
  attributes: {}
)


In the above example, there are 5 tensor nodes, where node `a` and `b` are inputs and node `d` and `e`. The computation of node `c` depends on the computation of node `a`. Node `d` depends on node `c`, and node `e` depends on both nodes `a` and `b`.

# 3. Build and Run a Task
We provide a driver function [`hidet.driver.build_task()`](https://docs.hidet.org/stable/python_api/driver.html#hidet.driver.build_task) to build a task into callable function. The [`build_task()`](https://docs.hidet.org/stable/python_api/driver.html#hidet.driver.build_task) function does the following steps to lower the task into a callable function:
1. Dispatch the task to a **scheduler** according to the target device and task.
2. The scheduler lowers the task into a tensor program, defined with [`IRModule`](https://docs.hidet.org/stable/python_api/ir/func.html#hidet.ir.func.IRModule).
3. Lower and optimize the `IRModule`.
4. Code generation that translates the IRModule into the target source code (e.g., `source.cu`).
5. Call compiler (e.g., `nvcc`) to compile the source code into a dynamic library (i.e., `lib.so`).
6. Load the dynamic library and wrap it to [`CompiledFunction`](https://docs.hidet.org/stable/python_api/runtime/index.html#hidet.runtime.CompiledFunction) that can be directly called.

We can define the following function to build and run a task.

In [3]:
from typing import List
from hidet.ir.task import Task


def run_task(task: Task, inputs: List[hidet.Tensor], outputs: List[hidet.Tensor]):
    """Run given task and print inputs and outputs"""
    from hidet.runtime import CompiledFunction

    # build the task
    func: CompiledFunction = hidet.driver.build_task(task, target_device='cpu')
    params = inputs + outputs

    # run the compiled task
    func(*params)

    print('Task:', task.name)
    print('Inputs:')
    for tensor in inputs:
        print(tensor)
    print('Output:')
    for tensor in outputs:
        print(tensor)
    print()

The following code shows how to 1) define the computation, 2) define the task, and 3) build and run the task.

In [4]:
from hidet.ir.compute import tensor_input, reduce, compute, arg_reduce, TensorNode

def add_example():
    a: TensorNode = tensor_input(name='a', dtype='float32', shape=[5])
    b: TensorNode = tensor_input(name='b', dtype='float32', shape=[5])
    c: TensorNode = compute(name='c', shape=[5], fcompute=lambda i: a[i] + b[i])
    task = Task(name='add', inputs=[a, b], outputs=[c])
    run_task(task, [hidet.randn([5]), hidet.randn([5])], [hidet.empty([5])])

add_example()

Compiling cpu task [92madd(a=float32[5], b=float32[5])[0m...


Task: add
Inputs:
Tensor(shape=(5,), dtype='float32', device='cpu')
[ 0.5170048  -0.8175022  -0.6692999  -0.2707757  -0.36673257]
Tensor(shape=(5,), dtype='float32', device='cpu')
[-0.8175891   0.4673212  -0.54976064 -1.0559387   0.30592343]
Output:
Tensor(shape=(5,), dtype='float32', device='cpu')
[-0.30058432 -0.350181   -1.2190605  -1.3267144  -0.06080914]



## 3.1 More Examples
We show more examples of using the compute primitives to define operator computation.

### 3.1.1 ReduceSum

In [5]:
def reduce_sum_example():
    a = tensor_input('a', dtype='float32', shape=[4, 3])
    b = compute(
        'b',
        shape=[4],
        fcompute=lambda i: reduce(
            shape=[3], fcompute=lambda j: a[i, j], reduce_type='sum'
        ),
    )
    task = Task('reduce_sum', inputs=[a], outputs=[b])
    run_task(task, [hidet.randn([4, 3])], [hidet.empty([4])])


reduce_sum_example()

Compiling cpu task [92mreduce_sum(a=float32[4, 3])[0m...


Task: reduce_sum
Inputs:
Tensor(shape=(4, 3), dtype='float32', device='cpu')
[[-0.21989535 -0.6286531  -0.52672076]
 [ 0.7672621   0.82575005 -1.0160285 ]
 [ 1.0468827  -0.6883719   0.29560193]
 [ 0.9096517   0.968135    0.9219604 ]]
Output:
Tensor(shape=(4,), dtype='float32', device='cpu')
[-1.3752692  0.5769836  0.6541128  2.799747 ]



### 3.1.2 MatMul

In [6]:
def matmul_example():
    a = tensor_input('a', dtype='float32', shape=[3, 3])
    b = tensor_input('b', dtype='float32', shape=[3, 3])
    c = compute(
        'c',
        shape=[3, 3],
        fcompute=lambda i, j: reduce(
            shape=[3], fcompute=lambda k: a[i, k] * b[k, j], reduce_type='sum'
        ),
    )
    task = Task('matmul', inputs=[a, b], outputs=[c])
    run_task(task, [hidet.randn([3, 3]), hidet.randn([3, 3])], [hidet.empty([3, 3])])


matmul_example()

Compiling cpu task [92mmatmul(a=float32[3, 3], b=float32[3, 3])[0m...


Task: matmul
Inputs:
Tensor(shape=(3, 3), dtype='float32', device='cpu')
[[-0.12335024  1.1787595  -1.1675321 ]
 [ 0.26967815  0.23652509  1.9818728 ]
 [-1.5455004   0.08275169  0.76819605]]
Tensor(shape=(3, 3), dtype='float32', device='cpu')
[[ 0.55907583  1.2184253  -0.12690002]
 [ 0.65917647  0.9455997  -0.7075036 ]
 [ 0.6402905  -0.6257311   1.5101963 ]]
Output:
Tensor(shape=(3, 3), dtype='float32', device='cpu')
[[-0.03951138  1.6949027  -2.581526  ]
 [ 1.5756567  -0.6878787   2.7914524 ]
 [-0.31763536 -2.285511    1.2977037 ]]



# 4. Add an Operator with Rule-based Scheduling
So far, we have learned how to define the computation using compute primitives and wrap it into a [`Task`](https://docs.hidet.org/stable/python_api/ir/task.html#hidet.ir.task.Task). In this section, we will learn how to add an [`Operator`](https://docs.hidet.org/stable/python_api/graph/index.html#hidet.graph.Operator) with the given computation definition, and use hidet's privided rule-based scheduler to automatically schedule the computation into a tensor program.

## 4.1 Three steps to define a new operator
There are three steps to define a new operator in Hidet.
1. Define the computation task class by inheriting [`Task`](https://docs.hidet.org/stable/python_api/ir/task.html#hidet.ir.task.Task).
2. Define the operator class by inheriting [`Operator`](https://docs.hidet.org/stable/python_api/graph/index.html#hidet.graph.Operator).
3. Define a function to create the operator instance.

## 4.2 Batch Matrix Multiplication Example
We will take the batch matrix multiplication as an example to illustrate the three steps.

### 4.2.1. Define the computation task class
We define the computation task class `BatchMatmulTask` by inheriting [`Task`](https://docs.hidet.org/stable/python_api/ir/task.html#hidet.ir.task.Task) class. The `BatchMatmulTask` class’s constructor function takes two arguments, `a` and `b` that are the input tensor nodes of the batch matrix multiplication.

In [7]:
from hidet.ir.compute import TensorNode, compute, reduce
from hidet.ir.task import Task


class BatchMatmulTask(Task):
    def __init__(self, a: TensorNode, b: TensorNode):
        # get the input sizes
        batch_size, m_size, k_size = a.const_shape()
        batch_size, k_size, n_size = b.const_shape()

        # define the computation
        c = compute(
            name='c',
            shape=[batch_size, m_size, n_size],
            fcompute=lambda p, i, j: reduce(
                shape=[k_size],
                fcompute=lambda k: a[p, i, k] * b[p, k, j],
                reduce_type='sum',
            ),
        )

        # call the parent class constructor to initialize the task
        super().__init__(
            name='batch_matmul',  # the name of the task
            inputs=[a, b],  # the input tensor nodes
            outputs=[c],  # the output tensor nodes
        )

### 4.2.2. Define the operator class
Our next step is to define the operator class `BatchMatmulOp` by inheriting [`Operator`](https://docs.hidet.org/stable/python_api/graph/index.html#hidet.graph.Operator) class.

In [8]:
from hidet.graph import Operator, Tensor
from hidet.graph.ops.definitions.utils import input_like


class BatchMatmulOp(Operator):
    def __init__(self, a: Tensor, b: Tensor):
        # call the parent class constructor to initialize the operator
        super().__init__(
            inputs=[a, b],  # the input tensors
            task=BatchMatmulTask(  # the task of the operator
                # create tensor nodes (TensorNode) with the same shape and dtype as the tensors (Tensor)
                input_like(a, 'a'),
                input_like(b, 'b'),
            ),
        )

### 4.2.3. Define a function to create the operator instance
We define a function `batch_matmul` to create the operator instance `BatchMatmulOp` and return the output tensor.

In [9]:
def batch_matmul(a: Tensor, b: Tensor) -> Tensor:
    # get_output(0) returns the first output tensor of the operator
    return BatchMatmulOp(a, b).get_output(0)

### 4.2.4. Use the defined operator
The new operator has no difference with the hidet provided operators, as we define hidet operators in the same way. For example, when we optimize the flow graph, this new operator can also fuse surrounding operators.

In [10]:
def demo_usage():
    a = hidet.randn([2, 2, 3])
    b = hidet.randn([2, 3, 2])
    c = batch_matmul(a, b)
    print(a)
    print(b)
    print(c)

demo_usage()

Compiling cpu task [92mbatch_matmul(a=float32[2, 2, 3], b=float32[2, 3, 2])[0m...


Tensor(shape=(2, 2, 3), dtype='float32', device='cpu')
[[[-0.1188942  -0.03512365 -1.1575822 ]
  [-0.60656774 -0.7368701  -0.81118315]]

 [[-1.2190988  -1.5492649   0.4305561 ]
  [-1.2007065   1.2391245  -1.319195  ]]]
Tensor(shape=(2, 3, 2), dtype='float32', device='cpu')
[[[-1.4443336   0.86954886]
  [-2.3942506   0.97157407]
  [-1.4177141   0.6442079 ]]

 [[ 0.20606913 -1.0127158 ]
  [ 0.02667027  1.5317795 ]
  [-1.2534379   0.34543064]]]
Tensor(shape=(2, 2, 2), dtype='float32', device='cpu')
[[[ 1.8969383  -0.8832331 ]
  [ 3.7903638  -1.7659348 ]]

 [[-0.8322132  -0.98980427]
  [ 1.4391483   2.6583495 ]]]


## 4.3. Two Scheduling Mechanisms
We only define the computation of the operator, and leave the scheduling to the rule-based scheduler provided by hidet. We call this method of scheduling as **rule-based scheduling**. Most hidet operators are using the same rule-based scheduler as we used in this example. Our experience shows that the rule-based scheduler can achieve good performance for operators that do not have large amount of reduction. However, for operators like matrix multiplication, convolution, etc., the rule-based scheduler may not be able to achieve the best performance as it does not use shared memory to cache the data loading. Thus, hidet also provides another scheduling mechanism, the **template-based scheduling**.

# 5. Add an Operator with Template-based Scheduling
Template-based scheduling allows us to define a tensor program template, and the template will be instantiated for different input shapes and tunable hyper-parameters.

## 5.1 Override `implement_cuda()` method
The [`Task`](https://docs.hidet.org/stable/python_api/ir/task.html#hidet.ir.task.Task) class have two methods, `implement_cpu()` and `implement_cuda()` that we can override when we define a new task.

In [11]:
from hidet.ir.compute import TensorNode, compute, reduce
from hidet.ir.task import Task
from hidet.ir.func import IRModule

class BatchMatmulFp16Task(Task):
    def __init__(self, a: TensorNode, b: TensorNode):
        batch_size, m_size, k_size = a.const_shape()
        batch_size, k_size, n_size = b.const_shape()
        c = compute(
            name='c',
            shape=[batch_size, m_size, n_size],
            fcompute=lambda p, i, j: reduce(
                shape=[k_size],
                fcompute=lambda k: a[p, i, k] * b[p, k, j],
                reduce_type='sum',
            ),
        )
        super().__init__(
            name='batch_matmul_fp16',
            inputs=[a, b],
            outputs=[c],
            attributes={
                'batch_size': batch_size,
                'm_size': m_size,
                'n_size': n_size,
                'k_size': k_size,
            },
        )

    def allow_epilogue(self) -> bool:
        return False

    def implement_cuda(self, working_dir: str) -> IRModule:
        # override this method to use template-based scheduling
        return batch_matmul_mma_fp16_schedule(self)

In above task definition, we override the `implement_cuda()` method to use template-based scheduling. Inside the `implement_cuda()` method, we call the `batch_matmul_mma_fp16_schedule()` function which we will write to get a tensor program that implements the computation defined in the task.

## 5.2. Implement the tensor program
We can implement the `batch_matmul_mma_fp16_schedule()` function in the following way. This function is written using Hidet Script, a DSL for writing tensor programs, which we will explore in detail in the next section. Understanding the below code requires knowledge in Hidet Script and efficient CUDA programming. For now, we can skip the details of this implementation.

In [12]:
def batch_matmul_mma_fp16_schedule(task: BatchMatmulFp16Task) -> IRModule:
    from hidet.lang import f16, spatial, repeat, tensor, attr, grid, printf, cast
    from hidet.lang.mapping import repeat, spatial
    from hidet.lang.cuda import blockIdx, threadIdx, syncthreads
    from hidet.lang.cuda import MmaConfig, mma_sync
    from hidet.transforms.tools import add_packed_func

    # get the workload size
    bs = task.attributes['batch_size']
    m_size = task.attributes['m_size']
    n_size = task.attributes['n_size']
    k_size = task.attributes['k_size']

    # define the template hyper-parameters
    mma_config = MmaConfig.m16n8k8_f16_f16()
    block_m, block_n, block_k = 128, 128, 8
    warp_m, warp_n, warp_k = 64, 64, 8
    warp_count_m, warp_count_n, warp_count_k = 2, 2, 1
    mma_m, mma_n, mma_k = mma_config.m, mma_config.n, mma_config.k  # 16, 8, 8
    mma_count_m, mma_count_n, mma_count = 4, 8, 1
    threads = warp_count_m * warp_count_n * warp_count_k * 32

    # define the tensor program
    with hidet.script_module() as module:

        @hidet.script
        def load_regs_a(
            smem_a: f16[block_m, block_k], regs_a: f16[4, mma_config.a_elements]
        ):
            """Load A registers from shared memory."""
            warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32
            for wi, wj, wk in spatial(warp_count_m, warp_count_n, warp_count_k).on(
                warp_id
            ):
                for mi in range(mma_count_m):
                    p = 0
                    for i, k in mma_config.a_load_map.on(lane_id):
                        regs_a[mi, p] = smem_a[
                            wi * warp_m + mi * mma_m + i, wk * warp_k + k
                        ]
                        p += 1

        @hidet.script
        def load_regs_b(
            smem_b: f16[block_k, block_n], regs_b: f16[8, mma_config.b_elements]
        ):
            """Load B registers from shared memory."""
            warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32
            for wi, wj, wk in spatial(warp_count_m, warp_count_n, warp_count_k).on(
                warp_id
            ):
                for mj in range(mma_count_n):
                    p = 0
                    for k, j in mma_config.b_load_map.on(lane_id):
                        regs_b[mj, p] = smem_b[
                            wk * warp_k + k, wj * warp_n + mj * mma_n + j
                        ]
                        p += 1

        @hidet.script
        def warp_mma(
            regs_a: f16[4, mma_config.a_elements],
            regs_b: f16[8, mma_config.b_elements],
            regs_c: f16[4, 8, mma_config.c_elements],
        ):
            """Perform warp-level matrix multiplication."""
            for mi, mj in repeat(mma_count_m, mma_count_n).on(0):
                mma_sync(mma_config, ~regs_a[mi, 0], ~regs_b[mj, 0], ~regs_c[mi, mj, 0])

        @hidet.script
        def store_c(regs_c: f16[4, 8, mma_config.c_elements], c: f16[bs, m_size, n_size]):
            """Store C registers to global memory."""
            warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32
            offset_m, offset_n = blockIdx.x * block_m, blockIdx.y * block_n
            gmem_c = c[blockIdx.z, offset_m:, offset_n:]
            for k_round in range(warp_count_k):
                for wi, wj, wk in spatial(warp_count_m, warp_count_n, warp_count_k).on(
                    warp_id
                ):
                    if wk == k_round:
                        for mi, mj in repeat(mma_count_m, mma_count_n).on(0):
                            p = 0
                            for i, j in mma_config.c_store_map.on(lane_id):
                                gmem_c.write(
                                    [
                                        wi * warp_m + mi * mma_m + i,
                                        wj * warp_n + mj * mma_n + j,
                                    ],
                                    regs_c[mi, mj, p],
                                    protected=True,
                                )
                                p += 1

        @hidet.script
        def batch_matmul_kernel(
            a: f16[bs, m_size, k_size],
            b: f16[bs, k_size, n_size],
            c: f16[bs, m_size, n_size],
        ):
            """Batch matrix multiplication kernel."""
            attr.cuda_grid_dim = (
                (m_size + block_m - 1) // block_m,
                (n_size + block_n - 1) // block_n,
                bs,
            )
            attr.cuda_block_dim = threads
            offset_m, offset_n = blockIdx.x * block_m, blockIdx.y * block_n
            smem_a = tensor('shared', 'float16', [block_m, block_k])
            smem_b = tensor('shared', 'float16', [block_k, block_n])
            regs_a = tensor('register', 'float16', [4, mma_config.a_elements])
            regs_b = tensor('register', 'float16', [8, mma_config.b_elements])
            regs_c = tensor('register', 'float16', [4, 8, mma_config.c_elements])

            for i, j, p in grid(4, 8, mma_config.c_elements):
                regs_c[i, j, p] = 0.0

            for k0 in range((k_size + block_k - 1) // block_k):
                offset_k = k0 * block_k
                gmem_a = a[blockIdx.z, offset_m:, offset_k:]
                gmem_b = b[blockIdx.z, offset_k:, offset_n:]
                for i, k in repeat(8, 1).spatial(16, 8).on(threadIdx.x):
                    smem_a[i, k] = gmem_a.read([i, k], protected=True)
                for k, j in repeat(8, 1).spatial(1, 128).on(threadIdx.x):
                    smem_b[k, j] = gmem_b.read([k, j], protected=True)
                syncthreads()
                load_regs_a(smem_a, regs_a)
                load_regs_b(smem_b, regs_b)
                warp_mma(regs_a, regs_b, regs_c)
                syncthreads()
            store_c(regs_c, c)

    ir_module = module.ir_module()
    # conduct the fusion (when the task has prologue or epilogue) and generate the packed function
    # ir_module = fuse_and_pack(ir_module, kernel_func=batch_matmul_kernel, task=task)
    add_packed_func(ir_module, func=batch_matmul_kernel, pack_func_name=task.name)
    return ir_module

## 5.3. Define the operator
The remaining part is the same as the rule-based scheduling method to add new operator.


In [13]:
from hidet.graph import Operator, Tensor
from hidet.graph.ops.definitions.utils import input_like

class BatchMatmulFp16Op(Operator):
    def __init__(self, a: Tensor, b: Tensor):
        assert a.dtype == hidet.float16 and b.dtype == hidet.float16
        super().__init__(
            inputs=[a, b],
            task=BatchMatmulFp16Task(input_like(a, 'a'), input_like(b, 'b')),
        )


def batch_matmul_fp16(a: Tensor, b: Tensor) -> Tensor:
    return BatchMatmulFp16Op(a, b).get_output(0)


def demo_usage():
    a = hidet.randn([1, 2, 2], dtype='float16', device='cuda')
    b = hidet.randn([1, 2, 2], dtype='float16', device='cuda')
    c = batch_matmul_fp16(a, b)
    print(a)
    print(b)
    print(c)


demo_usage()

Compiling cpu task [92mcast(x=float32[1, 2, 2])[0m...
Compiling cuda task [92mbatch_matmul_fp16(a=float16[1, 2, 2], b=float16[1, 2, 2], batch_size=1, m_size=2, n_size=2, k_size=2)[0m...


Tensor(shape=(1, 2, 2), dtype='float16', device='cuda:0')
[[[0.0829  1.327  ]
  [0.12366 1.723  ]]]
Tensor(shape=(1, 2, 2), dtype='float16', device='cuda:0')
[[[-0.05673  0.521  ]
  [-0.934   -0.864  ]]]
Tensor(shape=(1, 2, 2), dtype='float16', device='cuda:0')
[[[-1.244 -1.104]
  [-1.616 -1.424]]]


# 6. Summary
In this section, we have learned how to add a new operator to Hidet. We first define the computation task of the operator through compute primitives. We then define the operator class and either use rule-based scheduling or write our own schedule template and use template-based scheduling to implement the operator.
In the next section, we will learn about Hidet Script, a DSL which allows us to conveniently write our own efficient schedules.