# Extending PyTorch
### In this note we'll cover ways of extending torch.nn, torch.autograd, and writing custom C extensions utilizing our C libraries.

## Extending torch.autograd 

### Adding operations to *autograd* requires implementing a new *Function* subclass for each operation. Recall that *Fucntion*s are what *autograd* uses to compute the results and gradients, and encode the operation history. Every new fucntion requires you to implement 2 methods:
* forward() - the code that performs the operations. It can takes as many arguments as you want, all kinds of Python objects are accepted here.
* backward() - gradient formula.

In [23]:
import torch

# Inherit from Function
class LinearFunction(torch.autograd.Function):

    # Note that both forward and backward are @staticmethods
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        # These needs_input_grad checks are optional and there only to
        # improve efficiency. If you want to make your code simpler, you can
        # skip them. Returning gradients for inputs that don't require it is
        # not an error.
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)

        return grad_input, grad_weight, grad_bias
    
linear = LinearFunction.apply

_input = torch.randn((10, 4), dtype=torch.float, requires_grad=True)
print(_input.size())
weight = torch.randn((3, 4))
print(weight.size())

output = linear(_input, weight)
print(output.size())

torch.Size([10, 4])
torch.Size([3, 4])
torch.Size([10, 3])


## Extending torch.autograd.Function
### saved_tensors 在backward()中会用到
### needs_input_grad 表示输入的Tensor是否需要计算梯度
### num_inputs 表示forward()中传入参数的数量
### num_outputs 表示forward()中返回值的数量

In [25]:
class MulConstant(torch.autograd.Function):
    @staticmethod
    def forward(ctx, tensor, constant):
        ctx.constant = constant
        return tensor * constant
    
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output * ctx.constant, None
    
mul_constant = MulConstant.apply

_input = torch.tensor([1.], requires_grad=True)
constant = torch.tensor([2.])

y = mul_constant(_input, constant)
print(y, y.requires_grad)

tensor([2.], grad_fn=<MulConstantBackward>) True
