In [66]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [67]:
class Conv_block(nn.Module):
    def __init__(self,dim,dConv_kernel_size=7):
        super().__init__()
        self.depth_conv = nn.Conv2d(dim,dim,kernel_size=dConv_kernel_size,padding=int((dConv_kernel_size-1)/2),groups=dim)
        self.norm = nn.BatchNorm2d(dim)
        self.conv_1 = nn.Conv2d(dim,dim*4,kernel_size=1)
        self.act = nn.GELU()
        self.conv_2 = nn.Conv2d(dim*4,dim,kernel_size=1)

    def forward(self,x):
        input = x
        x = self.depth_conv(x)
        x = self.norm(x)
        x = self.conv_1(x)
        x = self.act(x)
        x = self.conv_2(x)
        return(x+input)

In [68]:
class Encoder(nn.Module):
    def __init__(self,in_chans= 3, depths=[3, 3, 9, 3,3,1],dims=[96, 192, 384, 768,768,1536],dConv_kernel_size=7):
        super().__init__()
        self.layers = nn.ModuleList()
      
        for layer_n,depth in enumerate(depths):
            for sublayer_n in range(depth):
                self.layers.append(Conv_block(dims[layer_n],dConv_kernel_size))
            if layer_n < len(depths)-1:
                self.layers.append(nn.Conv2d(dims[layer_n],dims[layer_n+1],kernel_size= 2, stride = 2))

    def forward(self,x):
        for layer in self.layers:
            x = layer(x)
            print(x.shape)
        return(x)
    

class Decoder(nn.Module):
    def __init__(self,in_chans=1 ,out_chans=3 ,depths=[3, 3, 9, 3,3,1],dims=[96, 192, 384, 768,768,1536],dConv_kernel_size=7):
        super().__init__()
        self.depths = list(reversed(depths))
        self.dims = list(reversed(dims))
        self.layers = nn.ModuleList()
        for layer_n,depth in enumerate(self.depths):

            for _ in range(depth):
                self.layers.append(Conv_block(self.dims[layer_n],dConv_kernel_size))
            if layer_n < len(depths)-1:     
                self.layers.append(nn.ConvTranspose2d(self.dims[layer_n],self.dims[layer_n+1],kernel_size=2,stride=2))

    def forward(self,x):
        for layer in self.layers:
            x = layer(x)
        return(x)

In [69]:
depths=[3, 3, 3, 9, 3, 3, 3]
dims=[3, 6, 12, 24, 48,96,192]
dConv_kernel_size = 3
enc = Encoder(depths=depths,dims = dims, dConv_kernel_size = dConv_kernel_size)
dec = Decoder(depths=depths,dims = dims, dConv_kernel_size = dConv_kernel_size)

In [70]:
seed = torch.rand([7,3,64,64])

x = enc(seed)

torch.Size([7, 3, 64, 64])
torch.Size([7, 3, 64, 64])
torch.Size([7, 3, 64, 64])
torch.Size([7, 6, 32, 32])
torch.Size([7, 6, 32, 32])
torch.Size([7, 6, 32, 32])
torch.Size([7, 6, 32, 32])
torch.Size([7, 12, 16, 16])
torch.Size([7, 12, 16, 16])
torch.Size([7, 12, 16, 16])
torch.Size([7, 12, 16, 16])
torch.Size([7, 24, 8, 8])
torch.Size([7, 24, 8, 8])
torch.Size([7, 24, 8, 8])
torch.Size([7, 24, 8, 8])
torch.Size([7, 24, 8, 8])
torch.Size([7, 24, 8, 8])
torch.Size([7, 24, 8, 8])
torch.Size([7, 24, 8, 8])
torch.Size([7, 24, 8, 8])
torch.Size([7, 24, 8, 8])
torch.Size([7, 48, 4, 4])
torch.Size([7, 48, 4, 4])
torch.Size([7, 48, 4, 4])
torch.Size([7, 48, 4, 4])
torch.Size([7, 96, 2, 2])
torch.Size([7, 96, 2, 2])
torch.Size([7, 96, 2, 2])
torch.Size([7, 96, 2, 2])
torch.Size([7, 192, 1, 1])
torch.Size([7, 192, 1, 1])
torch.Size([7, 192, 1, 1])
torch.Size([7, 192, 1, 1])
