In [6]:
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import torchvision

kwargs = {'batch_size': 256, 'shuffle': True}
my_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])
train_loader = DataLoader(
    MNIST('./DATA',train=True, transform=my_transform, download=True), **kwargs)
test_loader = DataLoader(
    MNIST('./DATA', train=False, transform=my_transform, download=True), **kwargs)


Using downloaded and verified file: ./DATA/MNIST/raw/train-images-idx3-ubyte.gz
Extracting ./DATA/MNIST/raw/train-images-idx3-ubyte.gz to ./DATA/MNIST/raw
Using downloaded and verified file: ./DATA/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ./DATA/MNIST/raw/train-labels-idx1-ubyte.gz to ./DATA/MNIST/raw
Using downloaded and verified file: ./DATA/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting ./DATA/MNIST/raw/t10k-images-idx3-ubyte.gz to ./DATA/MNIST/raw
Using downloaded and verified file: ./DATA/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ./DATA/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./DATA/MNIST/raw
Processing...
Done!


In [19]:
import torch.nn as nn
import torch.utils.data.datasets

class VanillaVAE(nn.Module):

    """
    Input:
    inout_feature: (int) H*W
    hidden_feature: (int)
    latent_feature: (int)
    x: tensor (N, H*W) batch_size, image_size

    Output:
    mu_logvar: tensor (N, 2 * latent_feature)
    mu: tensor (N, latent_feature)
    log_var: tensor (N, latent_feature)
    x_hat: tensor (N, H*W)
    """

    def __init__(self, inout_feature, hidden_feature, latent_feature):
        super(VanillaVAE, self).__init__()
        self.latent_feature = latent_feature
        self.encoder = nn.Sequential(
                nn.Linear(inout_feature, hidden_feature),
                nn.ReLU(),
                nn.Linear(hidden_feature, 2 * self.latent_feature)
                )
        self.decoder = nn.Sequential(
                nn.Linear(self.latent_feature, hidden_feature),
                nn.ReLU(),
                nn.Linear(hidden_feature, inout_feature)
                )


    def Reparam(mu, log_var):
        var = log_var.exp()
        eps = torch.normal(mu, var)
        # z = var.mul(eps).add(mu)
        z = var * eps + mu
        return z

    def forward(self, x):
        mu_logvar = self.encoder(x)
        new_mulogvar = mu_logvar.reshape(-1, self.latent_feature, 2)
        mu = new_mulogvar[:, :, 0]
        log_var = new_mulogvar[:, :, 1]
        z = Reparam(mu, log_var)
        x_hat = self.decoder(z)
        return x_hat, mu, log_var

    def BernoulliDecoder():
        pass

    def GaussianDecoder():
        pass


def VanillaVAELoss(x, x_hat, mu, log_var, beta=1):
    # encode loss: logP(x|z)
    # BCEWithLogitsLoss=Sigmoid + BECLoss
    bcelog_loss = nn.BCEWithLogitsLoss()
    recon_los = bcelog_loss(x, x_hat)
    # KL divergence: -1/2(var**2 + mu**2 - log(var))
    kl_div = -0.5 * (var**2 + mu**2 - torch.logit(var))
    return recon_loss + beta * kl_div

In [20]:
# model
model = VanillaVAE(784, 400, 20)

# optimizer
import torch.optim

learning_rate = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [21]:
# Training
epoch = 10

for i_epoch in range(epoch):
    model.train()
    
    for x, _ in train_loader:
        #=============forward==============
        x_hat, mu, log_var = model(x)
        loss = VanillaVAELoss(x, x_hat, mu, log_var)
        tot_train_loss += loss
        #=============backward=============
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    #-----accuracy
    print(f'=============>Epoch: {i_epoch}, average Loss is {tot_train_loss/len(train_loader.dataset):.4f}')    





# Testing

RuntimeError: mat1 and mat2 shapes cannot be multiplied (7168x28 and 784x400)