In [1]:
from torchinfo import summary
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class SiameseNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(SiameseNetwork, self).__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        
        self.layer1 = nn.Linear(in_features=self.input_dim, out_features=self.hidden_dim)
        self.layer2 = nn.Linear(in_features=self.hidden_dim, out_features=self.hidden_dim)
        self.out = nn.Linear(in_features=self.hidden_dim, out_features=self.hidden_dim)
        self.dropout = nn.Dropout(p=0.1, inplace=True)
    
    def forward_once(self, x):
        x = x.view(x.shape[0], -1)
        x = F.relu(self.dropout(self.layer1(x)))
        x = F.relu(self.dropout(self.layer2(x)))
        out = self.out(x)
        
        return out
    
    def forward(self, x1, x2):
        out1 = self.forward_once(x1)
        out2 = self.forward_once(x2)
        
        return out1, out2

In [3]:
def initialize_base_network():
    model = SiameseNetwork(input_dim=784, hidden_dim=128)
    
    return model

In [4]:
model = initialize_base_network()
print(model)

SiameseNetwork(
  (layer1): Linear(in_features=784, out_features=128, bias=True)
  (layer2): Linear(in_features=128, out_features=128, bias=True)
  (out): Linear(in_features=128, out_features=128, bias=True)
  (dropout): Dropout(p=0.1, inplace=True)
)


In [5]:
summary(model)

Layer (type:depth-idx)                   Param #
SiameseNetwork                           --
├─Linear: 1-1                            100,480
├─Linear: 1-2                            16,512
├─Linear: 1-3                            16,512
├─Dropout: 1-4                           --
Total params: 133,504
Trainable params: 133,504
Non-trainable params: 0