In [1]:
import torch
import torch.nn as nn

wide_nn = nn.Sequential(
    nn.Linear(2,4),
    nn.Linear(4,3)
)

deep_nn = nn.Sequential(
    nn.Linear(2,2),
    nn.Linear(2,2),
    nn.Linear(2,3)
)

In [2]:
wide_nn

Sequential(
  (0): Linear(in_features=2, out_features=4, bias=True)
  (1): Linear(in_features=4, out_features=3, bias=True)
)

In [3]:
deep_nn

Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
  (2): Linear(in_features=2, out_features=3, bias=True)
)

bias count = number of nodes

In [9]:
for i in wide_nn.named_parameters():
  print(i)
  print(' ')

('0.weight', Parameter containing:
tensor([[ 0.6251,  0.1796],
        [-0.3793, -0.0835],
        [-0.0799, -0.2139],
        [ 0.6023, -0.2981]], requires_grad=True))
 
('0.bias', Parameter containing:
tensor([ 0.2136, -0.1172,  0.2474, -0.3045], requires_grad=True))
 
('1.weight', Parameter containing:
tensor([[-0.1461,  0.4114,  0.2855, -0.0397],
        [ 0.2280,  0.4960, -0.0360, -0.2196],
        [-0.4473, -0.4715,  0.3943,  0.4785]], requires_grad=True))
 
('1.bias', Parameter containing:
tensor([ 0.0252, -0.1940,  0.1095], requires_grad=True))
 


In [20]:
for i in deep_nn.named_parameters():
  print(i)
  print(' ')

('0.weight', Parameter containing:
tensor([[-0.3379,  0.1757],
        [ 0.6381,  0.5820]], requires_grad=True))
 
('0.bias', Parameter containing:
tensor([-0.6913, -0.4261], requires_grad=True))
 
('1.weight', Parameter containing:
tensor([[ 0.6231, -0.3142],
        [-0.6938, -0.3392]], requires_grad=True))
 
('1.bias', Parameter containing:
tensor([-0.5491, -0.3610], requires_grad=True))
 
('2.weight', Parameter containing:
tensor([[-0.0524,  0.2807],
        [-0.4894,  0.1545],
        [ 0.4075, -0.0660]], requires_grad=True))
 
('2.bias', Parameter containing:
tensor([0.6724, 0.2661, 0.0362], requires_grad=True))
 


In [28]:
wideNodesCount = 0
deepNodesCount = 0

for paramName, paramVec in wide_nn.named_parameters():
  if 'bias' in paramName:
    wideNodesCount += len(paramVec)

for paramName, paramVec in deep_nn.named_parameters():
  if 'bias' in paramName:
    deepNodesCount += len(paramVec)

print(f'Wide NN nodes: {wideNodesCount}')
print(f'Deep NN nodes: {deepNodesCount}')

Wide NN nodes: 7
Deep NN nodes: 7


count the total number of trainable parameters

In [29]:
nparams = 0
for p in wide_nn.parameters():
  if p.requires_grad:
    print('This piece has %s parameters' %p.numel())
    nparams += p.numel()

print('\n\nTotal of %s parameters' %nparams)

This piece has 8 parameters
This piece has 4 parameters
This piece has 12 parameters
This piece has 3 parameters


Total of 27 parameters


In [30]:
nparams = 0
for p in deep_nn.parameters():
  if p.requires_grad:
    print('This piece has %s parameters' %p.numel())
    nparams += p.numel()

print('\n\nTotal of %s parameters' %nparams)

This piece has 4 parameters
This piece has 2 parameters
This piece has 4 parameters
This piece has 2 parameters
This piece has 6 parameters
This piece has 3 parameters


Total of 21 parameters


summary

In [32]:
from torchsummary import summary
summary(wide_nn, (1,2))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                 [-1, 1, 4]              12
            Linear-2                 [-1, 1, 3]              15
Total params: 27
Trainable params: 27
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
----------------------------------------------------------------


In [38]:
summary(deep_nn, (1,2))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                 [-1, 1, 2]               6
            Linear-2                 [-1, 1, 2]               6
            Linear-3                 [-1, 1, 3]               9
Total params: 21
Trainable params: 21
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
----------------------------------------------------------------
