In [1]:
import torch
import triton
import triton.language as tl
from triton.runtime import driver
import tabulate

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
@triton.jit
def relu_forward_kernel(
    x_ptr, y_ptr, n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.where(x > 0, x, 0.0)
    tl.store(y_ptr + offsets, y, mask=mask)

In [4]:
@triton.jit
def relu_backward_kernel(
    x_ptr, grad_out_ptr, grad_in_ptr, n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    x = tl.load(x_ptr + offsets, mask=mask)
    grad_out = tl.load(grad_out_ptr + offsets, mask=mask)
    # propagate gradients if the input was positive
    grad_in = tl.where(x > 0, grad_out, 0.0)
    tl.store(grad_in_ptr + offsets, grad_in, mask=mask)


In [12]:
from torch.autograd import Function

class ReLU_Triton(Function):
    @staticmethod
    def forward(ctx, x):
        # triton expects flat array
        x_flat = x.contiguous().view(-1)
        y = torch.empty_like(x_flat)
        BLOCK_SIZE = 1024
        n_elements = x_flat.numel()

        grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)

        relu_forward_kernel[grid](x_flat, y, n_elements, BLOCK_SIZE=BLOCK_SIZE)

        ctx.save_for_backward(x_flat)
        return y.view_as(x)

    @staticmethod
    def backward(ctx, grad_out):
        (x_flat,) = ctx.saved_tensors
        grad_out_flat = grad_out.contiguous().view(-1)
        grad_in = torch.empty_like(x_flat)
        BLOCK_SIZE = 1024
        n_elements = x_flat.numel()

        grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)

        relu_backward_kernel[grid](x_flat, grad_out_flat, grad_in, n_elements, BLOCK_SIZE=BLOCK_SIZE)

        return grad_in.view_as(x_flat)


class TritonReLU(torch.nn.Module):
    def forward(self, x):
        return ReLU_Triton.apply(x)


In [15]:
# Shortcut
relu = TritonReLU()
x = torch.randn(2, device='cuda', requires_grad=True)
y = relu(x)
loss = y.sum()
loss.backward()

# Check:
print("x:", x)
print("y = relu(x):", y)
print("x.grad (should be 1 where x > 0):", x.grad)


x: tensor([-0.6888,  2.0654], device='cuda:0', requires_grad=True)
y = relu(x): tensor([0.0000, 2.0654], device='cuda:0', grad_fn=<ReLU_TritonBackward>)
x.grad (should be 1 where x > 0): tensor([0., 1.], device='cuda:0')
