# Pytorch 2nd order Derivative computations

In [274]:
import torch
n_walkers = 10
n_particles = 2
n_dim = 3
torch.set_default_tensor_type(torch.DoubleTensor)

Define an input tensor:

In [275]:
x = torch.rand((n_walkers, n_particles, n_dim), requires_grad=True, dtype=torch.float64)

In [276]:
x.size()

torch.Size([10, 2, 3])

In [277]:
print(x)

tensor([[[0.0466, 0.8930, 0.4716],
         [0.1755, 0.4794, 0.7658]],

        [[0.4399, 0.7239, 0.3771],
         [0.6025, 0.7637, 0.4979]],

        [[0.4121, 0.7012, 0.0066],
         [0.4676, 0.7512, 0.0072]],

        [[0.7359, 0.6154, 0.6321],
         [0.5837, 0.6246, 0.3641]],

        [[0.4467, 0.0756, 0.1534],
         [0.5131, 0.7008, 0.3360]],

        [[0.4468, 0.6092, 0.0456],
         [0.8910, 0.8574, 0.2396]],

        [[0.5234, 0.1006, 0.2207],
         [0.8233, 0.5003, 0.9378]],

        [[0.2860, 0.5691, 0.8432],
         [0.1798, 0.7486, 0.2296]],

        [[0.0188, 0.9081, 0.5145],
         [0.0465, 0.7274, 0.5099]],

        [[0.9027, 0.1038, 0.5196],
         [0.4229, 0.5855, 0.0591]]], requires_grad=True)


Define a model to differentiate:

In [278]:
class DeepSets(torch.nn.Module):
    
    def __init__(self):
        torch.nn.Module.__init__(self)
        
        self.individual_net = torch.nn.ModuleList([
            torch.nn.Linear(n_dim, 16, bias=False),
            torch.nn.Linear(16, 32, bias=False)
        ]
        )
        
        self.aggregate_net = torch.nn.ModuleList([
            torch.nn.Linear(32, 16, bias=False),
            torch.nn.Linear(16, 1, bias=False)
        ]
        )
        
    def forward(self, inputs):
        
        #split the inputs along the particle dimension (1):
        particles = torch.chunk(inputs, n_particles, axis=1)
        
        particles = [torch.reshape(p, (-1,n_dim)) for p in particles]
        
        individuals = []
        for p in particles:
            this_i = p
            for l in self.individual_net:
                this_i = torch.tanh(l(this_i))
            individuals.append(this_i)
#         individuals = [self.individual_net(p) for p in particles]

        concatd = torch.stack(individuals, dim=-1)
        
        
        # Sum over the latent space:
        summed = torch.sum(concatd, dim=-1)
        
        output = summed
        for l in self.aggregate_net:
            output = l(output)
#         output = self.(summed)
        
        return output.reshape((-1))
        

In [279]:
d = DeepSets()

In [280]:
print(x.requires_grad)
o = d(x)

True


In [281]:
o.shape

torch.Size([10])

In [282]:
o.grad_fn

<ReshapeAliasBackward0 at 0x7fcbd1d5ec88>

# How to compute derivatives

In [283]:
# First, get the first derivative at every particle coordinate:

In [284]:
w_of_x, jvp = torch.autograd.functional.vjp(d.forward, x, torch.ones((n_walkers)))

In [285]:
jvp.shape

torch.Size([10, 2, 3])

In [286]:
o - w_of_x

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], grad_fn=<SubBackward0>)

Check the derivatives with numerical approximation (NOTE: it helps to use double precision for this):

In [287]:
kick = torch.zeros(x.shape)
dim = torch.randint(low=0,high=n_dim, size=()).item()
particle = torch.randint(low=0,high=n_particles, size=()).item()
kick_size = 1e-5
kick[:,particle, dim] = kick_size

In [288]:
numerical_dw_dx = torch.reshape((d(x + kick) - d(x - kick))/(2*kick_size), (-1,))
print(numerical_dw_dx.shape)
print(jvp[:,particle, dim].shape)

torch.Size([10])
torch.Size([10])


