In [51]:
import torch
import triton
import triton.language as tl
from triton.runtime import driver
import tabulate

In [52]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [53]:
@triton.jit
def cross_entropy_forward_kernel(
    logits_ptr,  # [B, C]
    targets_ptr,  # [B] (int)
    loss_ptr,     # [B]
    stride_batch,
    stride_class,
    num_classes,
    BLOCK_SIZE: tl.constexpr,
):
    row = tl.program_id(0)  # index over batch
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < num_classes

    # pointer to start of row
    row_ptr = logits_ptr + row * stride_batch

    # load logits
    logits = tl.load(row_ptr + col_offsets * stride_class, mask=mask, other=-float('inf'))

    # numerical stability trick
    max_logit = tl.max(logits, axis=0)
    logits = logits - max_logit

    exp_logits = tl.exp(logits)
    sum_exp = tl.sum(exp_logits, axis=0)
    log_sum_exp = tl.log(sum_exp)

    target_idx = tl.load(targets_ptr + row)
    true_logit = tl.load(row_ptr + target_idx * stride_class)

    loss = log_sum_exp - (true_logit - max_logit)
    tl.store(loss_ptr + row, loss)

In [54]:
@triton.jit
def cross_entropy_backward_kernel(
    logits_ptr,
    targets_ptr,
    grad_loss_ptr,
    grad_logits_ptr,
    stride_batch,
    stride_class,
    num_classes,
    BLOCK_SIZE: tl.constexpr,
):
    row = tl.program_id(0)
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < num_classes

    row_ptr = logits_ptr + row * stride_batch

    logits = tl.load(row_ptr + col_offsets * stride_class, mask=mask, other=-float('inf'))
    max_logit = tl.max(logits, axis=0)
    logits = logits - max_logit

    exp_logits = tl.exp(logits)
    sum_exp = tl.sum(exp_logits, axis=0)
    softmax = exp_logits / sum_exp

    target_idx = tl.load(targets_ptr + row)
    grad_loss = tl.load(grad_loss_ptr + row)
    grad = softmax
    grad = tl.where(col_offsets == target_idx, grad - 1.0, grad)
    grad = grad * grad_loss  # chain rule

    out_ptr = grad_logits_ptr + row * stride_batch
    tl.store(out_ptr + col_offsets * stride_class, grad, mask=mask)

In [55]:
class CrossEntropyTriton(torch.autograd.Function):
    @staticmethod
    def forward(ctx, logits, targets):
        B, C = logits.shape
        logits_ = logits.contiguous()
        targets_ = targets.contiguous()

        loss = torch.empty(B, device=logits.device, dtype=logits.dtype)
        BLOCK_SIZE = triton.next_power_of_2(C)

        grid = lambda meta: (B,)

        cross_entropy_forward_kernel[grid](
            logits_,
            targets_,
            loss,
            logits_.stride(0),
            logits_.stride(1),
            C,
            BLOCK_SIZE=BLOCK_SIZE,
        )

        ctx.save_for_backward(logits_, targets_, loss)
        return loss

    @staticmethod
    def backward(ctx, grad_output):
        logits, targets, loss = ctx.saved_tensors
        B, C = logits.shape
        grad_logits = torch.empty_like(logits)
        BLOCK_SIZE = triton.next_power_of_2(C)

        grid = lambda meta: (B,)

        cross_entropy_backward_kernel[grid](
            logits,
            targets,
            grad_output.contiguous(),
            grad_logits,
            logits.stride(0),
            logits.stride(1),
            C,
            BLOCK_SIZE=BLOCK_SIZE,
        )

        return grad_logits, None  # no grad wrt targets


In [56]:
class TritonCrossEntropyLoss(torch.nn.Module):
    def forward(self, logits, targets):
        return CrossEntropyTriton.apply(logits, targets).mean()

In [62]:
logits = torch.randn(4, 10, device='cuda', requires_grad=True)
targets = torch.randint(0, 10, (4,), device='cuda')

loss_fn = TritonCrossEntropyLoss()
loss = loss_fn(logits, targets)
loss.backward()

# Compare to PyTorch
logits_ref = logits.detach().clone().requires_grad_()
loss_ref = torch.nn.functional.cross_entropy(logits_ref, targets)
loss_ref.backward()

loss_diff = abs(loss.item() - loss_ref.item())
grad_diff = (logits.grad - logits_ref.grad).abs().max().item()


print("Loss diff:", loss_diff)
print("Grad diff (max):", grad_diff)

Loss diff: 0.0
Grad diff (max): 7.450580596923828e-09


# More efficient:

Reuses the logsumexp, takes into account padding

