In [2]:
import numpy
import webdataset
import torch
import torch.nn as nn
import torch.nn.functional as F

In [115]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.stride = stride
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=2, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv2 = nn.Conv2d(in_channels, in_channels*2, kernel_size=1, stride=stride, padding=0)
        self.bn2 = nn.BatchNorm2d(in_channels*2)
        self.conv3 = nn.Conv2d(in_channels*2,in_channels*2, kernel_size=1, stride=1, padding=0)
        self.bn3 = nn.BatchNorm2d(in_channels*2)
        
        self.match_conv = nn.Conv2d(in_channels, in_channels*2, kernel_size=2, stride=stride, padding=0)
    
    def match_input(self, x):
        x = self.match_conv(x)
        return x
    
    def forward(self, x):
        block_input = x
        # First Convolution
        x = F.leaky_relu(self.bn1(self.conv1(x)))
        # Second Convolution
        x = F.leaky_relu(self.bn2(self.conv2(x)))
        # Third Convolution
        x = self.bn3(self.conv3(x))
        block_input = self.match_input(block_input)
        # Add residual
        x += block_input
        x = F.leaky_relu(x)
        return x

In [170]:
class Encoder(nn.Module):
    def __init__(self, in_channels):
        super(Encoder, self).__init__()
        # Residual Feature Extraction
        self.block_1 = ResidualBlock(in_channels, stride=1)
        self.pool_1 = nn.AvgPool2d(2)
        self.block_2 = ResidualBlock(in_channels*2, stride=2)
        self.pool_2 = nn.AvgPool2d(2)
        self.block_3 = ResidualBlock(in_channels*4, stride=2)
        self.pool_2 = nn.AvgPool2d(2)
        
        self.flatten_shape = None
        if self.flatten_shape is None:
            with torch.no_grad():
                zer = torch.zeros(size=(1,4,33,33))
                zer = self.convolutions(zer)
                self.flatten_shape = zer.shape[1]
                
        self.fc1 = nn.Linear(self.flatten_shape, 512)
        self.fc2 = nn.Linear(512, 256)
        
    def convolutions(self, x):
        x = self.block_1(x)
        x = self.block_2(x)
        x = self.block_3(x)
        x = x.view(x.shape[0], -1)
        return x

    def forward(self, x):
        x = self.convolutions(x)
        x = self.fc1(x)
        x = self.fc2(x)
        return x


enc = Encoder(4)

In [181]:
class Decoder(nn.Module):
    def __init__(self, input_nodes):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(input_nodes, 512)
        self.fc2 = nn.Linear(512, 32*8*8)
        self.convt1 = nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=2, stride=2, padding=0)
        self.convt2 = nn.ConvTranspose2d(in_channels=16, out_channels=8, kernel_size=2, stride=2, padding=0)
        self.convt3 = nn.ConvTranspose2d(in_channels=8, out_channels=4, kernel_size=2, stride=2, padding=0)
        
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x).view(-1, 32, 8, 8)
        x = self.convt1(x)
        x = self.convt2(x)
        x = self.convt3(x)
        x = F.interpolate(x, size=(33,33))
        return x

In [182]:
class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        self.encoder = Encoder(in_channels=4)
        self.decoder = Decoder(input_nodes=256)
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        print(x.shape)
        
ae = AutoEncoder()
zero = torch.zeros(size=(1,4,33,33))
ae.forward(zero)

torch.Size([1, 4, 33, 33])