In [289]:
print(numerical_dw_dx)
print(jvp[:,particle, dim])
print(numerical_dw_dx - jvp[:,particle, dim] )

tensor([0.0550, 0.0598, 0.0646, 0.0570, 0.0648, 0.0648, 0.0643, 0.0523, 0.0540,
        0.0607], grad_fn=<ReshapeAliasBackward0>)
tensor([0.0550, 0.0598, 0.0646, 0.0570, 0.0648, 0.0648, 0.0643, 0.0523, 0.0540,
        0.0607])
tensor([ 1.5443e-12,  3.0219e-14,  3.2294e-12,  1.6786e-13,  2.2961e-12,
         1.0515e-12, -1.6431e-12, -7.0966e-13, -9.7358e-13, -1.8749e-12],
       grad_fn=<SubBackward0>)


In [290]:
print(x.requires_grad)
o = d(x)
dw_dx = torch.autograd.grad(o, x, grad_outputs=torch.ones((n_walkers)), retain_graph=True, create_graph=True)[0]

True


In [291]:
dw_dx

tensor([[[-0.0445,  0.0355,  0.0550],
         [-0.0443,  0.0454,  0.0541]],

        [[-0.0479,  0.0404,  0.0598],
         [-0.0482,  0.0376,  0.0582]],

        [[-0.0490,  0.0436,  0.0646],
         [-0.0495,  0.0417,  0.0646]],

        [[-0.0473,  0.0411,  0.0570],
         [-0.0478,  0.0430,  0.0611]],

        [[-0.0452,  0.0557,  0.0648],
         [-0.0482,  0.0411,  0.0609]],

        [[-0.0484,  0.0458,  0.0648],
         [-0.0489,  0.0343,  0.0622]],

        [[-0.0452,  0.0551,  0.0643],
         [-0.0462,  0.0427,  0.0523]],

        [[-0.0449,  0.0424,  0.0523],
         [-0.0470,  0.0418,  0.0607]],

        [[-0.0438,  0.0346,  0.0540],
         [-0.0440,  0.0404,  0.0559]],

        [[-0.0442,  0.0527,  0.0607],
         [-0.0481,  0.0465,  0.0647]]], grad_fn=<CatBackward0>)

In [292]:
d2w_dx2 = torch.autograd.grad(dw_dx, x, grad_outputs=torch.ones_like(x))[0]

In [293]:
# How about a second order derivative?  Recompute the first order and create a graph:
# l = lambda x : d.forward(x).sum()
# d2w_dx2, hvp = torch.autograd.functional.hvp(l, x, torch.ones_like(x))

In [294]:
print(d2w_dx2)

tensor([[[-0.0091, -0.0434, -0.0178],
         [-0.0039, -0.0328, -0.0233]],

        [[-0.0053, -0.0441, -0.0187],
         [-0.0038, -0.0449, -0.0197]],

        [[-0.0055, -0.0408, -0.0111],
         [-0.0051, -0.0422, -0.0118]],

        [[-0.0022, -0.0417, -0.0208],
         [-0.0040, -0.0418, -0.0182]],

        [[-0.0039, -0.0115, -0.0055],
         [-0.0047, -0.0436, -0.0181]],

        [[-0.0054, -0.0382, -0.0108],
         [ 0.0004, -0.0447, -0.0175]],

        [[-0.0041, -0.0147, -0.0081],
         [-0.0007, -0.0373, -0.0219]],

        [[-0.0049, -0.0359, -0.0222],
         [-0.0069, -0.0432, -0.0164]],

        [[-0.0094, -0.0428, -0.0176],
         [-0.0072, -0.0404, -0.0195]],

        [[-0.0027, -0.0202, -0.0148],
         [-0.0055, -0.0375, -0.0109]]])


In [298]:
d2w_dx2.shape

torch.Size([10, 2, 3])

In [299]:
# d2w_dx2, vhp = torch.autograd.functional.vhp(l, x, torch.ones_like(x))

In [300]:
# torch.max(hvp - hvp)

In [301]:
# torch.min(hvp - hvp)

