In [49]:
import torch
from torch import nn
import pytorch_lightning as pl
import math

In [50]:
Q = 72
x = torch.rand(100,Q,Q)

In [71]:
class NLBranchNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.rel_kernel_size = 1/3
        self.layers = nn.ModuleList()
        self.layers.append(nn.Conv2d(in_channels=1, out_channels=4, kernel_size=int(self.rel_kernel_size*Q), stride=1, padding=int((self.rel_kernel_size*Q - 1)/2)))
        self.layers.append(nn.ConvTranspose2d(in_channels=4, out_channels=4, kernel_size=int(self.rel_kernel_size*Q), stride=1, padding=int((self.rel_kernel_size*Q - 1)/2)))
        self.layers.append(nn.Conv2d(in_channels=4, out_channels=4, kernel_size=int(self.rel_kernel_size*Q), stride=1, padding=int((self.rel_kernel_size*Q - 1)/2)))
        self.layers.append(nn.ConvTranspose2d(in_channels=4, out_channels=1, kernel_size=int(self.rel_kernel_size*Q), stride=1, padding=int((self.rel_kernel_size*Q - 1)/2)))
    def forward(self, x):
        x = x.unsqueeze(1)
        print(x.shape)
        for layer in self.layers:
            x = layer(x)
            print(x.shape)
        y = x.squeeze()   
        return y

In [72]:
model = NLBranchNet()

In [73]:
sum(p.numel() for p in model.parameters())

23053

In [36]:
model

NLBranchNet(
  (layers): ModuleList(
    (0): Conv2d(1, 4, kernel_size=(24, 24), stride=(1, 1), padding=(11, 11))
    (1): ConvTranspose2d(4, 8, kernel_size=(24, 24), stride=(1, 1), padding=(11, 11))
    (2): Conv2d(8, 4, kernel_size=(24, 24), stride=(1, 1), padding=(11, 11))
    (3): ConvTranspose2d(4, 1, kernel_size=(24, 24), stride=(1, 1), padding=(11, 11))
  )
)

In [37]:
model(x)

torch.Size([100, 1, 72, 72])
torch.Size([100, 4, 71, 71])
torch.Size([100, 8, 72, 72])
torch.Size([100, 4, 71, 71])
torch.Size([100, 1, 72, 72])


tensor([[[-0.0265, -0.0308, -0.0109,  ..., -0.0203, -0.0027, -0.0211],
         [-0.0470, -0.0539, -0.0412,  ..., -0.0175, -0.0095, -0.0205],
         [-0.0496, -0.0240, -0.0412,  ..., -0.0697, -0.0387, -0.0234],
         ...,
         [-0.0342, -0.0323, -0.0422,  ..., -0.0482, -0.0494, -0.0056],
         [-0.0221, -0.0446, -0.0501,  ..., -0.0394, -0.0319, -0.0575],
         [-0.0356, -0.0608, -0.0342,  ..., -0.0375, -0.0298, -0.0138]],

        [[-0.0521, -0.0430, -0.0224,  ..., -0.0195, -0.0146, -0.0312],
         [-0.0359, -0.0437,  0.0063,  ..., -0.0195, -0.0258, -0.0006],
         [-0.0352, -0.0550, -0.0083,  ..., -0.0259, -0.0272, -0.0209],
         ...,
         [-0.0455, -0.0839, -0.0607,  ..., -0.0863, -0.0300, -0.0430],
         [-0.0321, -0.0433, -0.0720,  ..., -0.0502, -0.0224, -0.0313],
         [-0.0464, -0.0477, -0.0748,  ..., -0.0514, -0.0377, -0.0327]],

        [[-0.0348, -0.0323, -0.0195,  ..., -0.0369, -0.0037, -0.0090],
         [-0.0368, -0.0387, -0.0216,  ..., -0

In [7]:
sum(p.numel() for p in model.parameters())

2308

In [8]:
class NLBranchNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList()
        # self.layers.append(ResizeLayer2D(params, input_dim=self.hparams['input_dim'], output_dim=12))
        # self.layers.append(UnsqueezeLayer(params))
        self.layers.append(nn.ConvTranspose2d(in_channels=1, out_channels=16, kernel_size=4, stride=1))
        self.layers.append(nn.ReLU())
        self.layers.append(nn.BatchNorm2d(num_features=16))
        self.layers.append(nn.ConvTranspose2d(in_channels=16, out_channels=32, kernel_size=4, stride=1))
        self.layers.append(nn.ReLU())
        self.layers.append(nn.BatchNorm2d(num_features=32))
        self.layers.append(nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=2, stride=2))
        self.layers.append(nn.ReLU())
        self.layers.append(nn.ConvTranspose2d(in_channels=16, out_channels=1, kernel_size=2, stride=2))
        # self.layers.append(SqueezeLayer(params))
        # self.layers.append(ResizeLayer2D(params, input_dim=72, output_dim=self.hparams['latent_dim']))

    def forward(self, x):
        x = x.unsqueeze(1)
        for layer in self.layers:
            x = layer(x)
        y = x.squeeze()   
        return y

In [9]:
model2 = NLBranchNet()

In [10]:
sum(p.numel() for p in model2.parameters())

10721