In [14]:
import torch
from torch.nn import Linear, Sequential
from torch.autograd import forward_ad
from torch.nn.utils import parameters_to_vector
from torch.nn.utils._stateless import functional_call

class MLP(torch.nn.Module):
    def __init__(self, d, k):
        super().__init__()
        layers = []
        layers.append(torch.nn.Linear(d, 3, bias = False))
        layers.append(torch.nn.ReLU())
        layers.append(torch.nn.Linear(3, k, bias = False))
        self.net = torch.nn.Sequential(*layers)

    def forward(self, X):
        return self.net(X)

In [40]:
torch.manual_seed(111)

# Forward mode
@torch.enable_grad()
def _compute_gradient(model, named_params, inputs, targets, criterion):
    preds = functional_call(model, named_params, inputs)  # model(inputs)
    loss = criterion(preds, targets)
    return torch.autograd.grad(loss, named_params.values())


def my_forward_ad(model, inputs, targets, criterion, multiplier):
    with forward_ad.dual_level():
        named_dual_params = {name: forward_ad.make_dual(p, multiplier[i])
                             for i, (name, p) in enumerate(model.named_parameters())}
        gradient = _compute_gradient(
            model, named_dual_params, inputs, targets, criterion)
    return gradient
    
def criterion(true, output):   # Usual criteria give NotImplementedErrors (see below)
    return torch.abs(true - outputs).mean()
d = 5 # dimension of inputs
k = 2 # dimension of outputs
N = 4 # number of data
model = MLP(d, k)
x = torch.arange(5) + 0.1
true = torch.arange(2) + 0.3

multiplier = [torch.randn_like(v) for v in model.parameters()] # `v` in HVP. The factor of 1e5 is to amplif
grad2 = my_forward_ad(model, x, true, criterion, multiplier)
grad2

(tensor([[ 2.1394e-02,  2.3533e-01,  4.4927e-01,  6.6320e-01,  8.7714e-01],
         [ 6.7076e-04,  7.3783e-03,  1.4086e-02,  2.0793e-02,  2.7501e-02],
         [-1.0195e-02, -1.1214e-01, -2.1409e-01, -3.1604e-01, -4.1798e-01]]),
 tensor([[0.4125, 0.2169, 0.4556],
         [0.4125, 0.2169, 0.4556]]))

In [41]:
torch.manual_seed(111)
# Simple pytorch implementation calculating dL/dw, w parameters
# using backward propagation

d = 5 # dimension of inputs
k = 2 # dimension of outputs
N = 4 # number of data
model = MLP(d, k)
x = torch.arange(5) + 0.1
true = torch.arange(2) + 0.3
outputs = model(x)

loss = torch.abs(true - outputs).mean()
loss.backward()
for param in model.parameters():
    print(param.grad)

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0152, 0.1672, 0.3192, 0.4712, 0.6232],
        [0.0021, 0.0227, 0.0433, 0.0639, 0.0845]])
tensor([[-0.0000, -0.5192, -0.0785],
        [-0.0000, -0.5192, -0.0785]])
