In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
from torchvision.datasets import svhn
import numpy as np
import matplotlib.pyplot as plt

In [77]:
class Encoder_Block(nn.Module):
    def __init__ (self,num_layers=2,in_channel=3,initial_filter=32):
        super(Encoder_Block,self).__init__()
        self.relu = nn.ReLU()
        self.layers = []
        out_channel = initial_filter*2
        for i in range(num_layers):
            self.layers.append(nn.Conv2d(in_channel,out_channel,kernel_size=3))
            self.layers.append(nn.ReLU())
            in_channel=out_channel
        self.block = nn.Sequential(*self.layers)

    def forward(self,x):
        x = self.block(x)
        print(x.shape)
        return x

In [82]:
class Encoder(nn.Module):
    def __init__(self,num_blocks=5,in_channel=3,final_channel=512):
        super(Encoder,self).__init__()
        self.blocks = []
        for i in reversed(range(num_blocks)):
            initial_filter = final_channel//(2**(i))
            self.blocks.append(Encoder_Block(in_channel=in_channel,initial_filter=initial_filter,num_layers=2))
            if num_blocks-i != num_blocks:
                self.blocks.append(nn.MaxPool2d(kernel_size=2))
            
            in_channel=initial_filter*2
            
        self.encoder = nn.Sequential(*self.blocks)
    def forward(self,x):
        x = self.encoder(x)


In [83]:
test_tensor = torch.randn(1, 3, 572, 572)
model = Encoder_Block(2)
output = model(test_tensor)
model

torch.Size([1, 64, 568, 568])


Encoder_Block(
  (relu): ReLU()
  (block): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU()
  )
)

In [84]:
model = Encoder()
out =model(test_tensor)
print(model)

torch.Size([1, 64, 568, 568])
torch.Size([1, 128, 280, 280])
torch.Size([1, 256, 136, 136])
torch.Size([1, 512, 64, 64])
torch.Size([1, 1024, 28, 28])
Encoder(
  (encoder): Sequential(
    (0): Encoder_Block(
      (relu): ReLU()
      (block): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
        (1): ReLU()
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
        (3): ReLU()
      )
    )
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): Encoder_Block(
      (relu): ReLU()
      (block): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
        (1): ReLU()
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (3): ReLU()
      )
    )
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Encoder_Block(
      (relu): ReLU()
      (block): Sequential(
        (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
    