# Softmax forward and backward pass

This is an (inefficient) softmax implementation designed to help understand the forward and backward pass of triton.  It works by fitting an entire row into one block for both the forward and backward passes.

It's loosely based on [layernorm](https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html) and [softmax](https://triton-lang.org/main/getting-started/tutorials/02-fused-softmax.html) Triton tutorials.

First, we define the forward kernel.  Remember, softmax is $sigma(z_i) = \frac{e^{z_{i}}}{\sum_{j=1}^K e^{z_{j}}}$.

1. Load in a row of data and define the mask.
2. Subtract the maximum of the row from each value in the row for numerical stability (prevent overflows with $e^{z_{i}}$)
3. Implement the softmax equation (raise e to the power of the x values, divide by the sum)

In [1]:
import torch
from torch.testing import assert_close
import torch.nn.functional as F

import triton
import triton.language as tl


@triton.jit
def softmax_fwd_kernel(
    x_ptr,
    output_ptr,
    x_rows,
    x_cols,
    BLOCK_SIZE:tl.constexpr
):
    pid = tl.program_id(axis=0) # Get the current row id
    
    row_start = pid * x_cols
    row_offset = row_start + tl.arange(0, BLOCK_SIZE)
    
    mask = row_offset < row_start + x_cols
    
    row = tl.load(x_ptr + row_offset, mask=mask, other=-float('inf'))
    
    row = row - tl.max(row, axis=0) # Subtract max for stability
    num = tl.exp(row) # Raise e to the power of each element
    denom = tl.sum(num, axis=0) # Sum the elements in the row
    output = num / denom #Softmax output
    
    tl.store(output_ptr + row_offset, output, mask=mask)

Then, we can define the softmax forward python function.  We operate on each row.  We have to allocate a block that covers all the elements in the row, but Triton requires that we go to the next power of 2.

We create a 1-d launch grid over the rows.

In [2]:
def softmax_fwd(x):
    output = torch.empty_like(x)
    x_rows, x_cols = x.shape
    
    # Set the block size - Triton only works with powers of 2
    BLOCK_SIZE = triton.next_power_of_2(x_cols)
    # 1-d launch grid
    grid = lambda meta: (x_rows,)
    
    softmax_fwd_kernel[grid](x, output, x_rows, x_cols, BLOCK_SIZE=BLOCK_SIZE)
    
    return output

We can test our softmax kernel against the torch implementation using `assert_close`.

In [3]:
torch.manual_seed(0)
x = torch.rand((10, 6), device='cuda')

output_triton = softmax_fwd(x)

In [4]:
x.requires_grad = True
output_torch = F.softmax(x, dim=-1)

assert_close(output_triton, output_torch)

We set `requires_grad` for `x` above - this will allow us to run a backward pass with our torch tensor.  We can define a fake one-hot encoded target, y, and use [CrossEntropyLoss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) to find the gradient wrt `x`.

In [5]:
# Setup a fake target
y = torch.zeros_like(x)
inds = (torch.arange(0, y.shape[0]), torch.randint(0, 3, (y.shape[0],)))
y[inds] = 1

# Define loss and run backward pass
loss_fn = torch.nn.CrossEntropyLoss()
loss = loss_fn(output_torch, y)
loss.backward()

# Save gradient tensor for later
torch_xgrad = x.grad.detach().clone()

We can now write the softmax backward kernel to compare to torch.  This is a very inefficient implementation for learning purposes only!  We break down the softmax into a computational graph, and then run the backward pass against the graph, step by step.

This is the forward computational graph for softmax:

![comp graph](images/comp_graph.png)

1. Run the forward pass again to regenerate values needed to compute backward step.
2. Compute gradient step by step (see comments).
3. Store the gradient.

