In [1]:
import torch, numpy
n_walkers = 20
n_particles = 2
n_dim = 3
torch.set_default_tensor_type(torch.DoubleTensor)

Input tensor:

In [136]:
x = torch.rand((n_walkers, n_particles, n_dim), requires_grad=True, dtype=torch.float64)

In [3]:
class DeepSets(torch.nn.Module):
    
    def __init__(self):
        torch.nn.Module.__init__(self)
        
        self.individual_net = torch.nn.ModuleList([
            torch.nn.Linear(n_dim, 16, bias=False),
            torch.nn.Linear(16, 32, bias=False)
        ]
        )
        
        self.aggregate_net = torch.nn.ModuleList([
            torch.nn.Linear(32, 16, bias=False),
            torch.nn.Linear(16, 1, bias=False)
        ]
        )
        
    def forward(self, inputs):
        
        #split the inputs along the particle dimension (1):
        particles = torch.chunk(inputs, n_particles, axis=1)
        
        particles = [torch.reshape(p, (-1,n_dim)) for p in particles]
        
        individuals = []
        for p in particles:
            this_i = p
            for l in self.individual_net:
                this_i = torch.tanh(l(this_i))
            individuals.append(this_i)
#         individuals = [self.individual_net(p) for p in particles]

        concatd = torch.stack(individuals, dim=-1)
        
        
        # Sum over the latent space:
        summed = torch.sum(concatd, dim=-1)
        
        output = summed
        for l in self.aggregate_net:
            output = l(output)
#         output = self.(summed)
        
        return output.reshape((-1))
        

In [4]:
d = DeepSets()

In [5]:
print(x.requires_grad)
o = d(x)

True


In [6]:
o.shape

torch.Size([20])

In [7]:
o.grad_fn

<ReshapeAliasBackward0 at 0x7f8054234080>

In [8]:
for param_tensor in d.parameters():
    jac_i = torch.autograd.functional.jacobian(d, x)

In [9]:
print(jac_i.shape)

torch.Size([20, 20, 2, 3])


In [10]:
o = d(x)
params = list(d.parameters())

In [11]:
print(o)

tensor([-0.0758, -0.0851, -0.0414, -0.0907, -0.0402, -0.0218, -0.0074, -0.0455,
         0.0101, -0.0610, -0.0351, -0.0358, -0.0311, -0.0416, -0.0185, -0.0245,
        -0.0054, -0.0137,  0.0060, -0.0059], grad_fn=<ReshapeAliasBackward0>)


In [12]:
outputs = torch.tensor_split(torch.flatten(torch.eye(n_walkers)), n_walkers, axis=0)
print(outputs)
jac = torch.autograd.grad(o, params, retain_graph=True, grad_outputs=outputs)

(tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.]), tensor([0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.]), tensor([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.]), tensor([0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.]), tensor([0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.]), tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.]), tensor([0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.]), tensor([0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.]), tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.]), tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.]), tensor([0., 0., 0.,

In [13]:
jac[1]

tensor([[-8.9194e-04,  7.4605e-04, -3.8275e-03,  4.8174e-03, -2.7491e-04,
         -3.0629e-03,  1.8783e-03, -1.6623e-04,  4.2345e-03,  5.2004e-03,
          2.9563e-03, -2.3737e-03, -1.1785e-03, -8.6339e-04,  5.1584e-03,
         -1.3632e-06],
        [ 1.3119e-02, -1.0909e-02,  5.6119e-02, -7.0482e-02,  3.9984e-03,
          4.4917e-02, -2.7500e-02,  2.3416e-03, -6.1993e-02, -7.6153e-02,
         -4.3278e-02,  3.4787e-02,  1.7267e-02,  1.2622e-02, -7.5508e-02,
         -3.0392e-05],
        [-9.8766e-03,  8.3813e-03, -4.2718e-02,  5.4045e-02, -3.1284e-03,
         -3.4166e-02,  2.1036e-02, -2.0336e-03,  4.7435e-02,  5.8220e-02,
          3.3120e-02, -2.6523e-02, -1.3173e-02, -9.7051e-03,  5.7804e-02,
         -1.0915e-04],
        [-7.2110e-03,  5.9622e-03, -3.0751e-02,  3.8541e-02, -2.1738e-03,
         -2.4617e-02,  1.5048e-02, -1.2323e-03,  3.3920e-02,  4.1677e-02,
          2.3678e-02, -1.9053e-02, -9.4559e-03, -6.8966e-03,  4.1309e-02,
          4.3362e-05],
        [ 5.1995e-03

In [14]:
jac[-1]

tensor([[ 0.0733,  0.2283,  0.0228, -0.2164, -0.0393, -0.0296, -0.1130,  0.0097,
          0.0356, -0.1698, -0.0225, -0.0267, -0.2118,  0.0152, -0.0715,  0.0792]])

In [15]:
numpy.prod(outputs[0].shape)

20

In [16]:
# Here's a sort of trick: scale out the batch size by n_ouputs:

In [17]:
params = d.parameters()

In [18]:
shapes = (p.size() for p in params)
n_params = numpy.sum([numpy.prod(s) for s in shapes])


In [19]:
# Here's the stupid way:

In [300]:
def jacobian_reverse(inputs, n_params, model):
    '''
    This computes a backward pass once per input in a way that builds up the full
    jacobian matrix (ninputs x nparameters) one input at a time.
    
    For deep networks / many parameters, and for just a few inputs, this is more efficient.
    '''
    n_walkers = inputs.size()[0]
    jac_output = torch.zeros((n_walkers, n_params)) 
    o = model(inputs)

    for i_walker in range(n_walkers):
        grad_outputs = torch.zeros_like(o)
        grad_outputs[i_walker] = 1.0
        single_jac = torch.autograd.grad(o, model.parameters(), retain_graph=True, grad_outputs=grad_outputs)
        flattened_line = torch.cat([s.flatten() for s in single_jac])
        jac_output[i_walker, :] = flattened_line
    return jac_output

In [323]:
j_bkwd = jacobian_reverse(x, n_params, d)

In [341]:
def jacobian_forward(inputs, n_params, model):
    '''
    https://j-towns.github.io/2017/06/12/A-new-trick.html
    
    Based on the above trick, this uses two backward passes to arrive at a forward-mode gradient.
    The jacobian then is constructed one parameter at a time (instead of one input at a time).
    For smaller networks, or when n_params << n_inputs, this should be more efficient.
    
    '''
    network_output = model(inputs)
    n_walkers = inputs.size()[0]
    jac_output = torch.zeros((n_walkers, n_params)) 
    
    # Loop over layers in the model.
    # Keep track of the column of the jacobian:
    running_column_index = 0
    for i_layer, layer in enumerate(model.parameters()):
        
        # How many parameters in this layer?
        n_params_local = numpy.prod(layer.shape)
        
        for i_weight in range(n_params_local):
            v = torch.ones_like(network_output, requires_grad=True)
            u = torch.zeros_like(layer).flatten()
            # Need to set the i^th index to 1.0, but it's trick to do so without reshaping:
            u[i_weight] = 1.0
            u = u.reshape(layer.shape)
            # First backward pass:
            vjp = torch.autograd.grad(network_output, layer, grad_outputs=v, create_graph=True)[0]
            # Second backward pass:
            output = torch.autograd.grad(vjp, v, grad_outputs=u)[0]
            jac_output[:,running_column_index] = output
            running_column_index += 1
    return jac_output

In [342]:
j_fwd = jacobian_forward(x, n_params, d)

In [343]:
torch.max(torch.abs(j_fwd - j_bkwd))

RuntimeError: The size of tensor a (1000) must match the size of tensor b (20) at non-singleton dimension 0

In [344]:
# Which is more efficient?
n_walkers = 20
n_particles = 2
n_dim = 3

In [345]:
x = torch.rand((n_walkers, n_particles, n_dim), requires_grad=True, dtype=torch.float64)

In [346]:
print(x.shape)
print(n_params)

torch.Size([20, 2, 3])
1088


In [347]:
%timeit j_fwd = jacobian_forward(x, n_params, d)

272 ms ± 11.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [348]:
%timeit j_bkwd = jacobian_reverse(x, n_params, d)

4.25 ms ± 211 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


For small numbers of walkers, reverse is better.  What about comparable sizes?

In [349]:
n_walkers = 1000
n_particles = 2
n_dim = 3
x = torch.rand((n_walkers, n_particles, n_dim), requires_grad=True, dtype=torch.float64)

In [350]:
%timeit j_fwd = jacobian_forward(x, n_params, d)

835 ms ± 27.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [351]:
%timeit j_bkwd = jacobian_reverse(x, n_params, d)

437 ms ± 3.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


Here, it took about twice as long for the forward mode as backward mode - makes sense, its 2x passes per parameter and there are similar sized parameters as inputs.

In [352]:
n_walkers = 10000
n_particles = 2
n_dim = 3
x = torch.rand((n_walkers, n_particles, n_dim), requires_grad=True, dtype=torch.float64)

In [355]:
%timeit j_fwd = jacobian_forward(x, n_params, d)

9.73 s ± 198 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [356]:
%timeit j_bkwd = jacobian_reverse(x, n_params, d)

45.8 s ± 1.12 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


At very high number of walkers, which is more efficient for memory usage, the forward mode jacobian is better!  In particular, there is only a little overhead introduced 