In [None]:
import torch.nn as nn
import torch

In [None]:
#Since this is a variational autoencoder, we will have a Flatten class to flatten the final feature map into a 1D vector
#And a Unflatten class to change the latent vector z into the feature maps again.

class Flatten(nn.Module):
    def forward(self , input):
        return input.view(input.size(0) , -1)
    
class Unflatten(nn.Module):
    def __init__(self , channel , height , width):
        super(Unflatten , self).__init__()
        
        self.channel = channel
        self.height = height
        self.width = width
          
    def forward(self , input):
        return input.view(input.size(0) , self.channel , self.height , self.width)

In [None]:
#Variational Autoencoders work best when the weights of the various layers are initialized between -0.08 and 0.08.
#So here we sample the weights of the layers from a uniform distribution having upper bound = 0.08 and lower bound = -0.08

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.uniform_(m.weight, -0.08 , 0.08)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.uniform_(m.weight, -0.08 , 0.08)
        torch.nn.init.zeros_(m.bias)
    if classname.find('ConvTranspose') != -1:
        torch.nn.init.uniform_(m.weight, -0.08 , 0.08)

In [None]:
class convVAE(nn.Module):
    
    def __init__(self , latent_vector_size):
        super(convVAE , self).__init__()
        
        self.latent_vector_size = latent_vector_size
        
        self.encoder = nn.Sequential(
             
            nn.Conv2d(3 , 32 , 3 , 1 , 1) ,
            nn.ReLU() , 
            nn.MaxPool2d(2 , 2) , 
            nn.ReLU() ,
            nn.BatchNorm2d(32) ,
             
            
            nn.Conv2d(32 , 64 , 3 , 1 , 1) , 
            nn.ReLU() , 
            nn.MaxPool2d(2 , 2) , 
            nn.ReLU() ,
            nn.BatchNorm2d(64) ,
            
            nn.Conv2d(64 , 128 , 3 , 1 , 1) , 
            nn.ReLU() , 
            nn.MaxPool2d(2 , 2) ,
            nn.ReLU() ,
            nn.BatchNorm2d(128) , 
            
            nn.Conv2d(128 , 256 , 3 , 1 , 1) , 
            nn.ReLU() , 
            nn.MaxPool2d(2 , 2) , 
            nn.ReLU() , 
            nn.BatchNorm2d(256) , 
            
            Flatten() , 
            nn.Linear(4096 , 1024) ,
            nn.Linear(1024 , 32) ,
            nn.ReLU()            
            )
        
        self.mu = nn.Linear(32 , self.latent_vector_size)
        self.logvar = nn.Linear(32 , self.latent_vector_size)
        
        self.decoder = nn.Sequential(
            nn.Linear(self.latent_size , 1024) , 
            nn.ReLU() , 
            nn.Linear(1024 , 4096) , 
            nn.ReLU() , 
            Unflatten(256 , 4 , 4) , 
           
            nn.ConvTranspose2d(256 , 128 , 2 , 2) , 
            nn.ReLU() , 
            nn.BatchNorm2d(128) , 
            
            nn.ConvTranspose2d(128 , 64 , 2 , 2) , 
            nn.ReLU() , 
            nn.BatchNorm2d(64) ,
            
            nn.ConvTranspose2d(64 , 32 , 2 , 2) , 
            nn.ReLU() , 
            nn.BatchNorm2d(32) ,
            
            nn.ConvTranspose2d(32 , 3 , 2 , 2) , 
            nn.Sigmoid()
            )
        
    def encode(self , x):
        h = self.encoder(x)
        mu , logvar = self.mu(h) , self.logvar(h)
        return mu , logvar            
        
    def reparameterize(self , mu , logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu +eps * std
        return z
        
    def decode(self , z):
        decoded_z = self.decoder(z)
        return decoded_z
        
    def forward(self , x):
        mu , logvar = self.encode(x)
        z = self.reparameterize(mu , logvar)
        return self.decode(z) , mu , logvar
            
vae = convVAE(32)
vae.to("cuda")
vae.apply(weights_init)