In [47]:
@triton.jit
def cross_entropy_forward_kernel(
    logits_ptr,  # [B, C]
    targets_ptr,  # [B] (int)
    loss_ptr,     # [B]
    logsumexp_ptr, # [B] - for saving
    stride_batch,
    stride_class,
    num_classes,
    BLOCK_SIZE: tl.constexpr,
):
    row = tl.program_id(0)
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < num_classes

    row_ptr = logits_ptr + row * stride_batch
    logsumexp_out_ptr = logsumexp_ptr + row
    loss_out_ptr = loss_ptr + row

    logits = tl.load(row_ptr + col_offsets * stride_class, mask=mask, other=-float('inf'))
    logits = logits.to(tl.float32)

    max_logit = tl.max(logits, axis=0)
    max_logit_safe = tl.where(max_logit == -float('inf'), 0.0, max_logit)
    shifted_logits = logits - max_logit_safe
    exp_logits = tl.exp(shifted_logits)
    sum_exp = tl.sum(exp_logits, axis=0)
    logsumexp_val = max_logit_safe + tl.log(sum_exp)
    tl.store(logsumexp_out_ptr, logsumexp_val)

    target_idx = tl.load(targets_ptr + row)
    true_logit = tl.load(row_ptr + target_idx * stride_class)
    true_logit = true_logit.to(tl.float32)

    loss = logsumexp_val - true_logit
    tl.store(loss_out_ptr, tl.where(target_idx != -100, loss, 0.0))

@triton.jit
def cross_entropy_backward_kernel(
    logits_ptr,
    targets_ptr,
    grad_loss_ptr,
    grad_logits_ptr,
    logsumexp_ptr,
    stride_batch,
    stride_class,
    num_classes,
    BLOCK_SIZE: tl.constexpr,
):
    row = tl.program_id(0)
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < num_classes

    row_ptr = logits_ptr + row * stride_batch
    logits = tl.load(row_ptr + col_offsets * stride_class, mask=mask, other=-float('inf'))
    logits = logits.to(tl.float32)

    logsumexp = tl.load(logsumexp_ptr + row)
    softmax = tl.exp(logits - logsumexp)

    target_idx = tl.load(targets_ptr + row)
    grad_loss = tl.load(grad_loss_ptr + row)

    grad = softmax
    grad = tl.where(col_offsets == target_idx, grad - 1.0, grad)
    grad = grad * tl.where(target_idx != -100, grad_loss, 0.0)

    out_ptr = grad_logits_ptr + row * stride_batch
    tl.store(out_ptr + col_offsets * stride_class, grad, mask=mask)

class CrossEntropyTriton(torch.autograd.Function):
    @staticmethod
    def forward(ctx, logits, targets):
        B, C = logits.shape
        logits_ = logits.contiguous()
        targets_ = targets.contiguous()

        loss = torch.empty(B, device=logits.device, dtype=logits.dtype)
        logsumexp = torch.empty_like(loss)
        BLOCK_SIZE = triton.next_power_of_2(C)

        grid = lambda meta: (B,)

        cross_entropy_forward_kernel[grid](
            logits_,
            targets_,
            loss,
            logsumexp,
            logits_.stride(0),
            logits_.stride(1),
            C,
            BLOCK_SIZE=BLOCK_SIZE,
        )

        ctx.save_for_backward(logits_, targets_, logsumexp)
        return loss

    @staticmethod
    def backward(ctx, grad_output):
        logits, targets, logsumexp = ctx.saved_tensors
        B, C = logits.shape
        grad_logits = torch.empty_like(logits)
        BLOCK_SIZE = triton.next_power_of_2(C)

        grid = lambda meta: (B,)

        cross_entropy_backward_kernel[grid](
            logits,
            targets,
            grad_output.contiguous(),
            grad_logits,
            logsumexp,
            logits.stride(0),
            logits.stride(1),
            C,
            BLOCK_SIZE=BLOCK_SIZE,
        )

        return grad_logits, None

class TritonCrossEntropyLoss(torch.nn.Module):
    def forward(self, logits, targets):
        return CrossEntropyTriton.apply(logits, targets).mean()

logits = torch.randn(4, 10, device='cuda', requires_grad=True)
targets = torch.randint(0, 10, (4,), device='cuda')

loss_fn = TritonCrossEntropyLoss()
loss = loss_fn(logits, targets)
loss.backward()

logits_ref = logits.detach().clone().requires_grad_()
loss_ref = torch.nn.functional.cross_entropy(logits_ref, targets)
loss_ref.backward()

loss_diff = abs(loss.item() - loss_ref.item())
grad_diff = (logits.grad - logits_ref.grad).abs().max().item()

print("Loss diff:", loss_diff)
print("Grad diff (max):", grad_diff)

Loss diff: 0.0
Grad diff (max): 1.4901161193847656e-08
