In [1]:
import pandas as pd
import numpy as np
import torch
import torchvision
import torch.nn.functional as F
import torch.nn as nn
from torchvision import transforms

In [2]:
mnist_trainset = torchvision.datasets.MNIST(root="./data",train=True,download=True,transform=transforms.Compose([
                                                              transforms.ToTensor(), # first, convert image to PyTorch tensor
                                                              transforms.Normalize((0.1307,), (0.3081,)) # normalize inputs
                                                          ]))
mnist_testset = torchvision.datasets.MNIST(root="./data",train=False,download=True,transform=transforms.Compose([
                                                              transforms.ToTensor(), # first, convert image to PyTorch tensor
                                                              transforms.Normalize((0.1307,), (0.3081,)) # normalize inputs
                                                          ]))

In [34]:
# number of image, always 0, Image
mnist_trainset[50000][0][0].shape

torch.Size([28, 28])

### Encoder

In [67]:
class Encoder(nn.Module):
    def __init__(self,encoder_net):
        super(Encoder,self).__init__()
        self.encoder = encoder_net

    def encode(self,x):
        h_e = self.encoder(x)
        mu_e, log_var_e = torch.chunk(h_e,chunks=2)
        return mu_e, log_var_e

    def reparameterization(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return mu+std*eps
    
    def sample(self, mu_e, log_var_e):
        z = self.reparameterization(mu_e,log_var_e)
        return z

### Decoder

In [134]:
class Decoder(nn.Module):
    def __init__(self,decoder_net):
        super(Decoder,self).__init__()
        self.num_vals = 10
        self.decoder = decoder_net
    
    def decode(self,z):
        h_d = self.decoder(z)
        h_d = h_d.view(28,28,10)
        #h_d = h_d.view(h_d.shape[0],h_d.shape[1]//self.num_vals,self.num_vals)
        print(h_d)
        mu_d = torch.softmax(h_d,dim=2)
        return mu_d

    def sample(self,z):
        mu_d = self.decode(z)
        


### Prior

In [110]:
class Prior(nn.Module):
    def __init__(self,L):
        super(Prior, self).__init__()
        self.L = L
    
    def sample(self, batchsize=1):
        z = torch.randn((batchsize,self.L))
        return z

### VAE

In [111]:
class VAE(nn.Module):
    def __init__(self,encoder_net,decoder_net):
        super(VAE,self)
        self.encoder = Encoder(encoder_net)
        self.decoder = Decoder(decoder_net)
        self.Prior = Prior()
    
    def forward(self, x):
        mu_e,log_var_e = self.encoder.encode(x)
        z = self.encoder.sample(mu_e,log_var_e)

    def sample(self, batchsize = 1):
        z = self.Prior.sample(batchsize)
        return self.decoder.sample(z)

In [135]:
D = 28**2
L = 10
num_values = 10
encoder = nn.Sequential(nn.Linear(D,2*L))
decoder = nn.Sequential(nn.Linear(L,num_values*D))
prior = Prior(L)
enc = Encoder(encoder)
dec = Decoder(decoder)

In [136]:
mu, log = enc.encode(torch.flatten(mnist_trainset[0][0][0]))
z = enc.sample(mu,log)
o = dec.decode(z)

tensor([[[ 1.1334,  2.3207,  2.3796,  ..., -0.8393, -1.5502, -1.9580],
         [ 0.3444,  0.5484, -0.0282,  ..., -1.0623, -0.0676,  1.0569],
         [ 0.3858, -1.0046, -0.0151,  ..., -0.3151, -0.2835,  0.1694],
         ...,
         [-0.4716, -0.4144, -0.6936,  ..., -0.4585, -0.1470,  1.4116],
         [-2.2032, -1.0438,  1.3237,  ...,  0.4377, -0.2266,  1.4836],
         [ 1.2631, -0.7472,  0.8377,  ...,  0.0689, -0.9747, -1.3372]],

        [[-0.7560, -1.4160,  1.9319,  ...,  0.2219, -2.1717, -0.6164],
         [-1.3742, -0.5852, -0.2784,  ..., -0.9649,  1.4528,  0.4080],
         [-0.7109,  2.0445,  0.7940,  ...,  0.0040, -0.9433, -1.3810],
         ...,
         [-0.0633, -0.3249, -1.1690,  ...,  2.1233, -1.4329, -1.8994],
         [ 2.1626,  1.4148,  0.0403,  ...,  0.7065,  0.6346, -0.1113],
         [-0.7569,  0.3757,  0.2186,  ...,  0.6510,  0.2630, -1.3657]],

        [[-0.7712,  0.2651,  0.7440,  ..., -1.9657, -1.3228,  2.0920],
         [-0.3894, -0.9201, -1.2061,  ..., -0

In [137]:
o

[tensor([[[0.0667, 0.2186, 0.2318,  ..., 0.0093, 0.0046, 0.0030],
          [0.0382, 0.0469, 0.0263,  ..., 0.0094, 0.0253, 0.0780],
          [0.1126, 0.0280, 0.0754,  ..., 0.0559, 0.0577, 0.0907],
          ...,
          [0.0482, 0.0511, 0.0386,  ..., 0.0489, 0.0667, 0.3171],
          [0.0070, 0.0223, 0.2376,  ..., 0.0980, 0.0504, 0.2788],
          [0.2858, 0.0383, 0.1868,  ..., 0.0866, 0.0305, 0.0212]],
 
         [[0.0415, 0.0214, 0.6098,  ..., 0.1103, 0.0101, 0.0477],
          [0.0068, 0.0149, 0.0202,  ..., 0.0102, 0.1142, 0.0402],
          [0.0277, 0.4364, 0.1250,  ..., 0.0567, 0.0220, 0.0142],
          ...,
          [0.0639, 0.0492, 0.0211,  ..., 0.5688, 0.0162, 0.0102],
          [0.3459, 0.1638, 0.0414,  ..., 0.0806, 0.0750, 0.0356],
          [0.0387, 0.1200, 0.1026,  ..., 0.1581, 0.1072, 0.0210]],
 
         [[0.0276, 0.0779, 0.1258,  ..., 0.0084, 0.0159, 0.4841],
          [0.0673, 0.0396, 0.0297,  ..., 0.0699, 0.1563, 0.0441],
          [0.1006, 0.0437, 0.0160,  ...,