In [302]:
kick = torch.zeros(x.shape)
dim = torch.randint(low=0,high=n_dim, size=()).item()
particle = torch.randint(low=0,high=n_particles, size=()).item()
kick_size = 1e-4
kick[:,particle, dim] = kick_size


In [303]:
numerical_d2w_dx2_order1 = (d(x + kick) - 2*d(x) + d(x-kick)) / (kick_size * kick_size)
numerical_d2w_dx2_order2 = (-d(x-2*kick) + 16*d(x-kick) - 30*d(x) + 16*d(x+kick) - d(x+2*kick)) / (12*kick_size * kick_size)
print(numerical_d2w_dx2_order1)
print(numerical_d2w_dx2_order2)

print(numerical_d2w_dx2_order1 / numerical_d2w_dx2_order2)



tensor([-0.0135, -0.0043, -0.0024,  0.0005,  0.0027, -0.0015,  0.0028, -0.0079,
        -0.0143,  0.0050], grad_fn=<DivBackward0>)
tensor([-0.0135, -0.0043, -0.0024,  0.0005,  0.0027, -0.0015,  0.0028, -0.0079,
        -0.0143,  0.0050], grad_fn=<DivBackward0>)
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000], grad_fn=<DivBackward0>)


In [304]:
numerical_d2w_dx2 = numerical_d2w_dx2_order2.reshape(-1)
print(numerical_d2w_dx2.shape)
print(d2w_dx2[:,particle, dim].shape)
print(d2w_dx2[:,particle,dim])

torch.Size([10])
torch.Size([10])
tensor([-0.0091, -0.0053, -0.0055, -0.0022, -0.0039, -0.0054, -0.0041, -0.0049,
        -0.0094, -0.0027])


In [305]:

print(numerical_d2w_dx2)
print(d2w_dx2[:,particle, dim])
print(numerical_d2w_dx2 - d2w_dx2[:,particle, dim] )
print(numerical_d2w_dx2 / d2w_dx2[:,particle, dim] )

tensor([-0.0135, -0.0043, -0.0024,  0.0005,  0.0027, -0.0015,  0.0028, -0.0079,
        -0.0143,  0.0050], grad_fn=<ReshapeAliasBackward0>)
tensor([-0.0091, -0.0053, -0.0055, -0.0022, -0.0039, -0.0054, -0.0041, -0.0049,
        -0.0094, -0.0027])
tensor([-0.0044,  0.0010,  0.0032,  0.0026,  0.0066,  0.0039,  0.0069, -0.0031,
        -0.0049,  0.0078], grad_fn=<SubBackward0>)
tensor([ 1.4898,  0.8075,  0.4285, -0.2112, -0.6737,  0.2779, -0.6931,  1.6363,
         1.5185, -1.8507], grad_fn=<DivBackward0>)


In [306]:
print(x.requires_grad)
o = d(x)
dw_dx = torch.autograd.grad(o,x,grad_outputs=torch.ones((n_walkers)), retain_graph=True, create_graph=True)[0]
print(dw_dx.shape)
print(dw_dx)

True
torch.Size([10, 2, 3])
tensor([[[-0.0445,  0.0355,  0.0550],
         [-0.0443,  0.0454,  0.0541]],

        [[-0.0479,  0.0404,  0.0598],
         [-0.0482,  0.0376,  0.0582]],

        [[-0.0490,  0.0436,  0.0646],
         [-0.0495,  0.0417,  0.0646]],

        [[-0.0473,  0.0411,  0.0570],
         [-0.0478,  0.0430,  0.0611]],

        [[-0.0452,  0.0557,  0.0648],
         [-0.0482,  0.0411,  0.0609]],

        [[-0.0484,  0.0458,  0.0648],
         [-0.0489,  0.0343,  0.0622]],

        [[-0.0452,  0.0551,  0.0643],
         [-0.0462,  0.0427,  0.0523]],

        [[-0.0449,  0.0424,  0.0523],
         [-0.0470,  0.0418,  0.0607]],

        [[-0.0438,  0.0346,  0.0540],
         [-0.0440,  0.0404,  0.0559]],

        [[-0.0442,  0.0527,  0.0607],
         [-0.0481,  0.0465,  0.0647]]], grad_fn=<CatBackward0>)


