In [1]:
import numpy as np
import torch
import torch.nn as nn

In [2]:
# build two models

# wide model
widenet = nn.Sequential(
            nn.Linear(2, 4), # input layer
            nn.Linear(4, 3) # output layer
        )

# deep model
deepnet = nn.Sequential(
            nn.Linear(2, 2), # input layer
            nn.Linear(2, 2), # hidden layer
            nn.Linear(2, 3) # output layer
        )

# print them out to have a look
print(widenet)
print(" ")
print(deepnet)

Sequential(
  (0): Linear(in_features=2, out_features=4, bias=True)
  (1): Linear(in_features=4, out_features=3, bias=True)
)
 
Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
  (2): Linear(in_features=2, out_features=3, bias=True)
)


In [3]:
# check out parameters
for p in deepnet.named_parameters():
    print(p)
    print(" ")

('0.weight', Parameter containing:
tensor([[ 0.5365, -0.5862],
        [ 0.6011, -0.4860]], requires_grad=True))
 
('0.bias', Parameter containing:
tensor([ 0.2848, -0.3389], requires_grad=True))
 
('1.weight', Parameter containing:
tensor([[-0.3340,  0.3096],
        [-0.2939,  0.1643]], requires_grad=True))
 
('1.bias', Parameter containing:
tensor([ 0.6291, -0.3056], requires_grad=True))
 
('2.weight', Parameter containing:
tensor([[ 0.2913, -0.4547],
        [-0.4417, -0.2405],
        [ 0.0106, -0.2082]], requires_grad=True))
 
('2.bias', Parameter containing:
tensor([-0.3780, -0.6103,  0.5099], requires_grad=True))
 


In [7]:
# count the no. of nodes ( = the no. of biases)

# named_parameters() is an iterable that returns the tuple (name, numbers)
numNodesInWide = 0
for p in widenet.named_parameters():
    if "bias" in p[0]: # p[0] contains the nameeg: "0.bias"
        numNodesInWide += len(p[1]) # p[1] contains actual values

numNodesInDeep = 0
for paramName, paramVect in deepnet.named_parameters():
    if "bias" in paramName:
        numNodesInDeep += len(paramVect)

print("There are %s nodes in the wide network." %numNodesInWide)
print("There are %s nodes in the deep network." %numNodesInDeep)

There are 7 nodes in the wide network.
There are 7 nodes in the deep network.


In [9]:
# just the parameters
for p in deepnet.parameters():
    print(p)
    print(" ")

Parameter containing:
tensor([[ 0.5365, -0.5862],
        [ 0.6011, -0.4860]], requires_grad=True)
 
Parameter containing:
tensor([ 0.2848, -0.3389], requires_grad=True)
 
Parameter containing:
tensor([[-0.3340,  0.3096],
        [-0.2939,  0.1643]], requires_grad=True)
 
Parameter containing:
tensor([ 0.6291, -0.3056], requires_grad=True)
 
Parameter containing:
tensor([[ 0.2913, -0.4547],
        [-0.4417, -0.2405],
        [ 0.0106, -0.2082]], requires_grad=True)
 
Parameter containing:
tensor([-0.3780, -0.6103,  0.5099], requires_grad=True)
 


In [11]:
# now count the total number of trainable parameters
nparams = 0
for p in widenet.parameters():
    if p.requires_grad:
        print("This piece has %s parameters" %p.numel())
        nparams += p.numel()

print("\n\nTotal of %s parameters" %nparams)

This piece has 8 parameters
This piece has 4 parameters
This piece has 12 parameters
This piece has 3 parameters


Total of 27 parameters


In [14]:
# using list comprehension

nparams = np.sum([p.numel() for p in widenet.parameters() if p.requires_grad])
print("Widenet has %s parameters" %nparams)

nparams = np.sum([p.numel() for p in deepnet.parameters() if p.requires_grad])
print("Widenet has %s parameters" %nparams)

Widenet has 27 parameters
Widenet has 21 parameters


In [17]:
# a nice simple way to print out model info
from torchsummary import summary
summary(widenet, (1,2))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                 [-1, 1, 4]              12
            Linear-2                 [-1, 1, 3]              15
Total params: 27
Trainable params: 27
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
----------------------------------------------------------------
