In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
class EncMNIST(nn.Module):
    def __init__(self, latent_dim_mnist):
        super(EncMNIST, self).__init__()
        self.latent_dim_mnist = latent_dim_mnist
        self.dim_MNIST = 28 * 28

        self.enc = nn.Sequential(nn.Linear(self.dim_MNIST, 512),
                                 nn.ReLU(inplace=True), 
                                 nn.Linear(512, 128),
                                 nn.ReLU(inplace=True))
        self.enc_mu_mnist = nn.Linear(128, latent_dim_mnist)
        self.enc_var_mnist = nn.Linear(128, latent_dim_mnist)

    def forward(self, x):
        x = self.enc(x)
        mu_mnist = self.enc_mu_mnist(x)
        log_var_mnist = self.enc_var_mnist(x)
        return mu_mnist, log_var_mnist


    
class DecMNIST(nn.Module):
    def __init__(self, latent_dim):
        super(DecMNIST, self).__init__()  
        self.latent_dim = latent_dim + 10
        self.dim_MNIST   = 28 * 28
        
        self.dec = nn.Sequential(nn.Linear(self.latent_dim, 128), 
                                 nn.ReLU(inplace=True),
                                 nn.Linear(128, 512), 
                                 nn.ReLU(inplace=True),
                                 nn.Linear(512, self.dim_MNIST), 
                                 nn.Sigmoid())
        
    def loss_function(self,x_rec, x, mu, log_var):
        x_rec = x_rec.flatten(-3,-1)
        color_level = 256
        x = (x * (color_level - 1)).floor().long()
        ce_loss = nn.CrossEntropyLoss(reduction='sum')
        MSE = ce_loss(x_rec,x.long())

        KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        total_loss = MSE+KLD
        return total_loss, MSE,KLD    
    
    def forward(self, z,x,generate_mode,mu, log_var):
        rec = self.dec(z).to(device)
        total_loss, MSE,KLD  = self.loss_function(rec, x, mu, log_var)

        return (rec, total_loss, MSE,KLD)
    
    
class pixelcnn_decoder(nn.Module):
    def __init__(self, pixelcnn):
        super(pixelcnn_decoder, self).__init__()
        self.pixelcnn = pixelcnn
        
    def loss_function(self,x_rec, x, mu, log_var):
        x_rec = x_rec.flatten(-3,-1)
        color_level = 256
        x = (x * (color_level - 1)).floor().long()
        ce_loss = nn.CrossEntropyLoss(reduction='sum')
        MSE = ce_loss(x_rec,x.long())

        KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        total_loss = MSE+KLD
        return total_loss, MSE,KLD    
    
    def forward(self, z, img,generate_mode,mu, log_var ):
        img_out = img.reshape(z.shape[0],1,28,28).to(z.device)
        if generate_mode is False:
            sample = self.pixelcnn(img_out, z )
        else:
            shape = [1,28,28]
            count = z.shape[0]
            sample = self.pixelcnn.sample(img_out,shape,count, z )
            sample = torch.exp(sample[0])
#             sample = self.pixelcnn(img_out, z )
        total_loss, MSE,KLD  = self.loss_function(sample, img, mu, log_var)
        return (sample,total_loss, MSE,KLD )