In [1]:
import numpy as np
import triton
import triton.language as tl
import torch

In [2]:
##Matrix addition - Note ask AI to teach you about the math behind working out block indices
#it is a lot more easier to do that

@triton.jit
def add_two_matrices(x_ptr, y_ptr, op_ptr, n_rows, n_cols, BLOCK_SIZE_X:tl.constexpr, BLOCK_SIZE_Y:tl.constexpr):

    pid_x = tl.program_id(0)
    pid_y = tl.program_id(1)

    ##This is simple because x and y are supposed to have the same shape
    row_offsets = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)
    col_offsets = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)

    ##Final offset -> combine both of them because indices need to be flat for memory access
    offsets = row_offsets[:,None] * n_cols + col_offsets[None,:]

    #mask -
    masks = (row_offsets[:,None] < n_rows) & (col_offsets[None,:] < n_cols)

    x = tl.load(x_ptr+offsets, mask=masks)
    y = tl.load(y_ptr+offsets, mask=masks)

    z = x+y

    tl.store(op_ptr+offsets, z, mask=masks)
      

In [3]:
def launch_matrix_add():
    x = torch.ones(100,130, device="cuda", dtype=torch.float32)
    y = torch.ones(100,130, device="cuda", dtype=torch.float32)
    z = torch.zeros_like(x)

    assert x.shape ==  y.shape
    n_rows, n_cols = x.shape

    grid = lambda meta : (
        triton.cdiv(n_rows, meta["BLOCK_SIZE_X"]),
        triton.cdiv(n_cols, meta["BLOCK_SIZE_Y"]))

    add_two_matrices[grid](x,y,z,n_rows,n_cols,32,32)

    ##Simple sanity check - Compute sum along rows and cols - This works only if all of them all identical
    print(torch.sum(z, axis=0).unique())
    print(torch.sum(z, axis=1).unique())

In [4]:
launch_matrix_add()

tensor([200.], device='cuda:0')
tensor([260.], device='cuda:0')


In [5]:
## 2D Copy Kernel using tile indexing only

@triton.jit
def matrix_copy_kernel(ip_ptr, op_ptr, n_rows, n_cols, BLOCK_SIZE_X:tl.constexpr, BLOCK_SIZE_Y:tl.constexpr):

    pid_row = tl.program_id(0)
    pid_col = tl.program_id(1)

    #Offsets
    row_offsets = pid_row * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)
    col_offsets = pid_col * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)
    offsets = row_offsets[:,None] * n_cols + col_offsets[None,:] ##Will get contiguous indices

    #masks
    mask = (row_offsets[:,None] < n_rows) & (col_offsets[None,:] < n_cols)

    ip_mat = tl.load(ip_ptr+offsets, mask=mask)
    op_mat = ip_mat

    tl.store(op_ptr+offsets,op_mat,mask=mask)
    

In [6]:
def launch_copy_kernel():
    x = torch.randn(157,123, device='cuda', dtype=torch.float32)
    y = torch.zeros_like(x)

    print(torch.equal(x,y))

    n_rows, n_cols = x.shape
    grid = lambda meta : (
        triton.cdiv(n_rows, meta["BLOCK_SIZE_X"]),
        triton.cdiv(n_cols, meta["BLOCK_SIZE_Y"])
    )

    matrix_copy_kernel[grid](x,y,n_rows,n_cols,32, 32)

    print(torch.equal(x,y))


In [7]:
launch_copy_kernel()

False
True


In [8]:
## Row Scaling Kernel (Mixed 2D + 1D indexing) -> a[row,col] *= scale[row]
##Kep in mind a = (M*N) scale = (M,), OP = (M,N);  i.e. This is elementwise scaling not
##dot product. this is like multiply every row by 


@triton.jit
def row_scale_kernel(mat2d_ptr, scale_ptr, op_ptr, n_rows, n_cols,
                    BLOCK_SIZE_ROW:tl.constexpr,
                    BLOCK_SIZE_COL:tl.constexpr):

    pid_row = tl.program_id(0)
    pid_col = tl.program_id(1)
    row_offsets = pid_row * BLOCK_SIZE_ROW + tl.arange(0, BLOCK_SIZE_ROW)
    col_offsets = pid_col * BLOCK_SIZE_COL + tl.arange(0, BLOCK_SIZE_COL)
    offsets = row_offsets[:,None] * n_cols + col_offsets[None,:]
    masks = (row_offsets[:,None] < n_rows) & (col_offsets[None,:] < n_cols)

    #scale vector mask
    scale_mask = row_offsets < n_rows ##Reuse pid_row as the dimensions are same and makes sense at a block level

    #Load the data
    mat2d = tl.load(mat2d_ptr+offsets, mask=masks)
    scale_vector = tl.load(scale_ptr+row_offsets, mask=scale_mask)

    op = mat2d * scale_vector[:,None]

    tl.store(op_ptr+offsets, op, mask=masks)
    

