In [4]:
'''
Reference: https://github.com/pytorch/examples/blob/master/vae/main.py, 
           https://github.com/hwalsuklee/tensorflow-mnist-VAE
'''

import os
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image



BATCH_SIZE = 128
EPOCHS = 10
LOG_INTERVAL = 100
Z_DIM = 2
LEARNING_RATE = 1e-3
PRR = True
Z1_RANGE = 2
Z2_RANGE = 2
Z1_INTERVAL = 0.2
Z2_INTERVAL = 0.2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


train_data = datasets.FashionMNIST('./data', train=True, download=True,
                            transform=transforms.ToTensor())


# pin memory provides improved transfer speed
kwargs = {'num_workers': 1, 'pin_memory': True} if device == 'cuda' else {}

train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=BATCH_SIZE, shuffle=True, **kwargs)



# --- defines the model and the optimizer --- #
class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 521)
        self.fc21 = nn.Linear(512, 2)  # fc21 for mean of Z
        self.fc22 = nn.Linear(512, 2)  # fc22 for log variance of Z
        self.fc3 = nn.Linear(2, 512)
        self.fc4 = nn.Linear(512, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        mu = self.fc21(h1)
        # I guess the reason for using logvar instead of std or var is that
        # the output of fc22 can be negative value (std and var should be positive)
        logvar = self.fc22(h1)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.rand_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        # x: [batch size, 1, 28,28] -> x: [batch size, 784]
        x = x.view(-1, 784)
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)


# --- defines the loss function --- #
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = 0.5 * torch.sum(mu.pow(2) + logvar.exp() - logvar - 1)

    return BCE + KLD


# --- train and test --- #
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, label) in enumerate(train_loader):
        # data: [batch size, 1, 28, 28]
        # label: [batch size] -> we don't use
        optimizer.zero_grad()
        data = data.to(device)
        recon_data, mu, logvar = model(data)
        loss = loss_function(recon_data, data, mu, logvar)
        loss.backward()
        cur_loss = loss.item()
        train_loss += cur_loss
        optimizer.step()
        if batch_idx % LOG_INTERVAL == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100.*batch_idx / len(train_loader),
                cur_loss/len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)
    ))


# def test(epoch):
#     model.eval()
#     test_loss = 0
#     with torch.no_grad():
#         for batch_idx, (data, label) in enumerate(test_loader):
#             data = data.to(device)
#             recon_data, mu, logvar = model(data)
#             cur_loss = loss_function(recon_data, data, mu, logvar).item()
#             test_loss += cur_loss
#             if batch_idx == 0:
#                 # saves 8 samples of the first batch as an image file to compare input images and reconstructed images
#                 num_samples = min(BATCH_SIZE, 8)
#                 comparison = torch.cat(
#                     [data[:num_samples], recon_data.view(BATCH_SIZE, 1, 28, 28)[:num_samples]]).cpu()
#                 save_generated_img(
#                     comparison, 'reconstruction', epoch, num_samples)

#     test_loss /= len(test_loader.dataset)
#     print('====> Test set loss: {:.4f}'.format(test_loss))


# --- etc. funtions --- #
def save_generated_img(image, name, epoch, nrow=8):
    if not os.path.exists('results'):
        os.makedirs('results')

    if epoch % 5 == 0:
        save_path = 'results/'+name+'_'+str(epoch)+'.png'
        save_image(image, save_path, nrow=nrow)


def sample_from_model(epoch):
    with torch.no_grad():
        # p(z) = N(0,I), this distribution is used when calculating KLD. So we can sample z from N(0,I)
        sample = torch.randn(64, Z_DIM).to(device)
        sample = model.decode(sample).cpu().view(64, 1, 28, 28)
        save_generated_img(sample, 'sample', epoch)


def plot_along_axis(epoch):
    z1 = torch.arange(-Z1_RANGE, Z1_RANGE, Z1_INTERVAL).to(device)
    z2 = torch.arange(-Z2_RANGE, Z2_RANGE, Z2_INTERVAL).to(device)
    num_z1 = z1.shape[0]
    num_z2 = z2.shape[0]
    num_z = num_z1 * num_z2

    sample = torch.zeros(num_z, 2).to(device)

    for i in range(num_z1):
        for j in range(num_z2):
            idx = i * num_z2 + j
            sample[idx][0] = z1[i]
            sample[idx][1] = z2[j]

    sample = model.decode(sample).cpu().view(num_z, 1, 28, 28)
    save_generated_img(sample, 'plot_along_z1_and_z2_axis', epoch, num_z1)


# --- main function --- #
if __name__ == '__main__':
    for epoch in range(1, EPOCHS + 1):
        train(epoch)
        # test(epoch)
        sample_from_model(epoch)

        if PRR:
            plot_along_axis(epoch)

====> Epoch: 1 Average loss: 284.5368
====> Epoch: 2 Average loss: 269.4005
====> Epoch: 3 Average loss: 266.4905
====> Epoch: 4 Average loss: 264.6746
====> Epoch: 5 Average loss: 263.5057
====> Epoch: 6 Average loss: 262.6984
====> Epoch: 7 Average loss: 262.0534
====> Epoch: 8 Average loss: 261.5813
====> Epoch: 9 Average loss: 261.1659
====> Epoch: 10 Average loss: 260.7662
