# TensorIR Schedule Example

This notebook shows an example of using TensorIR schedule primitives
to schedule a simple workload so that it can run on CUDA GPU.
To prepare, follow the same step in [mlsys_hw2.ipynb](mlsys_hw2.ipynb)
to install the PyPI package `mlc-ai`.
And also, please note that the schedule we study in this notebook
is not optimal (i.e., running the fastest on GPU) for the workload,
but a rather decent example that demonstrates the usage of many
schedule primitives.

<!-- TODO: notebook link -->

The workload we study in this notebook is "summation + broadcast",
where we want to sum up a 1D array and broadcast the summation
results to each position.
For example, by taking the input `[10, 20, 30]`,
the workload outputs `[60, 60, 60]`.

Assuming the input array has length 256,
this workload can be written as the following in TensorIR:

In [1]:
import tvm
from tvm import tir
from tvm.script import tir as T


@T.prim_func
def sum_broadcast(
    A: T.Buffer((256,), "float32"),
    B: T.Buffer((256,), "float32"),
) -> None:
    for i, k in T.grid(256, 256):
        with T.block("sum_broadcast"):
            vi = T.axis.spatial(256, i)
            vk = T.axis.reduce(256, k)
            with T.init():
                B[vi] = T.float32(0)
            B[vi] += A[vk]

sch = tir.Schedule(sum_broadcast)

As shown above, we first initialize the result array `B` to zero via
```python
with T.init():
    B[vi] = T.float32(0)
```
and then accumulate the values over the array for each position through
```python
B[vi] += A[vk]
```

To compute this workload on GPU,
we assign each GPU thread to compute one element in `B`.
So let's first bind loops to thread axes and block axes on GPU.

For demo purpose, we use two threads blocks to compute this workload,
where each thread block has 128 threads.
This means that we need to bind a loop with length 2 to `blockIdx.x` in GPU,
and bind the another loop with length 128 to `threadIdx.x`.

In TensorIR schedule, to get the loops ready for thread binding,
we can first split the loop `i` into two loops, where the inner loop
has length 128 and the outer loop has length 2.

In [2]:
# Fetch the computation block of "sum_broadcast".
block = sch.get_block("sum_broadcast")
# Fetch the i, j loops of the computation.
i, j = sch.get_loops(block)
# Split i into two loops.
i_outer, i_inner = sch.split(i, factors=[2, 128])  # or `factors=[None, 128]`,
                                                   # in which case `None` will be automatically inferred to 2.

You can print out the function after split via `sch.show()`.
You will find out that loop `i` becomes two loops `i_0` and `i_1` with desired lengths in the function.

In [3]:
sch.show()

Now we can bind them to `blockIdx.x` and `threadIdx.x` respectively.

In [4]:
sch.bind(i_outer, "blockIdx.x")
sch.bind(i_inner, "threadIdx.x")
sch.show()

After thread binding, this function is already runnable on GPU.
We can quickly test this fact and print out the source CUDA code we generate via scheduling.

In [5]:
def build_and_test(sch: tir.Schedule) -> None:
    import numpy as np

    # Build the scheduled function.
    f = tvm.build(sch.mod, target="cuda")

    # Create the NumPy array for testing.
    a_np = np.random.rand(256).astype("float32")
    b_np = np.broadcast_to(a_np.sum(keepdims=True), shape=(256,))

    # Run the function we scheduled and built.
    device = tvm.cuda()
    a_tvm = tvm.nd.array(a_np, device=device)
    b_tvm = tvm.nd.empty((256,), "float32", device=device)
    f(a_tvm, b_tvm)

    # Validate the result correctness.
    np.testing.assert_allclose(b_tvm.numpy(), b_np, atol=1e-5, rtol=1e-5)
    print("Test passed.")

    # Print out the CUDA source code of the function we scheduled.
    print(f"CUDA source code:\n{f.imported_modules[0].get_source()}")


# Run building and testing.
build_and_test(sch)

Test passed.
CUDA source code:

#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \
     (__CUDACC_VER_MAJOR__ > 11))
#define TVM_ENABLE_L2_PREFETCH 1
#else
#define TVM_ENABLE_L2_PREFETCH 0
#endif

#ifdef _WIN32
  using uint = unsigned int;
  using uchar = unsigned char;
  using ushort = unsigned short;
  using int64_t = long long;
  using uint64_t = unsigned long long;