In [9]:
def launch_row_scale():

    x = torch.randn(151,137, device="cuda", dtype=torch.float32)
    scale = torch.randn(151, device="cuda", dtype=torch.float32)
    y = torch.empty_like(x)

    n_rows, n_cols = x.shape
    grid = lambda meta : (
        triton.cdiv(n_rows, meta["BLOCK_SIZE_ROW"]),
        triton.cdiv(n_cols, meta["BLOCK_SIZE_COL"])
    )

    row_scale_kernel[grid](x, scale, y, n_rows, n_cols, BLOCK_SIZE_ROW=32, BLOCK_SIZE_COL=32)
    print(torch.equal(x*scale[:,None],y)) ##Use allclose if you want

In [10]:
launch_row_scale()

True


In [11]:
## Colum scaling kernel -> a[row,col ] *= scale[col]; multiply each column with a  particular value

@triton.jit
def column_scale_kernel(mat2d_ptr, scale_ptr, op_ptr, n_rows, n_cols, 
                       BLOCK_SIZE_ROW:tl.constexpr,
                       BLOCK_SIZE_COL:tl.constexpr):

    pid_row = tl.program_id(0)
    pid_col = tl.program_id(1)

    row_offsets = pid_row * BLOCK_SIZE_ROW + tl.arange(0, BLOCK_SIZE_ROW)
    col_offsets = pid_col * BLOCK_SIZE_COL + tl.arange(0, BLOCK_SIZE_COL)
    offsets = row_offsets[:, None] * n_cols + col_offsets[None,:]

    masks = (row_offsets[:,None] < n_rows) & (col_offsets[None,:] < n_cols)

    #Compute scalee masks
    scale_mask = col_offsets < n_cols

    mat2d_ip = tl.load(mat2d_ptr+offsets, mask=masks)
    scalar = tl.load(scale_ptr+col_offsets, mask=scale_mask)

    op = mat2d_ip * scalar[None,:]

    tl.store(op_ptr+offsets, op, mask=masks)


In [12]:
def launch_col_scale():

    x = torch.randn(159,191, device="cuda", dtype=torch.float32)
    scalar = torch.randn(191, device="cuda", dtype=torch.float32)
    y = torch.empty_like(x)

    n_rows, n_cols = x.shape
    grid = lambda meta :(
        triton.cdiv(n_rows, meta["BLOCK_SIZE_ROW"]),
        triton.cdiv(n_cols, meta["BLOCK_SIZE_COL"])
    )
    column_scale_kernel[grid](x, scalar, y, n_rows, n_cols, BLOCK_SIZE_ROW=32,
                              BLOCK_SIZE_COL=32)

    print(torch.equal(x*scalar[None,:], y))

In [13]:
launch_col_scale()

True


In [14]:
## Transpose of a matrix

@triton.jit
def transpose_kernel(ip_ptr, op_ptr, n_rows, n_cols, BLOCK_SIZE_ROW:tl.constexpr,
                    BLOCK_SIZE_COL:tl.constexpr):

    pid_row = tl.program_id(0)
    pid_col = tl.program_id(1)

    row_offsets = pid_row * BLOCK_SIZE_ROW + tl.arange(0, BLOCK_SIZE_ROW)
    col_offsets = pid_col * BLOCK_SIZE_COL + tl.arange(0, BLOCK_SIZE_COL)
    load_offsets = row_offsets[:,None] * n_cols + col_offsets[None,:]
    
    load_masks = (row_offsets[:,None] < n_rows) & (col_offsets[None,:] < n_cols)

    ip_matr = tl.load(ip_ptr+load_offsets, mask=load_masks)
    op_matr = tl.trans(ip_matr)

    store_offsets = col_offsets[:,None] * n_rows + row_offsets[None,:]
    store_masks = (col_offsets[:,None]< n_cols) & (row_offsets[None,:] < n_rows)

    tl.store(op_ptr+store_offsets, op_matr, mask=store_masks)


In [15]:
def launch_transpose_kernel():

    X = torch.randn(157,231, device="cuda",dtype=torch.float32)
    n_rows, n_cols = X.shape
    X_transpose = torch.empty((n_cols, n_rows), device="cuda", dtype=torch.float32)

    grid = lambda meta : (
        triton.cdiv(n_rows,meta["BLOCK_SIZE_ROW"]),
        triton.cdiv(n_cols, meta["BLOCK_SIZE_COL"])
    )

    transpose_kernel[grid](X, X_transpose, n_rows, n_cols, 32,32)
    # print(X)
    # print(X_transpose)
    print(torch.equal(X.T, X_transpose))
    
    

In [16]:
launch_transpose_kernel()

True


