Let's do a tiny example of using a custom triton kernel in a PyTorch model's forward and backward.

Steps:
1. Define tiny model in pure pytorch
2. Write custom autograd function that is used in the fwd / bwd
2. Write triton kernels for fwd / bwd, and call them from the custom aurograd function

In [None]:
import torch
import torch.nn as nn

from copy import copy

torch.set_printoptions(linewidth=200, precision=0, sci_mode=False)

**1. Define tiny model in pure pytorch**

In [None]:
class CustomLinear(nn.Module):
    def __init__(self, cin, cout):
        super().__init__()
        self.weights = nn.Parameter(torch.ones(cout, cin))

    def forward(self, x):
        return self.weights @ x

In [None]:
m,n = 4,2 # out_size, in_size (use powers of 2, as they're easier for triton)

In [None]:
lin = CustomLinear(n,m).to('cuda')
print(lin)
x = torch.ones(n, device='cuda')
print('x:', x)
y = lin(x)
print('y:', y)

CustomLinear()
x: tensor([1., 1.], device='cuda:0')
y: tensor([2., 2., 2., 2.], device='cuda:0', grad_fn=<MvBackward0>)


In [None]:
y.retain_grad() # retain grad for non-leaf variable, to use it below as input for kernel
loss = y.sum()
loss.backward()

In [None]:
y_grad = copy(y.grad)
y_grad

tensor([1., 1., 1., 1.], device='cuda:0')

In [None]:
print('dx:', x.grad)
print('dw:',lin.weights.grad)

dx: None
dw: tensor([[1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.]], device='cuda:0')


**2. Write custom autograd function that is used in the fwd / bwd**

Now, we'll create a custom `torch.autograd.Function` which manually computes the gradient for our custom operation. This function will be used by the autograd engine, when our operation is encountered.

In [None]:
class LinearFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, w, x):
        print('fwd of LinearFn called')
        ctx.save_for_backward(w, x)
        return w@x # here, we'll later use a function that runs on gpu

    @staticmethod
    def backward(ctx, d):
        print('bwd of LinearFn called')
        w, x = ctx.saved_tensors
        # here, we'll later use a function that runs on gpu:
        dx = d@w
        dw = d.t()[:,None]@x[None,:]
        return dw, dx

In [None]:
class CustomLinear(nn.Module):
    def __init__(self, cin, cout):
        super().__init__()
        print('fyi: This module uses a manual autograd function')
        self.weights = nn.Parameter(torch.ones(cout, cin))

    def forward(self, x):
        return LinearFn.apply(self.weights, x)

In [None]:
lin = CustomLinear(n,m).to('cuda')
print(lin)
x = torch.ones(n, device='cuda')
print('x:', x)
y = lin(x)
print('y:', y)

fyi: This module uses a manual autograd function
CustomLinear()
x: tensor([1., 1.], device='cuda:0')
fwd of LinearFn called
y: tensor([2., 2., 2., 2.], device='cuda:0', grad_fn=<LinearFnBackward>)


In [None]:
loss = y.sum()
loss.backward()

bwd of LinearFn called


In [None]:
print('dx:', x.grad)
print('dw:',lin.weights.grad)

dx: None
dw: tensor([[1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.]], device='cuda:0')


**3. Write triton kernels for fwd / bwd, and call them from the custom aurograd function**

Now, we'll use gpu-backed functions in the manual gradient computation.

In [None]:
import os
#os.environ['TRITON_INTERPRET'] = '1'

import triton
import triton.language as tl

from triton_util import *

In [None]:
@triton.jit
def fwd_kernel(w_ptr, x_ptr, out_ptr, m, n: tl.constexpr, bs: tl.constexpr):
    pid = tl.program_id(0)

    offs_m = get_1d_offset(bs, pid)       # split m axis into chunks of size `bs`, take chunk no 'pid'
    offs_n = get_1d_offset(n, 0)          # entire n axis

    offs_w = get_2d_offset(offs_m, offs_n, stride_0=n)

    mask_out    = get_1d_mask(offs_m, m)
    mask_x      = get_1d_mask(offs_n, n)
    mask_w = get_2d_mask(offs_m, offs_n, m, n)

    x = tl.load(x_ptr + offs_n, mask_x) # shape (n)
    w = tl.load(w_ptr + offs_w, mask_w) # shape (m,n)

    # note: we can't use tl.dot as it require all dims to be >= 16, so let's do manual matmul
    out = tl.sum(tl.sum(w[:,:,None] * x[None, :, None],1), 1) # shape (m,n),(n) -> (m,n,1),(1,n,1) -> (m,1) -> (m)

    tl.store(out_ptr+offs_m, out, mask_out)

