In [None]:
import torch
import numpy

For a gentle introduction see [PyTorch extension](https://pytorch.org/docs/stable/notes/extending.html) tutorial.

Source for `torch.autograd.Function` available [here](https://github.com/pytorch/pytorch/blob/master/torch/autograd/function.py).
These are the two that we have to override:

```python
@staticmethod
def forward(ctx, *args, **kwargs):
    """Performs the operation.
    This function is to be overridden by all subclasses.
    It must accept a context ctx as the first argument, followed by any
    number of arguments (tensors or other types).
    The context can be used to store tensors that can be then retrieved
    during the backward pass.
    """
    raise NotImplementedError

@staticmethod
def backward(ctx, *grad_outputs):
    """Defines a formula for differentiating the operation.
    This function is to be overridden by all subclasses.
    It must accept a context :attr:`ctx` as the first argument, followed by
    as many outputs did :func:`forward` return, and it should return as many
    tensors, as there were inputs to :func:`forward`. Each argument is the
    gradient w.r.t the given output, and each returned value should be the
    gradient w.r.t. the corresponding input.
    The context can be used to retrieve tensors saved during the forward
    pass. It also has an attribute :attr:`ctx.needs_input_grad` as a tuple
    of booleans representing whether each input needs gradient. E.g.,
    :func:`backward` will have ``ctx.needs_input_grad[0] = True`` if the
    first input to :func:`forward` needs gradient computated w.r.t. the
    output.
    """
    raise NotImplementedError
```    

In [None]:
# Custom addition module
class MyAdd(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x1, x2):
        # ctx is a context where we can save
        # computations for backward.
        ctx.save_for_backward(x1, x2)
        return x1 + x2

    @staticmethod
    def backward(ctx, grad_output):
        x1, x2 = ctx.saved_tensors
        grad_x1 = grad_output * torch.ones_like(x1)
        grad_x2 = grad_output * torch.ones_like(x2)
        # need to return grads in order 
        # of inputs to forward (excluding ctx)
        return grad_x1, grad_x2

In [None]:
# Let's try out the addition module
x1 = torch.ones((3), requires_grad=True)
x2 = torch.ones((3), requires_grad=True)
myadd = MyAdd.apply  # aliasing the apply method
y = myadd(x1, x2)
z = y.mean()
z.backward()
print(x1.grad)
print(x2.grad)

In [None]:
# Custom split module
class MySplit(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        x1 = x.clone()
        x2 = x.clone()
        return x1, x2
        
    @staticmethod
    def backward(ctx, grad_x1, grad_x2):
        x = ctx.saved_tensors[0]
        print(grad_x1)
        print(grad_x2)
        return grad_x1 + grad_x2

In [None]:
# Let's try out the split module
x = torch.ones((4), requires_grad=True)
split = MySplit.apply
x1, x2 = split(x)
y = x1 + x2
z = y.mean()
z.backward()
print(x.grad)

In [None]:
# Custom argmax module
class MyArgMax(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        # example where we explicitly use non-torch code
        argmax = x.detach().numpy().argmax()
        argmax_onehot = torch.zeros_like(x)
        argmax_onehot[argmax] = 1
        ctx.save_for_backward(argmax_onehot)
        return argmax_onehot
    @staticmethod
    def backward(ctx,grad_output):
        argmax_onehot = ctx.saved_tensors[0]
        return grad_output * argmax_onehot

In [None]:
# Let's try out the argmax module
x = torch.randn((5), requires_grad=True)
print(x)
myargmax = MyArgMax.apply
y = myargmax(x)
z = y.sum()
z.backward()
print(x.grad)