In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import triton
import triton.language as tl
import matplotlib.pyplot as plt

In [None]:
@triton.jit
def linear_cross_entropy_fwd_kernel(
    X_ptr,
    X_row_stride,
    W_ptr,
    W_row_stride,
    W_col_stride,
    Y_ptr,
    dX_ptr,
    dX_row_stride,
    dW_ptr,
    dW_row_stride,
    dW_col_stride,
    LSE_ptr, # TODO: we don't need to store the LSE
    Loss_ptr,
    D: tl.constexpr,
    V: tl.constexpr,
    D_BLOCK: tl.constexpr,
    V_BLOCK: tl.constexpr,
    ignore_index: tl.constexpr,
):
    # --- Setup ---
    row_id = tl.program_id(axis=0).to(tl.int64)
    d_tile_offs = tl.arange(0, D_BLOCK)
    v_tile_offs = tl.arange(0, V_BLOCK)

    m = -float("inf")  # running max
    d = 0.0  # running exp sum

    # --- Pointer Logic ---
    X_ptr += row_id * X_row_stride
    dX_ptr += row_id * dX_row_stride
    Y_ptr += row_id
    LSE_ptr += row_id
    Loss_ptr += row_id

    # --- Pre-compute the Y logit ---
    Y = tl.load(pointer=Y_ptr)

    # if Y should be ignored, zero the loss and return early
    if Y == ignore_index:
        tl.store(Loss_ptr, 0.0)
        tl.store(LSE_ptr, 0.0)
        return

    Y_logit_acc = 0.0

    # load the X and W tile, accumulate their dot product
    for d_idx in tl.range(0, D, D_BLOCK):
        d_tile_mask = (d_idx + d_tile_offs) < D
        X_tile = tl.load(
            X_ptr + d_idx + d_tile_offs, mask=d_tile_mask, other=0.0
        )

        W_target_tile_ptr = (
            W_ptr + (d_idx + d_tile_offs) * W_row_stride + Y * W_col_stride
        )
        W_target_tile = tl.load(W_target_tile_ptr, mask=d_tile_mask, other=0.0)

        Y_logit_acc += tl.sum(X_tile.to(tl.float32) * W_target_tile.to(tl.float32))

    # --- Forward Pass ---
    # iterate through X by tiles and W by blocks
    # compute X@W, accumulate the LSE and compute the max
    for v_idx in tl.range(0, V, V_BLOCK):
        W_block_ptr = tl.make_block_ptr(
            base=W_ptr,
            shape=(D, V),
            strides=(W_row_stride, W_col_stride),
            offsets=(0, v_idx),
            block_shape=(D_BLOCK, V_BLOCK),
            order=(0, 1),
        )

        # compute X@W and accumulate the Y logit
        acc = tl.zeros((1, V_BLOCK), dtype=tl.float32)
        for d_idx in tl.range(0, D, D_BLOCK):
            tile_offs = d_idx + d_tile_offs
            tile_mask = tile_offs < D

            X_tile = tl.load(X_ptr + tile_offs, mask=tile_mask, other=0.0)
            W_block = tl.load(W_block_ptr, boundary_check=(0, 1), padding_option="zero")
            acc += tl.dot(X_tile[None, :], W_block)

            W_block_ptr = W_block_ptr.advance((D_BLOCK, 0))

        # update LSE and max
        m_tile = tl.max(acc)
        new_m = tl.maximum(m, m_tile)
        d = d * tl.exp(m - new_m) + tl.sum(tl.exp(acc - new_m))
        m = new_m

    lse = m + tl.log(d)
    loss = lse - Y_logit_acc

    tl.store(pointer=Loss_ptr, value=loss)
    tl.store(pointer=LSE_ptr, value=lse)

    # --- Backward Pass ---

    # # 1: Compute the normalised probabilities P
    for v_idx in tl.range(0, V, V_BLOCK):
        W_block_ptr = tl.make_block_ptr(
            base=W_ptr,
            shape=(D, V),
            strides=(W_row_stride, W_col_stride),
            offsets=(0, v_idx),
            block_shape=(D_BLOCK, V_BLOCK),
            order=(0, 1),
        )

        P_tile = tl.zeros((1, V_BLOCK), dtype=tl.float32)
        for d_idx in tl.range(0, D, D_BLOCK):
            tile_offs = d_idx + d_tile_offs
            tile_mask = tile_offs < D

            X_tile = tl.load(X_ptr + tile_offs, mask=tile_mask, other=0.0)
            W_block = tl.load(W_block_ptr, boundary_check=(0, 1), padding_option="zero")
            P_tile += tl.dot(X_tile[None, :].to(tl.float32), W_block.to(tl.float32))

            W_block_ptr = W_block_ptr.advance((D_BLOCK, 0))

        P_tile = tl.exp(P_tile - lse)  # normalise
        P_tile -= tl.where(
            v_idx + v_tile_offs == Y, 1, 0
        )  # subtract 1 when logit = label

        # 2: Compute the gradients:
        # dX = (P - Y) . W^T
        # dW = X^T . (P - Y)
        W_T_block_ptr = tl.make_block_ptr(  # logically transpose W
            base=W_ptr,
            shape=(V, D),
            strides=(W_col_stride, W_row_stride),
            offsets=(v_idx, 0),
            block_shape=(V_BLOCK, D_BLOCK),
            order=(1, 0),  # W is row major => W^T is column major
        )

        for d_idx in tl.range(0, D, D_BLOCK):
            # 2.1: Accumulate dX
            tile_offs = d_idx + d_tile_offs
            tile_mask = tile_offs < D

            W_T_block = tl.load(
                W_T_block_ptr, boundary_check=(0, 1), padding_option="zero"
            )

            dX_partial = tl.dot(P_tile, W_T_block).reshape(D_BLOCK)
            tl.atomic_add(
                pointer=dX_ptr + d_idx + d_tile_offs,
                val=dX_partial,
                sem="relaxed",  # the order of adds across threads does not matter
            )

            W_T_block_ptr = W_T_block_ptr.advance((0, D_BLOCK))

            # 2.2: Accumulate dW
            dW_tile_ptr = dW_ptr + (
                (d_idx + d_tile_offs)[:, None] * dW_row_stride
                + (v_idx + v_tile_offs)[None, :] * dW_col_stride
            )

            dW_mask = ((d_idx + d_tile_offs)[:, None] < D) & (
                (v_idx + v_tile_offs)[None, :] < V
            )

            X_T_tile = tl.load(X_ptr + tile_offs, mask=tile_mask, other=0.0)[:, None]

            dW_partial = X_T_tile.to(tl.float32) * P_tile
            tl.atomic_add(
                pointer=dW_tile_ptr, val=dW_partial, mask=dW_mask, sem="relaxed"
            )