In [None]:
def fwd_gpu(weight, x):
    #shapes: (m,n) @ (n) -> (m)
    m, n = weight.shape
    out = torch.zeros(m, device='cuda')
    threads = 32
    blocks = (cdiv(m,threads),)
    assert_tensors_gpu_ready(weight, x, out)
    fwd_kernel[blocks](weight, x, out, m, n, threads)
    return out

Check the gpu-backed fwd producess the same result:

In [None]:
y = fwd_gpu(lin.weights.data, x)
y

tensor([2., 2., 2., 2.], device='cuda:0')

In [None]:
@triton.jit
def bwd_kernel(d_ptr, w_ptr, x_ptr, dw_ptr, dx_ptr, m: tl.constexpr, n, bs: tl.constexpr):
    # shapes: d = (b,m), w = (m,n), x = (b,n)
    pid = tl.program_id(0)

    offs_m = get_1d_offset(m, 0)    # entire m axis
    offs_n = get_1d_offset(bs, pid) # split n axis into chunks of size `bs`, take chunk no 'pid'

    offs_w = get_2d_offset(offs_m, offs_n, stride_0=n)

    mask_d = get_1d_mask(offs_m, m)
    mask_x = get_1d_mask(offs_n, n)
    mask_w = get_2d_mask(offs_m, offs_n, m, n)

    d = tl.load(d_ptr + offs_m, mask_d) # shape (m)
    x = tl.load(x_ptr + offs_n, mask_x) # shape (n)
    w = tl.load(w_ptr + offs_w, mask_w) # shape (m,n)

    # note: we can't use tl.dot as it require all dims to be >= 16, so let's do manual matmul 
    dx = tl.sum(tl.sum(d[None,:,None] * w[None, :, :], 1), 0) # shape (m),(m,n) -> (1,m,1),(1,m,n) -> (1,n) -> (n) 
    dw = d[:,None] * x[None, :]                               # shape (m),(n) -> (m,1),(1,n) -> (m,n)

    tl.store(dx_ptr+offs_n, dx, mask_x)
    tl.store(dw_ptr+offs_w, dw, mask_w)

In [None]:
def bwd_gpu(d, weight, x):
    d = d.contiguous() # autograd can return non-contiguous grads
    m, n = weight.shape
    dx      = torch.zeros_like(x, device='cuda')
    dweight = torch.zeros_like(weight, device='cuda')
    threads = 32
    blocks = (cdiv(n, threads),)
    assert_tensors_gpu_ready(d, weight, x, dweight, dx)
    bwd_kernel[blocks](d, weight, x, dweight, dx, m, n, threads)
    return dweight, dx

Check the gpu-backed bwd producess the same result:

In [None]:
dw, dx = bwd_gpu(d=y_grad, weight=lin.weights.data, x=x)
dx, dw

(tensor([4., 4.], device='cuda:0'),
 tensor([[1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.]], device='cuda:0'))

Now use them in a custom autograd function:

In [None]:
class LinearFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, w, x):
        print('LinearFn.forward with gpu called')
        ctx.save_for_backward(w, x)
        return fwd_gpu(w, x) # using gpu-backed fwd

    @staticmethod
    def backward(ctx, d):
        print('LinearFn.backward with gpu called')
        dw, dx = bwd_gpu(d, *ctx.saved_tensors) # using gpu-backed bwd
        return dw, dx

In [None]:
class CustomLinear(nn.Module):
    def __init__(self, cin, cout):
        super().__init__()
        print('fyi: This module uses a manual gpu-backed autograd function')
        self.weights = nn.Parameter(torch.ones(cout, cin))

    def forward(self, x):
        return LinearFn.apply(self.weights, x)

In [None]:
lin = CustomLinear(n,m).to('cuda')
print(lin)
x = torch.ones(n, device='cuda')
print(x)
y = lin(x)
print(y)

fyi: This module uses a manual gpu-backed autograd function
CustomLinear()
tensor([1., 1.], device='cuda:0')
LinearFn.forward with gpu called
tensor([2., 2., 2., 2.], device='cuda:0', grad_fn=<LinearFnBackward>)


In [None]:
y.sum().backward()

LinearFn.backward with gpu called


In [None]:
x.grad, lin.weights.grad

(None,
 tensor([[1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.]], device='cuda:0'))