# Extending PyTorch

In this note we’ll cover ways of extending torch.nn, torch.autograd, and writing custom C extensions utilizing our C libraries.

Extending torch.autograd
Adding operations to autograd requires implementing a new Function subclass for each operation. Recall that Function s are what autograd uses to compute the results and gradients, and encode the operation history. Every new function requires you to implement 2 methods:

forward() - the code that performs the operation. It can take as many arguments as you want, with some of them being optional, if you specify the default values. All kinds of Python objects are accepted here. Variable arguments will be converted to Tensor s before the call, and their use will be registered in the graph. Note that this logic won’t traverse lists/dicts/any other data structures and will only consider Variables that are direct arguments to the call. You can return either a single Tensor output, or a tuple of Tensor s if there are multiple outputs. Also, please refer to the docs of Function to find descriptions of useful methods that can be called only from forward().
backward() - gradient formula. It will be given as many Variable arguments as there were outputs, with each of them representing gradient w.r.t. that output. It should return as many Variable s as there were inputs, with each of them containing the gradient w.r.t. its corresponding input. If your inputs didn’t require gradient (see needs_input_grad), or were non-Variable objects, you can return None. Also, if you have optional arguments to forward() you can return more gradients than there were inputs, as long as they’re all None.
Below you can find code for a Linear function from torch.nn, with additional comments:


In [9]:
# unsqueeze() inserts singleton dim at position given as parameter
input = torch.Tensor(2, 4, 3)
print("input size: ", input.size())
print("unsqueezed(0) input: ", input.unsqueeze(0).size())
print("unsqueezed(1) input: ", input.unsqueeze(1).size())

input size:  torch.Size([2, 4, 3])
unsqueezed(0) input:  torch.Size([1, 2, 4, 3])
unsqueezed(1) input:  torch.Size([2, 1, 4, 3])
