In [7]:
# UnetVAE Encoder that Downsamples till the bottleneck

import torch.nn as nn
class UnetVAEEncoder(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, latent_dim):
        super(UnetVAEEncoder, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.latent_dim = latent_dim
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.conv2 = nn.Conv2d(out_channels, out_channels*2, kernel_size, stride, padding)
        self.conv3 = nn.Conv2d(out_channels*2, out_channels*4, kernel_size, stride, padding)
        self.pool = nn.MaxPool2d(2, 2)
        self.batchnorm1 = nn.BatchNorm2d(out_channels)
        self.batchnorm2 = nn.BatchNorm2d(out_channels*2)
        self.batchnorm3 = nn.BatchNorm2d(out_channels*4)
        self.fc1 = nn.Linear(out_channels*4*7*7, latent_dim)
        self.fc2 = nn.Linear(out_channels*4*7*7, latent_dim)
        
    def forward(self, x):
        x = F.relu(self.batchnorm1(self.conv1(x)))
        x = self.pool(x)
        x = F.relu(self.batchnorm2(self.conv2(x)))
        x = self.pool(x)
        x = F.relu(self.batchnorm3(self.conv3(x)))
        x = self.pool(x)
        x = x.view(-1, self.out_channels*4*7*7)

        mu = self.fc1(x)
        log_var = self.fc2(x)
        return mu, log_var

# A decoder that upsamples from the bottleneck to the original image size but we take input of encoder layer and concatenate it with input from Latent Space (skip connections)

# Dont forget to add skip connection from Encoder to Decoder, for this the decoder will have to take outputs of encoder as inputs

class UnetVAEDecoder(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, latent_dim):
        super(UnetVAEDecoder, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.latent_dim = latent_dim
        self.fc1 = nn.Linear(latent_dim, out_channels*4*7*7)
        self.fc2 = nn.Linear(latent_dim, out_channels*4*7*7)
        self.conv1 = nn.Conv2d(out_channels*4, out_channels*2, kernel_size, stride, padding)
        self.conv2 = nn.Conv2d(out_channels*2, out_channels, kernel_size, stride, padding)
        self.conv3 = nn.Conv2d(out_channels, in_channels, kernel_size, stride, padding)
        self.batchnorm1 = nn.BatchNorm2d(out_channels*2)
        self.batchnorm2 = nn.BatchNorm2d(out_channels)
        self.batchnorm3 = nn.BatchNorm2d(in_channels)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
    def forward(self, x, mu, log_var):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = x.view(-1, self.out_channels*4, 7, 7)
        x = self.upsample(x)
        x = F.relu(self.batchnorm1(self.conv1(x)))
        x = self.upsample(x)
        x = F.relu(self.batchnorm2(self.conv2(x)))
        x = self.upsample(x)
        x = F.relu(self.batchnorm3(self.conv3(x)))
        return x

# The final UnetVAE model that combines the encoder and decoder

# Writing encoder and decoder inside class VAE UNET so that we can store outputs of encoder layers as self.encoder_outputs and use them in decoder

class UnetVAE(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, latent_dim):
        super(UnetVAE, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.latent_dim = latent_dim
        self.encoder = UnetVAEEncoder(in_channels, out_channels, kernel_size, stride, padding, latent_dim)
        self.decoder = UnetVAEDecoder(in_channels, out_channels, kernel_size, stride, padding, latent_dim)
        
    def forward(self, x):
        mu, log_var = self.encoder(x)
        z = self.reparameterize(mu, log_var)
        return self.decoder(z, mu, log_var)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return mu + eps*std

# The loss function for the VAE




UnetVAE(
  (encoder): UnetVAEEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv3): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (batchnorm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (batchnorm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (batchnorm3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (fc1): Linear(in_features=12544, out_features=20, bias=True)
    (fc2): Linear(in_features=12544, out_features=20, bias=True)
  )
  (decoder): UnetVAEDecoder(
    (fc1): Linear(in_features=20, out_features=12544, bias=True)
    (fc2): Linear(in_features=20, out_features=12544, bias=True)
    (conv1): Conv2d(256, 128, kernel_size=(3, 3