Let's do a tiny example of using a custom cuda kernel in a PyTorch model's forward and backward:

1. Write a tiny pytorch model
2. Write its fwd and bwd in cpp, and use it in pytorch
3. Write its fwd and bwd in cuda, and use it in pytorch

In [None]:
import torch
import torch.nn as nn

## 1. Pure PyTorch

In [None]:
class CustomLinear(nn.Module):
    def __init__(self, cin, cout):
        super().__init__()
        self.weights = nn.Parameter(torch.ones(cout, cin))

    def forward(self, x):
        return self.weights @ x

In [None]:
lin = CustomLinear(2,3)
lin

CustomLinear()

In [None]:
x = torch.ones(2)
x

tensor([1., 1.])

In [None]:
y = lin(x)
y

tensor([2., 2., 2.], grad_fn=<MvBackward0>)

In [None]:
y.sum().backward()

In [None]:
x.grad

In [None]:
lin.weights.grad

tensor([[1., 1.],
        [1., 1.],
        [1., 1.]])

## 2. Fwd and Bwd in Cpp:

In [None]:
from torch.utils.cpp_extension import load

In [None]:
custom_lin_cpp = load(
    name='custom_lin_cpp', 
    sources=['custom_linear.cpp'],
    build_directory='tmp',
    verbose=True
)
cpp_fwd = custom_lin_cpp.forward
cpp_bwd = custom_lin_cpp.backward

Emitting ninja build file tmp/build.ninja...
Building extension module custom_lin_cpp...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


ninja: no work to do.


Loading extension module custom_lin_cpp...


In [None]:
cpp_fwd

<function custom_lin_cpp.PyCapsule.forward>

In [None]:
cpp_bwd

<function custom_lin_cpp.PyCapsule.backward>

In [None]:
cpp_fwd(lin.weights, x.unsqueeze(0)) # cpp_fwd expects x to be a matrix

tensor([[2., 2., 2.]], grad_fn=<MmBackward0>)

In [None]:
d = torch.ones(1,3) # grad
cpp_bwd(d, lin.weights, x.unsqueeze(0))

[tensor([[1., 1.],
         [1., 1.],
         [1., 1.]]),
 tensor([[3., 3.]], grad_fn=<MmBackward0>)]

In [None]:
class CustomLinearCpp(nn.Module):
    def __init__(self, cin, cout):
        super().__init__()
        self.weights = nn.Parameter(torch.ones(cout, cin))

    def forward(self, x):
        if len(x.shape)==1: x = x.unsqueeze(0) # cpp_fwd expects x to be a matrix
        return cpp_fwd(lin.weights, x)

In [None]:
lin = CustomLinearCpp(2,3)
x = torch.ones(2)

In [None]:
y = lin(x)
y

tensor([[2., 2., 2.]], grad_fn=<MmBackward0>)

In [None]:
y.sum().backward()

Hmm, I expected a runtime error, because I thought the backwards-function isn't known.

In [None]:
x.grad, lin.weights.grad

(None,
 tensor([[1., 1.],
         [1., 1.],
         [1., 1.]]))

Interestingly, the grads could be computed. It seems `torch::mm` in cpp has a defined backwards.

In [None]:
class CustomLinFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, weights, x):
        ctx.save_for_backward(weights, x)
        return cpp_fwd(weights, x)

    @staticmethod
    def backward(ctx, d):
        d_x, d_weights = cpp_bwd(d, *ctx.saved_tensors) # need to destructure into 2 elems, otherwise grad engine thinks we're returning 1 grad, but expects 2
        return d_x, d_weights

In [None]:
class CustomLinearCpp(nn.Module):
    def __init__(self, cin, cout):
        super().__init__()
        self.weights = nn.Parameter(torch.ones(cout, cin))

    def forward(self, x):
        if len(x.shape)==1: x = x.unsqueeze(0) # cpp_fwd expects x to be a matrix
        return CustomLinFunction.apply(self.weights, x)

In [None]:
lin = CustomLinearCpp(2,3)
x = torch.ones(2)

In [None]:
y = lin(x)
y

In [None]:
y.sum().backward()

In [None]:
x.grad, lin.weights.grad

Note: I verified `lin.weights.grad` changes when the cpp-backward is changed (eg doubled).

## 3. Fwd and Bwd in Cuda:

In [None]:
custom_lin_cuda = load(
    name='custom_lin_cuda', 
    sources=['custom_linear_cuda.cpp', 'custom_linear.cu'],
    build_directory='tmp',
    verbose=True
)
cuda_fwd = custom_lin_cuda.forward
cuda_bwd = custom_lin_cuda.backward

No we can use `cuda_fwd` / `cuda_bwd` the same way as `cpp_fwd` / `cpp_bwd` above.