In [None]:
class VAE(nn.Module):
    def __init__(self,in_channels,out__channels): #the size (batch,in_channels,54,81)
        super(VAE, self).__init__()
        
        # Encoder
        self.enc_conv1 = nn.Conv2d(in_channels, 21, kernel_size=4, stride=2, padding=1) 
        self.enc_conv2 = nn.Conv2d(21, 32, kernel_size=4, stride=2, padding=1)  
        self.enc_conv3 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
        self.enc_conv4 = nn.Conv2d(64, 96, kernel_size=4, stride=2, padding=1)
        self.fc1 = nn.Linear(96*3*5, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc2_mu = nn.Linear(128, in_channels)
        self.fc2_logvar = nn.Linear(128, in_channels)
        
        # Decoder
        self.fc3 = nn.Linear(in_channels, 128)
        self.fc4 = nn.Linear(128, 96*3*5)
        self.dec_conv1 = nn.ConvTranspose2d(96*2, 64, kernel_size=4, stride=2, padding=1)
        self.dec_conv2 = nn.ConvTranspose2d(64*2, 32, kernel_size=4, stride=2, padding=1)  
        self.dec_conv3 = nn.ConvTranspose2d(32*2, 21, kernel_size=4, stride=2, padding=1)  
        self.dec_conv4 = nn.ConvTranspose2d(21*2, 3, kernel_size=4, stride=2, padding=1)   
        self.last_conv= nn.ConvTranspose2d(20,out__channels , kernel_size=3, stride=1, padding=1)   # with the channel of the input
        
    def encode(self, x):
        h1 = torch.relu(self.enc_conv1(x))
        h2 = torch.relu(self.enc_conv2(h1))
        h3 = torch.relu(self.enc_conv3(h2))
        h4 = torch.relu(self.enc_conv4(h3))
        h = h4.view(-1,  96*3*5)
        h = torch.relu(self.fc1(h))
        h = torch.relu(self.fc2(h))
        return self.fc2_mu(h), self.fc2_logvar(h) ,[h4,h3,h2,h1]
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z,in_x,layers):
        h = torch.relu(self.fc3(z))
        h = torch.relu(self.fc4(h))
        h = h.view(-1, 96, 3, 5)
        h=torch.cat((h,layers[0]),dim=1)
        h = torch.relu(self.dec_conv1(h))
        h=torch.cat((h,layers[1]),dim=1)
        h = torch.relu(self.dec_conv2(h))
        padding = (0,0, (13-h.shape[2])//2,(13-h.shape[2]) - (13-h.shape[2]) // 2)
        padded_h = F.pad(h, padding)      
        h=torch.cat((padded_h,layers[2]),dim=1)
        h = torch.relu(self.dec_conv3(h))
        padding = (0,0, (27-h.shape[2])//2,(27-h.shape[2]) - (27-h.shape[2]) // 2)
        padded_h = F.pad(h, padding)
        h=torch.cat((padded_h,layers[3]),dim=1)
        h = self.dec_conv4(h)
        padding = (0, 81-h.shape[3], (54-h.shape[2])//2,(54-h.shape[2]) - (54-h.shape[2]) // 2)
        padded_h = F.pad(h, padding)
        h_cat=torch.cat((padded_h,in_x),dim=1)
        h=self.last_conv(h_cat)
        return h 
    
    def forward(self, x):
        in_x=x.clone()
        mu, logvar,layers= self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z,in_x,layers), mu, logvar

def loss_function_mse(recon_x, x, mu, logvar):
    MSE = nn.functional.mse_loss(recon_x, x)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return MSE + KLD,MSE