In [307]:
d2w_dx2_slow = torch.zeros_like(x)
print(x.shape)
for i_part in range(n_particles):
    for i_dim in range(n_dim):
        dw_dx_ij = dw_dx[:,i_part,i_dim]

        d2w_dx2_ij = torch.autograd.grad(dw_dx_ij, x, grad_outputs=torch.ones([n_walkers,]), retain_graph=True)[0]
        
        d2w_dx2_slow[:,i_part, i_dim] = d2w_dx2_ij[:,i_part,i_dim]

torch.Size([10, 2, 3])


In [308]:
d2w_dx2_slow.shape

torch.Size([10, 2, 3])

In [309]:
numerical_d2w_dx2_order2

tensor([-0.0135, -0.0043, -0.0024,  0.0005,  0.0027, -0.0015,  0.0028, -0.0079,
        -0.0143,  0.0050], grad_fn=<DivBackward0>)

In [310]:
d2w_dx2_slow[:,particle,dim]

tensor([-0.0135, -0.0043, -0.0024,  0.0005,  0.0027, -0.0015,  0.0028, -0.0079,
        -0.0143,  0.0050])

In [318]:
d2w_dx2 = torch.autograd.grad(dw_dx, x, grad_outputs=torch.ones((n_walkers, n_particles, n_dim)), retain_graph=True)

In [313]:
d2w_dx2_slow

tensor([[[-0.0135, -0.0326, -0.0155],
         [-0.0084, -0.0257, -0.0183]],

        [[-0.0043, -0.0327, -0.0155],
         [-0.0019, -0.0339, -0.0165]],

        [[-0.0024, -0.0292, -0.0095],
         [-0.0016, -0.0300, -0.0099]],

        [[ 0.0005, -0.0319, -0.0175],
         [-0.0011, -0.0311, -0.0150]],

        [[ 0.0027, -0.0074, -0.0048],
         [-0.0026, -0.0322, -0.0150]],

        [[-0.0015, -0.0277, -0.0095],
         [ 0.0048, -0.0332, -0.0136]],

        [[ 0.0028, -0.0096, -0.0068],
         [ 0.0013, -0.0286, -0.0184]],

        [[-0.0079, -0.0285, -0.0184],
         [-0.0085, -0.0314, -0.0136]],

        [[-0.0143, -0.0324, -0.0155],
         [-0.0121, -0.0310, -0.0163]],

        [[ 0.0050, -0.0129, -0.0119],
         [-0.0018, -0.0272, -0.0096]]])

In [319]:
d2w_dx2

(tensor([[[-0.0091, -0.0434, -0.0178],
          [-0.0039, -0.0328, -0.0233]],
 
         [[-0.0053, -0.0441, -0.0187],
          [-0.0038, -0.0449, -0.0197]],
 
         [[-0.0055, -0.0408, -0.0111],
          [-0.0051, -0.0422, -0.0118]],
 
         [[-0.0022, -0.0417, -0.0208],
          [-0.0040, -0.0418, -0.0182]],
 
         [[-0.0039, -0.0115, -0.0055],
          [-0.0047, -0.0436, -0.0181]],
 
         [[-0.0054, -0.0382, -0.0108],
          [ 0.0004, -0.0447, -0.0175]],
 
         [[-0.0041, -0.0147, -0.0081],
          [-0.0007, -0.0373, -0.0219]],
 
         [[-0.0049, -0.0359, -0.0222],
          [-0.0069, -0.0432, -0.0164]],
 
         [[-0.0094, -0.0428, -0.0176],
          [-0.0072, -0.0404, -0.0195]],
 
         [[-0.0027, -0.0202, -0.0148],
          [-0.0055, -0.0375, -0.0109]]]),)

In [330]:
dw_dx_ij_test = dw_dx[:,0,0]
print(dw_dx.shape)
print(dw_dx_ij_test.shape)
print(x.shape)
d2w_dx2_test = torch.autograd.grad(dw_dx_ij_test, x, grad_outputs=torch.ones([n_walkers,2,3]), retain_graph=True)[0]