In [None]:
device = "cuda:0"

N, D, V = 64, 512, 1024
X = torch.randn((N, D), device=device, requires_grad=True)
W = torch.randn((D, V), device=device, requires_grad=True)
Y = torch.randint(0, V, size=(N,), device=device)

logits_py = X @ W
loss_py = F.cross_entropy(logits_py, Y, reduction="none")

loss_py.sum().backward()

expected_dX = X.grad
expected_dW = W.grad

dX_triton = torch.zeros((N, D), device=device)
dW_triton = torch.zeros((D, V), device=device, dtype=torch.float32)
LSE_triton = torch.zeros((N,), device=device)
Loss_triton = torch.zeros((N,), device=device)

D_BLOCK = 64
V_BLOCK = 64

linear_cross_entropy_fwd_kernel[(N,)](
    X_ptr=X,
    X_row_stride=X.stride(0),
    W_ptr=W,
    W_row_stride=W.stride(0),
    W_col_stride=W.stride(1),
    Y_ptr=Y,
    dX_ptr=dX_triton,
    dX_row_stride=dX_triton.stride(0),
    dW_ptr=dW_triton,
    dW_row_stride=dW_triton.stride(0),
    dW_col_stride=dW_triton.stride(1),
    LSE_ptr=LSE_triton,
    Loss_ptr=Loss_triton,
    D=D,
    V=V,
    D_BLOCK=D_BLOCK,
    V_BLOCK=V_BLOCK,
    ignore_index=-100,
)

torch.testing.assert_close(Loss_triton, loss_py, atol=1e-7, rtol=1e-5)
print("Loss Match!")

torch.testing.assert_close(dX_triton, expected_dX, atol=1e-4, rtol=1e-5)
print("dX Match!")

torch.testing.assert_close(dW_triton, expected_dW, atol=1e-4, rtol=1e-5)
print("dW Match!")

Loss Match!
dX Match!
dW Match!
