In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

In [2]:
x = torch.ones(20, requires_grad=True)

In [3]:
print(x)

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1.], requires_grad=True)


In [4]:
x = torch.unsqueeze(x, dim=1)

In [5]:
x.retain_grad()

In [6]:
print(x)

tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]], grad_fn=<UnsqueezeBackward0>)


In [7]:
class Net(nn.Module):
    def __init__(self, n_independent, n_dependent,
                 n_hidden, activation):
        super(Net, self).__init__()
        
        self.activation = activation
        self.input_layer = nn.Linear(n_independent, n_hidden)
        self.output_layer = nn.Linear(n_hidden, n_dependent)
        
    def forward(self, x):
        x = self.activation(self.input_layer(x))
        x = self.output_layer(x)
        return x

In [8]:
net = Net(n_independent=1, n_dependent=1,
          n_hidden=1, activation=F.celu)

In [9]:
y = net(x)

In [10]:
# first, zero out the existing gradients to avoid
# accumulating gradients on top of existing gradients
net.zero_grad()
if x.grad is not None:
    x.grad.data.zero_()

In [11]:
print(y)

tensor([[0.2046],
        [0.2046],
        [0.2046],
        [0.2046],
        [0.2046],
        [0.2046],
        [0.2046],
        [0.2046],
        [0.2046],
        [0.2046],
        [0.2046],
        [0.2046],
        [0.2046],
        [0.2046],
        [0.2046],
        [0.2046],
        [0.2046],
        [0.2046],
        [0.2046],
        [0.2046]], grad_fn=<AddmmBackward>)


In [12]:
y.backward(torch.ones_like(y))

In [13]:
print(x.grad)

tensor([[0.0043],
        [0.0043],
        [0.0043],
        [0.0043],
        [0.0043],
        [0.0043],
        [0.0043],
        [0.0043],
        [0.0043],
        [0.0043],
        [0.0043],
        [0.0043],
        [0.0043],
        [0.0043],
        [0.0043],
        [0.0043],
        [0.0043],
        [0.0043],
        [0.0043],
        [0.0043]])
