In [1]:
import torch
import torchsummary

In [2]:
x = torch.randn(3, 64, 64)

In [3]:
x.numel()

12288

In [4]:
num_hidden = 10
num_output = 2

In [5]:
num_params_hidden = num_hidden * x.numel() + num_hidden
num_params_hidden

122890

In [6]:
num_params_out = num_output * num_hidden + num_output
num_params_out

22

In [7]:
# Proper way to construct a NN
class NN2(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden_layer = torch.nn.Linear(x.numel(), 10)
        self.output_layer = torch.nn.Linear(10, 2)

    def forward(self, x):
        x = self.hidden_layer(x)
        x = torch.nn.functional.sigmoid(x)
        x = self.output_layer(x)
        x = torch.nn.functional.sigmoid(x)
        return x

In [8]:
model = NN2()

In [9]:
# list(model.parameters())
for param in model.parameters():
    print(param.shape)

torch.Size([10, 12288])
torch.Size([10])
torch.Size([2, 10])
torch.Size([2])


In [10]:
torchsummary.summary(model);

Layer (type:depth-idx)                   Param #
├─Linear: 1-1                            122,890
├─Linear: 1-2                            22
Total params: 122,912
Trainable params: 122,912
Non-trainable params: 0


In [11]:
class Charles:
    def __init__(self, name=None):
        self.name = "charles" if name is None else name
        # this.name = name == null ? "charles" : name; <-- c++ style

In [12]:
class Max(Charles):
    def __init__(self):
        super().__init__()
        self.shirt_color = "black"

In [13]:
m = Max()

In [14]:
m.shirt_color

'black'

In [15]:
m.name

'charles'

In [17]:
# Less proper way to construct the same network
model = torch.nn.Sequential(
    torch.nn.Linear(x.numel(), 10),
    torch.nn.Sigmoid(),
    torch.nn.Linear(10, 2),
    torch.nn.Sigmoid()
)
torchsummary.summary(model);

Layer (type:depth-idx)                   Param #
├─Linear: 1-1                            122,890
├─Sigmoid: 1-2                           --
├─Linear: 1-3                            22
├─Sigmoid: 1-4                           --
Total params: 122,912
Trainable params: 122,912
Non-trainable params: 0