In [17]:
##Masked Tile Write - Only copy data that satisfy conditioon (row_index + col_index)%2=0
#Note A and B will already have values, you need to copy values from A where the above condition
#is satisfied. 


@triton.jit
def masked_tile_write_kernel(ip1_ptr, ip2_ptr, n_rows, n_cols, BLOCK_SIZE_ROW:tl.constexpr,
                            BLOCK_SIZE_COL:tl.constexpr):

    pid_row = tl.program_id(0)
    pid_col = tl.program_id(1)

    row_offsets = pid_row * BLOCK_SIZE_ROW + tl.arange(0, BLOCK_SIZE_ROW)
    col_offsets = pid_col * BLOCK_SIZE_COL + tl.arange(0, BLOCK_SIZE_COL)
    offsets = row_offsets[:,None] * n_cols + col_offsets[None,:]
    
    mask1 = (row_offsets[:,None] < n_rows) & (col_offsets[None,:] < n_cols)
    ##Broadcasting works here as one of the dimensions is 1
    cond_mask = (row_offsets[:,None] +  col_offsets[None,:]) %2 ==0
    final_mask = mask1 & cond_mask

    ip = tl.load(ip1_ptr+offsets, mask=mask1)

    tl.store(ip2_ptr+offsets, ip, mask=final_mask)

    

In [18]:
def test_masked_tile_correctness(a,b, b_orij):
    for i in range(a.shape[0]):
        for j in range(b.shape[1]):
            if (i+j)%2 == 0:
                assert a[i,j] == b[i,j]
            else:
                assert b_orij[i,j] == b[i,j]
                
    return "Assertion passed" #It will break if tthe values are not similar

def launch_masked_tile_kernel():
    a = torch.zeros(535,729, device="cuda", dtype=torch.float32)
    b = torch.ones(535,729, device="cuda", dtype=torch.float32)
    b_copy = b.clone().detach()

    n_rows, n_cols = a.shape
    grid = lambda meta : (
        triton.cdiv(n_rows, meta["BLOCK_SIZE_ROW"]),
        triton.cdiv(n_cols, meta["BLOCK_SIZE_COL"])
    )
    masked_tile_write_kernel[grid](a,b,n_rows,n_cols,32,32)
    print(test_masked_tile_correctness(a,b,b_copy))

In [19]:
launch_masked_tile_kernel()


Assertion passed


In [20]:
##Sub-matrix extraction; Extract from a bigger matrix
#Note that the mental model here is the ips and opss are of different size. So the 
#focus in basing the grid on output 

@triton.jit
def submatrix_extraction_kernel(A_ptr, B_ptr, nrows_A, ncols_A, r0, c0, op_rows,
                               op_cols, BLOCK_SIZE_ROW:tl.constexpr, BLOCK_SIZE_COL:tl.constexpr):

    pid_row = tl.program_id(0)
    pid_col = tl.program_id(1)

    op_row_offsets = pid_row * BLOCK_SIZE_ROW + tl.arange(0, BLOCK_SIZE_ROW)
    op_col_offsets = pid_col * BLOCK_SIZE_COL + tl.arange(0, BLOCK_SIZE_COL)
    op_offsets = op_row_offsets[:,None] * op_cols + op_col_offsets[None,:]

    #Probably there should be only one mask?
    op_mask = (op_row_offsets[:,None] < op_rows) & (op_col_offsets[None,:] < op_cols)

    #ip offsets 
    ip_row_offsets = r0 + op_row_offsets
    ip_col_offsets = c0 + op_col_offsets
    ip_offsets = ip_row_offsets[:,None] * ncols_A + ip_col_offsets[None,:]
 
    A = tl.load(A_ptr+ip_offsets, mask=op_mask)
    tl.store(B_ptr+op_offsets, A, mask=op_mask)
    

##Note checked with GPT - What I have for mask for tl.load is fine. However, it is also
#important to make sure be on the defensive and add a safety mechanism
#ip_mask = (ip_row_offsets[:,None]<nrows_A) & (ip_row_offsets[None,:]<ncols_A)
#final_mask = ip_mask & op_mask 


In [21]:
def launch_submatrix_kernel():
    A = torch.randn(235,569, device="cuda", dtype=torch.float32)
    nrows_A, ncols_A = A.shape
    op_rows, op_cols = (21,31)
    r0,c0 = (151,323)

    B = torch.zeros(op_rows,op_cols, device="cuda", dtype=torch.float32)

    grid = lambda meta : (
        triton.cdiv(op_rows, meta["BLOCK_SIZE_ROW"]),
        triton.cdiv(op_cols, meta["BLOCK_SIZE_COL"])
    )

    submatrix_extraction_kernel[grid](A, B, nrows_A, ncols_A, r0, c0, op_rows, op_cols, 32,32)
    print(torch.equal(A[r0:r0+op_rows, c0:c0+op_cols],B))

In [22]:
launch_submatrix_kernel()

True
