In [1]:
import torch
from torch import nn
from torch.autograd.functional import jacobian

In [2]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f0464503970>

In [3]:
class Linear(nn.Module):
    def __init__(self, K):
        super().__init__()
        self.K = nn.Parameter(torch.tensor(K))
    
    def forward(self, x):
        return self.K * x ** 2

In [4]:
net = Linear(0.5)
x = torch.tensor(3.0)

with torch.set_grad_enabled(True):
    dnet_dx = jacobian(net, x)
print(dnet_dx)

assert x.grad is None
assert net.K.grad is None

# TODO(eric.cousineau): Seems like it's fine. I don't think we're
# going to waste computation / storage on extraneous stuff.

tensor(3.)


## Er, but need batched.

In [5]:
# Example of not being batched "as desired".
# In this case, we get a matrix, rather than "batching" at first.
xs = torch.tensor([1.0, 3.0])
with torch.set_grad_enabled(True):
    dnet_dxs = jacobian(net, xs)
print(dnet_dxs)

tensor([[1., 0.],
        [0., 3.]])


In [6]:
def batch_jacobian(func, x):
    """
    Computes batched Jacobian.

    Arugments:
        func: Function to compute Jacobian with.
        x: Batch of input, Tensor[N, ...]
        
    Adapted from:
    https://discuss.pytorch.org/t/computing-batch-jacobian-efficiently/80771/5
    """

    def func_sum(x):
        return func(x).sum(dim=0)

    J = jacobian(func_sum, x)
    if x.ndim > 1:
        J = torch.movedim(J, 1, 0)
    return J

In [7]:
xs = torch.tensor([1.0, 3.0])
with torch.set_grad_enabled(True):
    dnet_dxs = batch_jacobian(net, xs)
print(dnet_dxs)

tensor([1., 3.])


In [8]:
xs = torch.tensor([[1.0, 3.0], [2.0, 4.0]])
with torch.set_grad_enabled(True):
    dnet_dxs = batch_jacobian(net, xs)
print(dnet_dxs)

tensor([[[1., 0.],
         [0., 3.]],

        [[2., 0.],
         [0., 4.]]])
