# Matmul forward and backward

Source: [Triton matmul](https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html#sphx-glr-getting-started-tutorials-03-matrix-multiplication-py)


We can make a few improvements to our matmul implementation from last time:

- Group rows for better performance
- Add backward pass
- Fuse in activation function

## Grouping rows

We can group blocks of rows together to ensure that a given block of columns is only loaded once for several blocks of rows (instead of loading a block of columns once per block of rows).

Here's a diagram:

![ordering](images/grouped_vs_row_major_ordering.png)

This can improve performance, since memory access is a bottleneck.

## Fusing activation function

We can fuse in an activation function by calling it from out matmul kernel.  We first define the function:

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.testing import assert_close

import triton
import triton.language as tl

@triton.jit
def relu(x, shape):
    # Return the x value only when it's greater than 0
    return tl.where(x > 0, x, 0.0)

Then, we can define our launch function for the matmul kernel.  We move to a 1-d grid so we have more control over the order of the x/y pids.

In [2]:
def matmul_fwd(X, Y, activation="none"):
    x_rows, x_cols = X.shape
    y_rows, y_cols = Y.shape
    output = torch.zeros(x_rows, y_cols, device="cuda", dtype=torch.float16) # Output matrix
    
    BLOCK_SIZE_X = 128
    BLOCK_SIZE_Y = 128
    BLOCK_SIZE_K = 32
    GROUP_SIZE_X = 8
    
    # Create a 1-D grid to iterate across blocks
    # We do 1-D so we can group rows more efficiently
    grid = lambda meta: (triton.cdiv(x_rows, meta["BLOCK_SIZE_X"]) * triton.cdiv(y_cols, meta["BLOCK_SIZE_Y"]), )
    
    matmul_fwd_kernel[grid](X, Y, output, x_rows, x_cols, y_rows, y_cols, BLOCK_SIZE_X=BLOCK_SIZE_X, BLOCK_SIZE_Y=BLOCK_SIZE_Y, BLOCK_SIZE_K=BLOCK_SIZE_K, GROUP_SIZE_X=GROUP_SIZE_X, ACTIVATION=activation)
    
    return output

In [3]:
@triton.jit
def matmul_fwd_kernel(
    x_ptr,
    y_ptr,
    output_ptr,
    x_rows,
    x_cols,
    y_rows,
    y_cols,
    BLOCK_SIZE_X: tl.constexpr, # Row count per block
    BLOCK_SIZE_Y: tl.constexpr, # column count per block
    BLOCK_SIZE_K: tl.constexpr, # Inner dim to iterate over (count per iteration)
    GROUP_SIZE_X: tl.constexpr, # Size of each X group
    ACTIVATION: tl.constexpr, # Which activation function to use
):
    # The total pid count is number of x blocks times number of y blocks (total blocks to process)
    # We get the pid of the current program
    pid = tl.program_id(axis=0)

    # Number of blocks across x
    num_pid_x = tl.cdiv(x_rows, BLOCK_SIZE_X)

    # Number of blocks across y
    num_pid_y = tl.cdiv(y_cols, BLOCK_SIZE_Y)

    # Total number of programs in each group
    # Number of row groups times pid per column (we run once per row group/column block pair)
    num_pid_in_group = GROUP_SIZE_X * num_pid_y

    # Id of the group this program is in
    # Divide current pid by number of pid in group
    # Floor division gives the group id
    group_id = pid // num_pid_in_group

    # Row-id of the first program in the group
    first_pid_x = group_id * GROUP_SIZE_X

    # This is the size of the current group
    # If `num_pid_x` isn't evenly divisible by `GROUP_SIZE_X`, the last group is smaller than the others
    group_size_x = min(num_pid_x - first_pid_x, GROUP_SIZE_X)

    # This is the x pid of the block to be retrived - we're going down the rows block by block before incrementing the y column
    # So we multiply several blocks of rows by the same column block
    # Row-id of the program in the *launch grid*
    x_pid = first_pid_x + (pid % group_size_x)

    # Find the pid within the current group
    in_group_pid = (pid % num_pid_in_group) 

    # We only increment the y block when we've gone "down" the rows in x
    # Column id of the program in the launch grid
    y_pid = in_group_pid // group_size_x
    
    # Define the start position for the x and y pointers
    # Remember that we have to stride across the columns when incrementing rows, and vice versa
    x_row_start = x_pid * BLOCK_SIZE_X * x_cols # Start of the block we're selecting in x
    y_col_start = y_pid * BLOCK_SIZE_Y # Start of block in y
    
    # Get the row and column offsets.  Row offsets need to be multiplied by the number of columns in the matrix
    x_offsets = x_row_start + tl.arange(0, BLOCK_SIZE_X) * x_cols # Get row start index for each row (that's why we multiply by x_cols)
    y_offsets = y_col_start + tl.arange(0, BLOCK_SIZE_Y) # Get column start index for each column (stride is 1)
    
    # Get the k offsets, for the matrix inner dimension
    k_offsets = tl.arange(0, BLOCK_SIZE_K)
    
    # Define our x pointers, which will be from column 0 to k within each row
    x_ptrs = x_ptr + (x_offsets[:,None] + k_offsets[None,:])
    
    # Define our y pointers, which will be from 0 to k within each column
    # We multiply the k offsets by y_cols to get the row start positions
    y_ptrs = y_ptr + (k_offsets[:,None] * y_cols + y_offsets[None,:])
    
    # The accumulator stores the results as we iterate across k
    # Store in float32 for better numerical precision
    accumulator = tl.zeros((BLOCK_SIZE_X, BLOCK_SIZE_Y), dtype=tl.float32)
    # Iterate across k, increment by BLOCK_SIZE_K
    for k in range(0, tl.cdiv(x_cols, BLOCK_SIZE_K)):
        # Load the x subset to multiply
        # We mask to avoid loading anything beyond the end of each column
        # [None, :] adds a 1-length first dimension, so the mask can broadcast across x_ptrs
        a = tl.load(x_ptrs, mask=k_offsets[None, :] < x_cols - k * BLOCK_SIZE_K, other=0.0)
        
        # [:, None] adds a 1-length second dimension, so the mask can broadcast across y_ptrs
        b = tl.load(y_ptrs, mask=k_offsets[:,None] < y_rows - k * BLOCK_SIZE_K, other=0.0)
        
        # Multiply a and b, then add to the accumulator
        result = tl.dot(a, b)
        accumulator += result
        
        # Increment the x pointers to go across the rows
        x_ptrs += BLOCK_SIZE_K
        
        # Increment the y pointers to go down the columns - we need to multiply by y_cols because we're moving down the columns (across the rows)
        y_ptrs += BLOCK_SIZE_K * y_cols
    
    output = accumulator.to(tl.float16)
    
    # Add in the activation function
    # Depending on the function, you can do this before/after the fp16 cast
    if ACTIVATION == "relu":
        output = relu(output, (BLOCK_SIZE_X, BLOCK_SIZE_Y))
    
    # Find the output pointer positions
    output_x_start = x_pid * BLOCK_SIZE_X
    output_y_start = y_pid * BLOCK_SIZE_Y
    
    output_x_rows = output_x_start + tl.arange(0, BLOCK_SIZE_X)
    
    output_x_offsets = output_x_start + tl.arange(0, BLOCK_SIZE_X)
    output_y_offsets = output_y_start + tl.arange(0, BLOCK_SIZE_Y)
    
    # Store the data, ensuring we don't overflow the rows/columns
    output_ptrs = output_ptr + (output_x_offsets[:, None] * y_cols + output_y_offsets[None, :])
    output_mask = (output_x_rows[:, None] < x_rows) & (output_y_offsets[None,:] < y_cols)
    tl.store(output_ptrs, output, mask=output_mask)

In [4]:
torch.manual_seed(0)
x = torch.rand((8192, 8192), device='cuda', dtype=torch.float16) - .5
y = torch.rand((8192, 8192), device='cuda', dtype=torch.float16) - .5

act = nn.ReLU()
output_torch = act(x @ y)
output_triton = matmul_fwd(x, y, activation="relu")
print(
    f'The maximum difference between torch and triton is '
    f'{torch.max(torch.abs(output_torch - output_triton))}'
)
print(output_torch[:5,:5])
print(output_triton[:5,:5])

The maximum difference between torch and triton is 0.0
tensor([[ 0.0000,  0.0000,  0.0000,  5.2227,  0.0000],
        [ 2.6484,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 5.5938,  0.0000,  0.0000,  2.3438,  9.6562],
        [ 0.0000,  3.9062,  0.0000,  0.0000,  0.0000],
        [14.8203,  8.7969,  0.0000,  0.2194,  0.0000]], device='cuda:0',
       dtype=torch.float16)
tensor([[ 0.0000,  0.0000,  0.0000,  5.2227,  0.0000],
        [ 2.6484,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 5.5938,  0.0000,  0.0000,  2.3438,  9.6562],
        [ 0.0000,  3.9062,  0.0000,  0.0000,  0.0000],
        [14.8203,  8.7969,  0.0000,  0.2194,  0.0000]], device='cuda:0',
       dtype=torch.float16)


## Adding the backward pass

For `Z = X @ Y`, the gradient wrt `X` is `Z @ Y.T`, and the gradient wrt `Y` is `X.T @ Z`.  We can call the matmul kernel twice to calculate these. We could do additional optimization by only loading `Z` once per group for both computations.

First, we'll do the same operation in torch to benchmark:

In [5]:
x = torch.rand((8192, 8192), device='cuda', dtype=torch.float16) - .5
y = torch.rand((8192, 8192), device='cuda', dtype=torch.float16) - .5

x.requires_grad = True
y.requires_grad = True

output_torch = act(x @ y)
output_torch = F.softmax(output_torch, dim=-1)

In [6]:
target = torch.zeros_like(x)
inds = (torch.arange(0, target.shape[0]), torch.randint(0, x.shape[1], (target.shape[0],)))
target[inds] = 1

loss_fn = torch.nn.CrossEntropyLoss()
loss = loss_fn(output_torch, target)
loss.backward()

In [7]:
torch_xgrad = x.grad.detach().clone()
torch_ygrad = y.grad.detach().clone()

Then, we can define our matmul kernel:

In [8]:
class MatMul(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, y):
        activation = "relu"
        z = matmul_fwd(x, y, activation)
        
        ctx.save_for_backward(x, y, z) # cache x and y for use in backward pass
        ctx.activation = activation
        return z

    @staticmethod
    def backward(ctx, dz):
        x, y, z = ctx.saved_tensors # grab x and y from the cache
        
        # Apply relu backwards
        # Would be more efficient to fuse this in
        if ctx.activation == "relu":
            dz = torch.where(z > 0, dz, 0.0)
        
        # Would be more efficient to do a single matmul call here
        dx = matmul_fwd(dz, y.T, activation="none")
        dy = matmul_fwd(x.T, dz, activation="none")
        
        return dx, dy

matmul = MatMul.apply

In [9]:
x.grad.data.zero_() # zero out the x gradient
y.grad.data.zero_() # zero out y gradient

output_triton = matmul(x, y)
output_triton = F.softmax(output_triton, dim=-1)

triton_loss = loss_fn(output_triton, target)
triton_loss.backward()
assert_close(x.grad, torch_xgrad, atol=1e-2, rtol=0)
assert_close(y.grad, torch_ygrad, atol=1e-2, rtol=0)