## Parameter Management


In [2]:
import torch
from torch import nn

In [3]:
net = nn.Sequential(nn.LazyLinear(8),
                    nn.ReLU(),
                    nn.LazyLinear(1))

X = torch.rand(size=(2, 4))
net(X).shape



torch.Size([2, 1])

In [10]:
# access params at each layer
net[2].state_dict() # output layer [2]

OrderedDict([('weight',
              tensor([[-0.2412,  0.2914, -0.2541, -0.1560,  0.1855,  0.1391,  0.3530, -0.2231]])),
             ('bias', tensor([0.0117]))])

In [12]:
type(net[2].bias), net[2].bias

(torch.nn.parameter.Parameter,
 Parameter containing:
 tensor([0.0117], requires_grad=True))

In [15]:
net[2].weight.grad == None # havent invoked backprop yet


True

In [17]:
[(name, param) for name, param in net.named_parameters()]

[('0.weight',
  Parameter containing:
  tensor([[-0.3788,  0.1551,  0.3746,  0.3354],
          [-0.4126, -0.0341, -0.0814, -0.0928],
          [-0.4628, -0.1518, -0.3645,  0.0715],
          [-0.0437, -0.2254, -0.0970, -0.2185],
          [-0.3938, -0.0218,  0.2459, -0.3831],
          [ 0.4729, -0.4725,  0.2638,  0.0982],
          [-0.2333, -0.2876, -0.0968,  0.1695],
          [-0.0863,  0.2315,  0.2779,  0.4808]], requires_grad=True)),
 ('0.bias',
  Parameter containing:
  tensor([-0.3403, -0.0216,  0.1643,  0.1851, -0.0497, -0.4958, -0.3582,  0.0541],
         requires_grad=True)),
 ('2.weight',
  Parameter containing:
  tensor([[-0.2412,  0.2914, -0.2541, -0.1560,  0.1855,  0.1391,  0.3530, -0.2231]],
         requires_grad=True)),
 ('2.bias',
  Parameter containing:
  tensor([0.0117], requires_grad=True))]

In [26]:
# We need to give the shared layer a name so that we can refer to its
# parameters
shared = nn.LazyLinear(8)
net = nn.Sequential(nn.LazyLinear(8), nn.ReLU(),
                    shared, nn.ReLU(),
                    shared, nn.ReLU(),
                    nn.LazyLinear(1))
net(X)
print(net[2].weight.data[0] == net[4].weight.data[0])
# Make sure that they are actually the same object rather than just having the
# same value
net[2].weight.data[0] = 100
print(net[2].weight.data[0] == net[4].weight.data[0])


tensor([True, True, True, True, True, True, True, True])
tensor([True, True, True, True, True, True, True, True])