In [7]:
@triton.jit
def softmax_bwd_kernel(
    x_ptr,
    dy_ptr,
    dx_ptr,
    x_rows,
    x_cols,
    BLOCK_SIZE: tl.constexpr
):
    pid = tl.program_id(axis=0) # Get the current row id
    
    row_start = pid * x_cols # Get start of row that was softmaxed
    row_offset = row_start + tl.arange(0, BLOCK_SIZE)
    
    mask = row_offset < row_start + x_cols
    
    # Get x row and incoming gradient row
    x_row = tl.load(x_ptr + row_offset, mask=mask, other=-float('inf'))
    dy_row = tl.load(dy_ptr + row_offset, mask=mask, other=0.0)
    
    x_normed = x_row - tl.max(x_row, axis=0) # Subtract max for stability
    
    num = tl.exp(x_normed) # Raise e to the power of each element
    denom = tl.sum(num, axis=0) # Sum the elements in the row
    
    # This is modified slightly to make the derivative easier to understand
    inverted = 1 / denom # Invert denominator
    output = num * inverted # Multiply numerator by inverted denominator
    
    # Find gradient for both components of final softmax forward step
    num_grad = dy_row * inverted
    inverted_grad = dy_row * num
    inverted_grad = tl.sum(inverted_grad, axis=0)
    
    # Gradient wrt to softmax denominator
    denom_grad = (-1 / (denom * denom)) * inverted_grad
    
    # Sum grad from denominator (which is the sum of numerator) into the numerator grad
    num_grad += tl.full(num_grad.shape, 1.0, dtype=tl.float32) * denom_grad
    
    # Find gradient wrt the normalized xs
    normed_grad = num_grad * tl.exp(x_normed)
    
    # Find the gradient across the normalization
    x_grad = normed_grad
    max_grad = -normed_grad
    max_grad = tl.sum(max_grad, axis=0)
    x_grad_2 = max_grad * tl.where(x_row == tl.max(x_row, axis=0), tl.full(x_grad.shape, 1.0, dtype=tl.float32), tl.zeros(x_grad.shape, dtype=tl.float32))
   
    tl.store(dx_ptr + row_offset, x_grad + x_grad_2, mask=mask)

Next, we create a torch autograd function that lets us define a custom forward and backward pass:

In [8]:
class Softmax(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x):
        output = torch.empty_like(x)
        x_rows, x_cols = x.shape

        BLOCK_SIZE = triton.next_power_of_2(x_cols)
        grid = lambda meta: (x_rows,)

        softmax_fwd_kernel[grid](x, output, x_rows, x_cols, BLOCK_SIZE=BLOCK_SIZE)
        
        ctx.save_for_backward(x) # cache x for use in backward pass
        return output

    @staticmethod
    def backward(ctx, dy):
        x = ctx.saved_tensors[0] # grab x from the cache
        dx = torch.empty_like(x)
        x_rows, x_cols = x.shape

        BLOCK_SIZE = triton.next_power_of_2(x_cols)
        grid = lambda meta: (x_rows,)

        softmax_bwd_kernel[grid](x, dy, dx, x_rows, x_cols, BLOCK_SIZE=BLOCK_SIZE)

        return dx


softmax = Softmax.apply

We can test it by checking the forward pass:

In [9]:
x.grad.data.zero_() # zero out the x gradient

output_triton2 = softmax(x)

assert_close(output_triton, output_triton2)

And the backward pass:

In [10]:
triton_loss = loss_fn(output_triton2, y)
triton_loss.backward()
assert_close(x.grad, torch_xgrad)

tensor([[ 0.0160,  0.0163, -0.0846,  0.0176,  0.0176,  0.0171],
        [-0.0835,  0.0178,  0.0160,  0.0182,  0.0158,  0.0157],
        [ 0.0162,  0.0180, -0.0842,  0.0158,  0.0177,  0.0163],
        [-0.0838,  0.0166,  0.0176,  0.0159,  0.0174,  0.0163],
        [ 0.0158, -0.0831,  0.0158,  0.0163,  0.0182,  0.0170],
        [-0.0840,  0.0172,  0.0174,  0.0181,  0.0156,  0.0157],
        [ 0.0169, -0.0823,  0.0160,  0.0157,  0.0171,  0.0166],
        [ 0.0161, -0.0825,  0.0186,  0.0158,  0.0163,  0.0158],
        [ 0.0171, -0.0833,  0.0155,  0.0176,  0.0166,  0.0165],
        [-0.0842,  0.0161,  0.0167,  0.0178,  0.0178,  0.0159]],
       device='cuda:0')
