## Counting the number of parameters

Deep learning models are famous for having a lot of parameters. Recent language models have billions of parameters. With more parameters comes more computational complexity and longer training times, and a deep learning practitioner must know how many parameters their model has.

In this exercise, you will calculate the number of parameters in your model, first using PyTorch then manually.

The torch.nn package has been imported as nn

In [1]:
import torch
import torch.nn as nn

- Iterate through the model's parameters to update the total variable with the total number of parameters in the model.

In [5]:
model = nn.Sequential(nn.Linear(16, 4),
                      nn.Linear(4, 2),
                      nn.Linear(2, 1))

total = 0

# Calculate the number of parameters in the model
for parameter in model.parameters():
  total += parameter.numel()
print(total)

81


Manually:

In [6]:
"""
Input layer has 16 neurons, no parameters
First hidden layer has 4 neurons, 
each neuron will have 16 weights and 1 bias
in total 16 + 1 = 20 parameters
So, first hidden layer will have 4 * 17 = 68 parameters
Second hidden layer has 2 neurons,
each neuron will have 4 weights and 1 bias
in total 4 + 1 = 5 parameters
So, second hidden layer will have 2 * 5 = 10 parameters
Output layer has 1 neuron,
each neuron will have 2 weights and 1 bias
in total 2 + 1 = 3 parameters
So, output layer will have 1 * 3 = 3 parameters

Total parameters = 68 + 10 + 3 = 81
"""

'\nInput layer has 16 neurons, no parameters\nFirst hidden layer has 4 neurons, \neach neuron will have 16 weights and 1 bias\nin total 16 + 1 = 20 parameters\nSo, first hidden layer will have 4 * 17 = 68 parameters\nSecond hidden layer has 2 neurons,\neach neuron will have 4 weights and 1 bias\nin total 4 + 1 = 5 parameters\nSo, second hidden layer will have 2 * 5 = 10 parameters\nOutput layer has 1 neuron,\neach neuron will have 2 weights and 1 bias\nin total 2 + 1 = 3 parameters\nSo, output layer will have 1 * 3 = 3 parameters\n\nTotal parameters = 68 + 10 + 3 = 81\n'