In [1]:
import torch
from torch import nn

In [2]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f8fd067a8e0>

In [3]:
class Linear(nn.Module):
    def __init__(self, K):
        super().__init__()
        self.K = nn.Parameter(torch.tensor(K))
    
    def forward(self, x):
        return self.K * x ** 2

In [4]:
def gradient(net, x):
    # Adapted from:
    # https://discuss.pytorch.org/t/computing-batch-jacobian-efficiently/80771/5
    create_graph = torch.is_grad_enabled()
    x = x.detach().requires_grad_(True)
    # TODO(eric.cousineau): Keep storage for dnet_dx?
    with torch.set_grad_enabled(True):
        dnet_dx, = torch.autograd.grad(net(x).sum(), x, create_graph=create_graph)
    return dnet_dx

In [5]:
K = 0.5
net = Linear(K)
x = torch.tensor(3.0)
dnet_dx = gradient(net, x)

print(dnet_dx)
expected = 2 * K * x
torch.testing.assert_allclose(dnet_dx, expected, atol=1e-9, rtol=0.0)

assert x.grad is None
assert net.K.grad is None

tensor(3.)


## Test w/ batching

In [6]:
# Example of not being batched "as desired".
# In this case, we get a matrix, rather than "batching" at first.
xs = torch.tensor([1.0, 3.0])
dnet_dxs = gradient(net, xs)
print(dnet_dxs)

tensor([1., 3.])


In [7]:
xs = torch.tensor([1.0, 3.0])
dnet_dxs = gradient(net, xs)
print(dnet_dxs)

tensor([1., 3.])


In [8]:
# Works.
xs = torch.tensor([[1.0, 3.0], [2.0, 4.0]])
dnet_dxs = gradient(net, xs)
print(dnet_dxs)

tensor([[1., 3.],
        [2., 4.]])


In [9]:
torch.random.manual_seed(0)
xs = torch.rand((100, 10))
if net.K.grad is not None:
    net.K.grad.zero_()
with torch.set_grad_enabled(True):
    tmp = gradient(net, xs)
    tmp.sum().backward()

print(net.K.grad)
# dL/dx = 2 K x
# d(dL/dx)/dK = 2 x
expected = (2 * xs).sum()
torch.testing.assert_allclose(net.K.grad, expected, atol=1e-8, rtol=0)

tensor(1001.8497)
