In [141]:
import torch
from torch import nn
import pytorch_lightning as pl

In [166]:
class NLBranchNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=2))
        self.layers.append(nn.ReLU())
        self.layers.append(nn.BatchNorm2d(num_features=16))
        self.layers.append(nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1))
        self.layers.append(nn.ReLU())
        self.layers.append(nn.BatchNorm2d(num_features=32))
        self.layers.append(nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=1))
        self.layers.append(nn.ReLU())
        self.layers.append(nn.Conv2d(in_channels=16, out_channels=1, kernel_size=3, stride=1, padding=1))
        
    def forward(self, x):
        x = x.unsqueeze(1)
        print(x.shape)
        for layer in self.layers:
            x = layer(x)
            print(x.shape)
        y = x.squeeze()   
        return y

In [167]:
x = torch.rand(100,12,12)

In [168]:
model = NLBranchNet()

In [169]:
model

NLBranchNet(
  (layers): ModuleList(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
)

In [170]:
sum(p.numel() for p in model.parameters())

9665

In [160]:
class NLBranchNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList()
        # self.layers.append(ResizeLayer2D(params, input_dim=self.hparams['input_dim'], output_dim=12))
        # self.layers.append(UnsqueezeLayer(params))
        self.layers.append(nn.ConvTranspose2d(in_channels=1, out_channels=16, kernel_size=4, stride=1))
        self.layers.append(nn.ReLU())
        self.layers.append(nn.BatchNorm2d(num_features=16))
        self.layers.append(nn.ConvTranspose2d(in_channels=16, out_channels=32, kernel_size=4, stride=1))
        self.layers.append(nn.ReLU())
        self.layers.append(nn.BatchNorm2d(num_features=32))
        self.layers.append(nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=2, stride=2))
        self.layers.append(nn.ReLU())
        self.layers.append(nn.ConvTranspose2d(in_channels=16, out_channels=1, kernel_size=2, stride=2))
        # self.layers.append(SqueezeLayer(params))
        # self.layers.append(ResizeLayer2D(params, input_dim=72, output_dim=self.hparams['latent_dim']))

    def forward(self, x):
        x = x.unsqueeze(1)
        for layer in self.layers:
            x = layer(x)
        y = x.squeeze()   
        return y

In [148]:
model2 = NLBranchNet()

In [149]:
sum(p.numel() for p in model2.parameters())

10721