In [99]:
import torch

In [100]:
class NN(torch.nn.Module):
    def __init__(self):
        super(NN, self).__init__()
        self.fc1 = torch.nn.Linear(3,3)
        self.fc2 = torch.nn.Linear(3,2)
    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x))
        return self.fc2(x)

In [101]:
x = torch.Tensor([[.1,.1,0.],[.1,.1,0],[.1,.1,1],[.1,.1,1]]).float()
y = torch.Tensor([0,0,1,1]).long()
x.requires_grad = True

In [102]:
torch.manual_seed(4)
nn = NN()
optimizer = torch.optim.SGD(nn.parameters(), lr=.9)

for i in range(50):
    output = nn(x)
    loss = torch.nn.functional.cross_entropy(output, y)
    loss.backward()
    optimizer.step()

In [103]:
n_out = 2
batch_size = 3
x_ = torch.Tensor([[.3,-.1,0],
                   [.5,-2,1],
                   [-.8,2,1]])
# x_ = x_.repeat(1, n_out).reshape(x_.shape[0]*n_out, x_.shape[1])
x_ = x_.repeat(n_out, 1)
x_.requires_grad = True
x_

tensor([[ 0.3000, -0.1000,  0.0000],
        [ 0.5000, -2.0000,  1.0000],
        [-0.8000,  2.0000,  1.0000],
        [ 0.3000, -0.1000,  0.0000],
        [ 0.5000, -2.0000,  1.0000],
        [-0.8000,  2.0000,  1.0000]], requires_grad=True)

In [104]:
# gradients = torch.eye(n_out).repeat(batch_size, 1)
idx = torch.arange(n_out).reshape(n_out,1).repeat(1,batch_size).reshape(batch_size*n_out,)
gradients = torch.zeros(len(idx), n_out).scatter_(1, idx.unsqueeze(1), 1.)
gradients

tensor([[1., 0.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [0., 1.],
        [0., 1.]])

In [105]:
optimizer.zero_grad()
y_ = nn(x_)
y_.backward(gradients, create_graph=True)
x_.grad

tensor([[ 4.2701e+01,  4.2685e+01, -5.6636e+02],
        [ 1.0240e-02,  1.7401e-02,  1.6279e-02],
        [ 6.2398e-02,  5.2298e-02,  1.5011e-02],
        [-4.1323e+01, -4.1308e+01,  5.4808e+02],
        [-1.0188e-01, -1.7313e-01, -1.6197e-01],
        [-1.8815e-01, -1.5770e-01, -4.5265e-02]], grad_fn=<CloneBackward>)

In [117]:
optimizer.zero_grad()
y_ = nn(x_)
torch.autograd.grad(y_, x_, grad_outputs=gradients, create_graph=True)[0]

tensor([[ 4.2701e+01,  4.2685e+01, -5.6636e+02],
        [ 1.0240e-02,  1.7401e-02,  1.6279e-02],
        [ 6.2398e-02,  5.2298e-02,  1.5011e-02],
        [-4.1323e+01, -4.1308e+01,  5.4808e+02],
        [-1.0188e-01, -1.7313e-01, -1.6197e-01],
        [-1.8815e-01, -1.5770e-01, -4.5265e-02]], grad_fn=<MmBackward>)

In [107]:
jacobians = x_.grad.reshape(batch_size, -1)

In [108]:
jacobian_norms = torch.pow(jacobians.norm(p='fro', dim=1), 2)

In [109]:
jacobian_norms

tensor([3.2441e+05, 3.0381e+05, 1.2891e-01], grad_fn=<PowBackward0>)

In [110]:
x1 = torch.Tensor([[.1,.05,0],[.1,.05,0]])
x1.requires_grad = True
y1 = nn(x1)
y1

tensor([[ 410.3284, -397.2157],
        [ 410.3284, -397.2157]], grad_fn=<AddmmBackward>)

In [111]:
y1.backward(torch.Tensor([[0,1],[1,0]]))
x1.grad

tensor([[ -41.3233,  -41.3075,  548.0845],
        [  42.7013,   42.6850, -566.3608]])

In [112]:
a = torch.Tensor([[[[1,2,3],
                   [4,5,6],
                   [7,8,9]]],
                  [[[10,11,12],
                   [13,14,15],
                   [16,17,18]]]])
print(a.shape)
a

torch.Size([2, 1, 3, 3])


tensor([[[[ 1.,  2.,  3.],
          [ 4.,  5.,  6.],
          [ 7.,  8.,  9.]]],


        [[[10., 11., 12.],
          [13., 14., 15.],
          [16., 17., 18.]]]])

In [113]:
print(a.repeat(3,1,1,1).shape)
a.repeat(3,1,1,1)

torch.Size([6, 1, 3, 3])


tensor([[[[ 1.,  2.,  3.],
          [ 4.,  5.,  6.],
          [ 7.,  8.,  9.]]],


        [[[10., 11., 12.],
          [13., 14., 15.],
          [16., 17., 18.]]],


        [[[ 1.,  2.,  3.],
          [ 4.,  5.,  6.],
          [ 7.,  8.,  9.]]],


        [[[10., 11., 12.],
          [13., 14., 15.],
          [16., 17., 18.]]],


        [[[ 1.,  2.,  3.],
          [ 4.,  5.,  6.],
          [ 7.,  8.,  9.]]],


        [[[10., 11., 12.],
          [13., 14., 15.],
          [16., 17., 18.]]]])

In [114]:
h = torch.arange(5).reshape(5,1).repeat(1,3).reshape(5*3,1)
y_onehot = torch.FloatTensor(5*3, 5)
y_onehot.zero_()
y_onehot.scatter_(1, h, 1)
y_onehot

tensor([[1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.]])

In [115]:
batch_size = 5
nb_digits = 10
# Dummy input that HAS to be 2D for the scatter (you can use view(-1,1) if needed)
y = torch.LongTensor(batch_size,1).random_() % nb_digits
# One hot encoding buffer that you create out of the loop and just keep reusing
y_onehot = torch.FloatTensor(batch_size, nb_digits)

# In your for loop
y_onehot.zero_()
y_onehot.scatter_(1, y, 1)

print(y.shape)
print(y_onehot)

torch.Size([5, 1])
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]])


In [116]:
x_.requires_grad = True
output = nn(x_)
print(f'{output.data[0]} ==> it\'s a {output.data[0].max(0)[1].item()}!')
# loss = torch.nn.functional.cross_entropy(output, )
output.backward(torch.eye(2))
out0_grad = torch.autograd.grad(output[0,:].sum(), x_, retain_graph=True)

tensor([ 412.4659, -399.2842]) ==> it's a 0!


RuntimeError: invalid gradient at index 0 - got [2, 2] but expected shape compatible with [6, 2]