torch.Size([10, 2, 3])
torch.Size([10])
torch.Size([10, 2, 3])


RuntimeError: Mismatch in shape: grad_output[0] has a shape of torch.Size([10, 2, 3]) and output[0] has a shape of torch.Size([10]).

In [347]:
print(dw_dx.shape)
print(x.shape)




torch.Size([10, 2, 3])
torch.Size([10, 2, 3])


In [349]:
dw_dx.shape[-2]

2

In [423]:
inputs = x
o = d(inputs)
dw_dx = torch.autograd.grad(o,inputs,grad_outputs=torch.ones([n_walkers]), retain_graph=True, create_graph=True)[0]
summed = torch.sum(dw_dx)

In [425]:
f = summed.backward()

In [426]:
print(f)

None


In [427]:
print(d2w_dx2_ein / d2w_dx2_slow[0])
print(inputs)

tensor([[[0.6712, 1.3317, 1.1502],
         [0.4664, 1.2762, 1.2711]]])
tensor([[[0.0466, 0.8930, 0.4716],
         [0.1755, 0.4794, 0.7658]],

        [[0.4399, 0.7239, 0.3771],
         [0.6025, 0.7637, 0.4979]],

        [[0.4121, 0.7012, 0.0066],
         [0.4676, 0.7512, 0.0072]],

        [[0.7359, 0.6154, 0.6321],
         [0.5837, 0.6246, 0.3641]],

        [[0.4467, 0.0756, 0.1534],
         [0.5131, 0.7008, 0.3360]],

        [[0.4468, 0.6092, 0.0456],
         [0.8910, 0.8574, 0.2396]],

        [[0.5234, 0.1006, 0.2207],
         [0.8233, 0.5003, 0.9378]],

        [[0.2860, 0.5691, 0.8432],
         [0.1798, 0.7486, 0.2296]],

        [[0.0188, 0.9081, 0.5145],
         [0.0465, 0.7274, 0.5099]],

        [[0.9027, 0.1038, 0.5196],
         [0.4229, 0.5855, 0.0591]]], requires_grad=True)


In [362]:
f = lambda x : d(x)

In [370]:
_, g = torch.autograd.functional.jvp(f, x, v=torch.ones_like(x))

In [372]:
g.shape

torch.Size([10])

In [366]:
dw_dx

tensor([[[-0.0445,  0.0355,  0.0550],
         [-0.0443,  0.0454,  0.0541]],

        [[-0.0479,  0.0404,  0.0598],
         [-0.0482,  0.0376,  0.0582]],

        [[-0.0490,  0.0436,  0.0646],
         [-0.0495,  0.0417,  0.0646]],

        [[-0.0473,  0.0411,  0.0570],
         [-0.0478,  0.0430,  0.0611]],

        [[-0.0452,  0.0557,  0.0648],
         [-0.0482,  0.0411,  0.0609]],

        [[-0.0484,  0.0458,  0.0648],
         [-0.0489,  0.0343,  0.0622]],

        [[-0.0452,  0.0551,  0.0643],
         [-0.0462,  0.0427,  0.0523]],

        [[-0.0449,  0.0424,  0.0523],
         [-0.0470,  0.0418,  0.0607]],

        [[-0.0438,  0.0346,  0.0540],
         [-0.0440,  0.0404,  0.0559]],

        [[-0.0442,  0.0527,  0.0607],
         [-0.0481,  0.0465,  0.0647]]], grad_fn=<CatBackward0>)

In [408]:
func = lambda x : d(x)
jacf = lambda x : torch.autograd.functional.jacobian(func, x)

In [409]:
jacf(x)

tensor([[[[-0.0445,  0.0355,  0.0550],
          [-0.0443,  0.0454,  0.0541]],

         [[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]]],


        [[[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]],

         [[-0.0479,  0.0404,  0.0598],
          [-0.0482,  0.0376,  0.0582]],

         [[ 0.0000,  0.0000,  0.0000],