#else
  #define uint unsigned int
  #define uchar unsigned char
  #define ushort unsigned short
  #define int64_t long long
  #define uint64_t unsigned long long
#endif
extern "C" __global__ void __launch_bounds__(128) sum_broadcast_kernel(float* __restrict__ A, float* __restrict__ B);
extern "C" __global__ void __launch_bounds__(128) sum_broadcast_kernel(float* __restrict__ A, float* __restrict__ B) {
  for (int k = 0; k < 256; ++k) {
    if (k == 0) {
      B[((((int)blockIdx.x) * 128) + ((int)threadIdx.x))] = 0.000000e+00f;
    }
    B[((((int)blockIdx.x) * 128) + ((int)threadIdx.x))] = (B[((((int

Though the schedule above can already run the sum + broadcast workload on GPU,
let's step further with some more optimization.
We can find that each element in `A` is read for 256 times in total,
since every position in `B` relies on every element in `A`.
This **memory reuse** pattern indicates that we may
leverage shared memory to reduce global memory access time:
we can first create a shared memory buffer,
and let all threads collectively move data from global memory to shared memory.
The process of "all threads collectively moving data" is named "cooperative fetching",
which you will need to implement in Task 3.

In TensorIR, we can apply `cache_read`, `compute_at` and `bind` in order for this purpose.
We first create the shared memory cache stage with `cache_read`.

In [6]:
# `A` is the first buffer in `T.reads` of the block.
# Therefore it has "read_buffer_index" being 0.
A_shared = sch.cache_read(block, read_buffer_index=0, storage_scope="shared")
sch.show()

`cache_read` generates the cache stage in the outermost scope
(outside the loop `i_0` of the main block).
We would like to move it into the loop `i_0` and `i_1`,
so that the shared memory read stage is done within the thread block
we scheduled before.
For this purpose, we use `compute_at`, which, as it names suggests,
computes a certain block at a certain location.
In our example, we want to do the shared memory read right inside `i_1` 
(because we need to move it inside a thread block) and outside `k`
(since the summation already starts if we mote it inside `k`).
We can do this by

In [7]:
sch.compute_at(A_shared, i_inner)
sch.show()

By looking at the function printed above, we find that
every thread (every possible value of `i_1`) still iterates over the the entire loop `ax0`,
which means every thread will perform a full copy of `A` into shared memory.
So now the remaining job is the cooperative fetching.
In our example, because the data being read has 256 elements and we only have 128 threads,
each thread needs to copy two elements from global memory to shared memory.
To this end, we can split the loop `ax0` into two loops,
with the outer loop has length 2 and the inner loop has length 128.
Then we bind the inner loop to `threadIdx.x`,
in order to let all threads cooperatively copy the data according to their thread ids.

In [8]:
ax0 = sch.get_loops(A_shared)[-1]
_, ax0_inner = sch.split(ax0, factors=[None, 128])
sch.bind(ax0_inner, "threadIdx.x")
sch.show()

Now we are all good after cooperative fetching!
Let's run the tests again and print out the CUDA source code after scheduling.
In the CUDA source code, you can find that each thread reads one element
from global memory into shared memory.

In [9]:
build_and_test(sch)

Test passed.
CUDA source code:

#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \
     (__CUDACC_VER_MAJOR__ > 11))
#define TVM_ENABLE_L2_PREFETCH 1
#else
#define TVM_ENABLE_L2_PREFETCH 0
#endif

#ifdef _WIN32
  using uint = unsigned int;
  using uchar = unsigned char;
  using ushort = unsigned short;
  using int64_t = long long;
  using uint64_t = unsigned long long;
#else
  #define uint unsigned int
  #define uchar unsigned char
  #define ushort unsigned short
  #define int64_t long long
  #define uint64_t unsigned long long
#endif
extern "C" __global__ void __launch_bounds__(128) sum_broadcast_kernel(float* __restrict__ A, float* __restrict__ B);
extern "C" __global__ void __launch_bounds__(128) sum_broadcast_kernel(float* __restrict__ A, float* __restrict__ B) {
  __shared__ float A_shared[256];
  for (int ax0_0 = 0; ax0_0 < 2; ++ax0_0) {
    A_shared[((ax0_0 * 128) + ((int)threadIdx.x))] = A[((ax0_0 * 128) + ((int)threadIdx.x))];
  }
  __syncthreads();
  for 