In this nb, I play around with PyTorch's autograd engine to get a sense of how complicated it would be to get gradient functions out of it (instead of gradients evaluated at a single point).

**Edit:** I dediced to manually derive the gradient functions, and verifiy them against the autograd engine. See `verify_derivative.ipynb`.

In [None]:
import torch
import torch.nn as nn

from fastcore.foundation import L
from fastcore.basics import strcat

In [None]:
class CustomModule(nn.Module):
    def __init__(self, n,r,m):
        super().__init__()
        self.a = nn.Parameter(torch.randn(r,n))
        self.b = nn.Parameter(torch.randn(m,r))
        self.c = nn.Parameter(torch.randn(m))

    def forward(self, x):
        # (m,r) @ (r,n) @ (n) + (m)
        return self.b@self.a@x + self.c

In [None]:
def shapes(*ts): return strcat(L(ts).map(lambda t: list(t.shape)), "; ")

In [None]:
module = CustomModule(n=2, r=1, m=3)
x = torch.randn(2)

shapes(module.a, module.b, module.c, x)

'[1, 2]; [3, 1]; [3]; [2]'

In [None]:
y = module(x)
shapes(y), y

('[3]', tensor([5.1279, 0.0990, 4.6005], grad_fn=<AddBackward0>))

In [None]:
module = CustomModule(n=2, r=1, m=3)
x = torch.randn(2)
y = module(x)
l = y.sum()

In [None]:
# Recursively print the backward graph 
def print_graph(node, indent=0):
    print(' ' * indent + str(node))
    for next_node, _ in node.next_functions:
        if next_node is not None:
            print_graph(next_node, indent + 4)

# Start from the output's grad_fn
print_graph(l.grad_fn)

<SumBackward0 object>
    <AddBackward0 object>
        <MvBackward0 object>
            <MmBackward0 object>
                <AccumulateGrad object>
                <AccumulateGrad object>
        <AccumulateGrad object>


In [None]:
l.grad_fn??

[0;31mSignature:[0m   [0ml[0m[0;34m.[0m[0mgrad_fn[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mType:[0m        SumBackward0
[0;31mString form:[0m <SumBackward0 object>
[0;31mDocstring:[0m   <no docstring>

In [None]:
l.grad_fn._saved_self_sym_sizes

(3,)

In [None]:
for a in [
    'metadata',
    'name',
    'next_functions',
    'register_hook',
    'register_prehook',
    'requires_grad'
]:
    obj = getattr(l.grad_fn, a)
    if callable(obj):
        try:
            result = obj()
            ty = 'mth - called'
        except TypeError:
            result = obj
            ty = 'mth'
    else:
        result = obj
        ty = 'var'
    print(f'{a:<17} ({ty:<12}) : {result}')


metadata          (var         ) : {}
name              (mth - called) : SumBackward0
next_functions    (var         ) : ((<AddBackward0 object>, 0),)
register_hook     (mth         ) : <built-in method register_hook of SumBackward0 object>
register_prehook  (mth         ) : <built-in method register_prehook of SumBackward0 object>
requires_grad     (var         ) : True


----

In [None]:
import requests
import yaml

url = 'https://raw.githubusercontent.com/pytorch/pytorch/a55d63659ad0b9a14cbf5b495464994a9180c988/tools/autograd/derivatives.yaml'
response = requests.get(url)
data = L(yaml.safe_load(response.text))

In [None]:
ops = L(
    'AccumulateGrad',
    'AddBackward0',
    'DivBackward0',
    'MmBackward0',
    'MulBackward0',
    'SqueezeBackward4',
    'SumBackward0',
    'TBackward0',
    'UnsqueezeBackward0'
)
ops

(#9) ['AccumulateGrad','AddBackward0','DivBackward0','MmBackward0','MulBackward0','SqueezeBackward4','SumBackward0','TBackward0','UnsqueezeBackward0']

In [None]:
def short_name(full_name):
    full_name = full_name.lower()
    try:    return full_name.split('backward')[0]
    except: return full_name

In [None]:
ops = ops.map(short_name)
ops

(#9) ['accumulategrad','add','div','mm','mul','squeeze','sum','t','unsqueeze']

In [None]:
def show_names(txt):
    for l in data.attrgot('name').filter(lambda o: o.startswith(txt)): print(l)

In [None]:
for op in ops:
    print(f'Derivative rules for {op}:')
    rules = data.attrgot('name').filter(lambda o: o.startswith(op))
    if rules:
        for r in rules: print(f'\t- {r}')
    else:
        print('\t-none')

Derivative rules for accumulategrad:
	-none
Derivative rules for add:
	- add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
	- add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
	- addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
	- addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor
	- addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor
	- addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
	- addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor
	- addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
Derivative rules for div:
	- div.Tensor(Tensor self, Tensor other) -> Tensor
	- div.Scalar(Tensor self, Scalar other) -> Tensor
	- div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor
	- div.Scalar_mode(Tensor self, Scalar other, 

In [None]:
relevant_op_names = L('add.Tensor div.Tensor mm mul.Tensor mul.Scalar squeeze.dim sum t unsqueeze'.split(' '))
op_names

(#9) ['add.Tensor','div.Tensor','mm','mul.Tensor','mul.Scalar','squeeze.dim','sum','t','unsqueeze']

In [None]:
from pprint import pprint

for op in relevant_op_names:
    print(f'--- Derivative rules for {op}:')
    rules = data.filter(lambda o: o['name'].startswith(op+'('))
    if len(rules)==0: print('none')
    assert len(rules)==1
    pprint(rules[0], width=200)
    print()

--- Derivative rules for add.Tensor:
{'name': 'add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor',
 'other': 'handle_r_to_c(other.scalar_type(), maybe_multiply(grad, alpha.conj()))',
 'result': 'self_t + maybe_multiply(other_t, alpha)',
 'self': 'handle_r_to_c(self.scalar_type(), grad)'}

--- Derivative rules for div.Tensor:
{'name': 'div.Tensor(Tensor self, Tensor other) -> Tensor',
 'other': 'div_tensor_other_backward(grad, self, other)',
 'result': '(self_t - other_t * result) / other_p',
 'self': 'div_tensor_self_backward(grad, other, self.scalar_type())'}

--- Derivative rules for mm:
{'mat2': 'mm_mat2_backward(grad, self, mat2.sym_sizes(), mat2.sym_strides(), mat2.layout(), 1)',
 'name': 'mm(Tensor self, Tensor mat2) -> Tensor',
 'result': 'at::mm(self_t, mat2_p) + at::mm(self_p, mat2_t)',
 'self': 'mm_mat1_backward(grad, mat2, self.sym_sizes(), self.sym_strides(), self.layout(), 1)'}

--- Derivative rules for mul.Tensor:
{'name': 'mul.Tensor(Tensor self, Tensor 