<a href="https://colab.research.google.com/github/anshulsawant/llm-systems/blob/main/CUDA_simulator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import os
os.environ['NUMBA_ENABLE_CUDASIM'] = '1'
import numba
import numpy as np
from numba import cuda

In [3]:
@cuda.jit
def vec_add(A, B, n, out):
    x = cuda.threadIdx.x
    bx = cuda.blockIdx.x
    bdx = cuda.blockDim.x
    i = bx * bdx + x
    if i < n:
      out[i] = A[i] + B[i]

In [4]:
n = 14
A = np.arange(n)
B = np.ones_like(A)
C = np.zeros_like(A)

In [5]:
griddim = 4
blockdim = 4
vec_add[griddim, blockdim](A, B, n, C)
C

array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14])

In [6]:
@cuda.jit(device=True)
def index_to_position(index, strides, num_dims):
    '''
     Converts a multidimensional tensor index into a single-dimensional position in storage
     based on strides.
     Args:
        index: index tuple of ints
        strides: tensor strides
        num_dims: number of dimensions in the tensor, e.g. shape/strides of [2, 3, 4] has 3 dimensions

     Returns:
        int - position in storage
    '''
    position = 0;
    for i in range(num_dims):
        position += index[i] * strides[i];
    return position;

@cuda.jit(device=True)
def to_index(ordinal, shape, out_index, num_dims):
    '''
     Convert an ordinal to an index in the shape. Should ensure that enumerating position 0 ... size of
     a tensor produces every index exactly once. It may not be the inverse of index_to_position.
     Args:
        ordinal: ordinal position to convert
        shape: tensor shape
        out_index: return index corresponding to position
        num_dims: number of dimensions in the tensor

     Returns:
        None (Fills in out_index)
    '''
    cur_ord = ordinal;
    for i in reversed(range(num_dims)):
        sh = shape[i];
        out_index[i] = cur_ord % sh;
        cur_ord //= sh;

@cuda.jit(device=True)
def broadcast_index(big_index, big_shape, shape, out_index, num_dims_big, num_dims):
    '''
     Convert a big_index into big_shape to a smaller out_index into shape following broadcasting rules.
     In this case it may be larger or with more dimensions than the shape given.
     Additional dimensions may need to be mapped to 0 or removed.

     Args:
        big_index: multidimensional index of bigger tensor
        big_shape: tensor shape of bigger tensor
        nums_big_dims: number of dimensions in bigger tensor
        out_index: multidimensional index of smaller tensor
        shape: tensor shape of smaller tensor
        num_dims: number of dimensions in smaller tensor

     Returns:
        None (Fills in out_index)
    '''
    for i in range(num_dims):
        if shape[i] > 1:
            out_index[i] = big_index[i + (num_dims_big - num_dims)]
        else:
            out_index[i] = 0


In [7]:
@cuda.jit('void(float32[:], int32[:], int32[:], float32[:], int32[:], int32[:], float32[:], int32[:], int32[:])')
def simple_matmul(out, out_shape, out_strides, a, a_shape, a_strides, b, b_shape, b_strides):
  '''
  a: M X H
  b: H x N
  c: M x N
  all shapes and strides are 2-tuples
  '''
  sz = np.dtype(np.float32).itemsize
  x = cuda.threadIdx.x
  bx = cuda.blockIdx.x
  bd = cuda.blockDim.x

  out_ordinal = bd * bx + x
  out_index = [0]*2
  to_index(out_ordinal, out_shape, out_index, 2)
  if out_index[0] >= out_shape[0] or out_index[1] >= out_shape[1]:
    return
  out_pos = int(index_to_position(out_index, out_strides, 2))
  out[out_pos//sz] = 0
  for k in range(a_shape[1]):
    a_index = [out_index[0], k]
    b_index = [k, out_index[1]]
    a_pos = int(index_to_position(a_index, a_strides, 2))
    b_pos = int(index_to_position(b_index, b_strides, 2))
    out[out_pos//sz] += a[a_pos//sz] * b[b_pos//sz]
  cuda.syncthreads()

In [110]:
@cuda.jit('void(float32[:], int32[:], int32[:], float32[:], int32[:], int32[:], float32[:], int32[:], int32[:], int32)')
def tiled_batch_matmul(out, out_shape, out_strides, a, a_shape, a_strides, b, b_shape, b_strides):
  '''
  a: B X M X H
  b: B X H x N
  c: B X M x N
  all shapes and strides are 3-tuples
  '''
  sz = np.dtype(np.float32).itemsize
  num_dims = 3

  ## assert cuda.blockDim.x == cuda.blockDim.y


  tile_size = cuda.blockDim.x
  _a = cuda.shared.array((tile_size, tile_size), dtype=np.float32)
  _b = cuda.shared.array((tile_size, tile_size), dtype=np.float32)
  ## Each block is threads_per_block x threads_per_block
  ## Each grid is ceil(M/threads_per_block) x ceil(N/threads_per_block) x B

  batch = cuda.blockIdx.z
  if batch >= out_shape[0]:
    return

  tx = cuda.threadIdx.x
  ty = cuda.threadIdx.y

  out_x = cuda.blockIdx.x * tile_size + tx
  out_y = cuda.blockIdx.y * tile_size + ty

  out_index = [batch, out_x, out_y]
  out_value = 0
  for k in range(0, a_shape[2], tile_size):
    a_index = [batch, out_index[1], k + ty]
    if a_index[1] < a_shape[1] and a_index[2] < a_shape[2]:
      a_pos = int(index_to_position(a_index, a_strides, num_dims))
      _a[tx, ty] = a[a_pos//sz]
    else:
      _a[tx, ty] = 0.0
    b_index = [batch, k + tx, out_index[2]]
    if b_index[2] < b_shape[2] and b_index[1] < b_shape[1]:
      b_pos = int(index_to_position(b_index, b_strides, num_dims))
      _b[tx, ty] = b[b_pos//sz]
    else:
      _b[tx, ty] = 0.0
    cuda.syncthreads()
    for i in range(tile_size):
      if k + i < a_shape[2] and k + i < b_shape[1]:
        out_value += _a[tx, i] * _b[i, ty]
    cuda.syncthreads()
  if out_index[1] >= out_shape[1] or out_index[2] >= out_shape[2]:
    return
  out_pos = int(index_to_position(out_index, out_strides, num_dims))
  out[out_pos//sz] = out_value

In [114]:
batch_size = 3
M = 3
N = 2
H = 4
threads_per_block = 2
## A = np.random.randn(batch_size, M, H).astype(np.float32)
## B = np.random.randn(batch_size, H, N).astype(np.float32)
A = np.arange(batch_size * M * H).reshape((batch_size, M, H)).astype(np.float32)
B = np.arange(batch_size * H * N).reshape((batch_size, H, N)).astype(np.float32)
gridDim = ((M + threads_per_block - 1)//threads_per_block, (N + threads_per_block - 1)//threads_per_block, batch_size)
blockDim = (threads_per_block, threads_per_block)
C_np = A @ B
C = np.zeros_like(C_np)
grid_dim = (int(), int(), 1)
tiled_batch_matmul[gridDim, blockDim](C.reshape(-1), C.shape, C.strides, A.reshape(-1), A.shape, A.strides, B.reshape(-1), B.shape, B.strides)

assert np.allclose(C, C_np)