# Modules

In [None]:
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

print("using", device)

torch.manual_seed(777)
if device == 'cuda':
    torch.cuda.manual_seed(777)

In [None]:
import argparse

def parse_args():
    # Hyperparameters
    args = argparse.ArgumentParser(description='VAE')

    # # Code for Module.py
    # args.add_argument('--batch_size', type=int, default=256)
    # args.add_argument('--epoch', type=int, default=500)
    # args.add_argument('--lr', type=float, default=0.01)
    # args.add_argument('--device', type=str, default=device)
    # args.add_argument('--eval_per_epoch', type=int, default=10)
    # args.parse_args()

    # Code for Colab
    args.batch_size = 2048
    args.epoch = 500
    args.lr = 0.0001
    args.device = device
    args.eval_per_epoch = 10
    return args


In [None]:
class VAEEncoder(torch.nn.Module):
    def __init__(self, size = (28, 28), latent_size = 32):
        super().__init__()
        self.size = 1
        for i in size:
            self.size *= i
        self.latent_size = latent_size
        self.fc1 = torch.nn.Linear(self.size, self.size//2)
        self.fc2 = torch.nn.Linear(self.size//2, 256)
        self.fc3_mu = torch.nn.Linear(256, latent_size)
        self.fc3_var = torch.nn.Linear(256, latent_size)
        self.activation = torch.nn.ReLU()

    def forward(self, x):
        # x : tensor size of(batch_size, size(at the init))
        # q(z|x) ~ N(mu, var)
        # mu : tensor size of (batch_size, latent_size(at the init))
        # log_var : tensor size of (batch_size, latent_size(at the init)) -> this means log(sigma^2)
        x = x.view(-1, self.size)
        x = self.fc1(x)
        x = self.activation(x)
        x = self.fc2(x)
        x = self.activation(x)
        mu = self.fc3_mu(x)
        log_var = self.fc3_var(x)
        return mu, log_var


class VAEDecoder(torch.nn.Module):
    def __init__(self, size = (28, 28), latent_size = 32):
        super().__init__()
        self.orig_size = size
        self.size = 1
        for i in size:
            self.size *= i
        self.latent_size = latent_size
        self.fc1 = torch.nn.Linear(self.latent_size, 256)
        self.fc2 = torch.nn.Linear(256, self.size//2)
        self.fc3 = torch.nn.Linear(self.size//2, self.size)
        self.activation = torch.nn.ReLU()

    def forward(self, x):
        # x : tensor size of(batch_size, latent_size(at the init))
        # p(x|z) ~ Normal distribution
        # return : tensor size of (batch_size, size(at the init))
        x = self.fc1(x)
        x = self.activation(x)
        x = self.fc2(x)
        x = self.activation(x)
        x = self.fc3(x)
        x = x.view(-1, self.orig_size[0], self.orig_size[1])
        return x


class VAE(torch.nn.Module):
    def __init__(self, size = (28, 28), latent_size = 32, device = 'cuda'):
        # size : size of input (tuple)
        super().__init__()
        self.size = size
        self.latent_size = latent_size
        self.device = device
        self.encoder = VAEEncoder(size = self.size, latent_size = self.latent_size)
        self.decoder = VAEDecoder(size = self.size, latent_size = self.latent_size)

    def forward(self, x):
        mu, log_var = self.encoder(x)
        latent = self.sample(mu, log_var)
        out = self.decoder(latent)
        return mu, log_var, out

    def sample(self, mu, log_var):
        epsilon = torch.normal(0, 1, size=(mu.size(0), mu.size(1))).to(self.device)
        return mu + epsilon * torch.exp(log_var / 2)

    def make_latent(self, x):
        mu, log_var = self.encoder(x)
        return self.sample(mu, log_var)

    def forward_from_latent(self, latent):
        out = self.decoder(latent)
        return out

# Train

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as img
import torch

In [None]:
def VAELoss(ground_truth, mu_pred, logvar_pred, out_pred, reconst_func = torch.nn.MSELoss(reduction='sum')):
    # p(x|z) ~ Normal distribution -> reconst func = torch.nn.MSELoss()
    # p(x|z) ~ Bernoulli distribution -> reconst func = torch.nn.BCELoss()
    if len(out_pred.size()) == 3:
      out_pred = out_pred.unsqueeze(1)
    Reconstruction_loss = reconst_func(ground_truth, out_pred) / out_pred.size(dim = 0)
    KLD_loss = 0.5 * torch.mean(mu_pred.pow(2) + logvar_pred.exp() - logvar_pred - 1)
    return Reconstruction_loss + KLD_loss

In [None]:
def test(args, model, test_loader):
    loss_avg = 0
    total_img = 0
    for img, _ in test_loader:
        img = img.to(args.device)
        mu_pred, logvar_pred, out_pred = model(img)
        loss = VAELoss(img, mu_pred, logvar_pred, out_pred)

        loss_avg += loss.item() * img.size(dim = 0)
        total_img += img.size(dim = 0)
    loss_avg /= total_img
    print('[EVAL]\t\t loss :', loss_avg)

In [None]:
def train(args, model, train_loader, test_loader):
    optimizer = torch.optim.Adam(model.parameters(), lr = args.lr)
    for i in range(args.epoch):
        loss_avg = 0
        total_img = 0
        for img, _ in train_loader:
            optimizer.zero_grad()
            img = img.to(args.device)
            mu_pred, logvar_pred, out_pred = model(img)
            loss = VAELoss(img, mu_pred, logvar_pred, out_pred)
            loss.backward()
            optimizer.step()

            loss_avg += loss.item() * img.size(dim = 0)
            total_img += img.size(dim = 0)

        loss_avg /= total_img
        print('[EPOCH]', (i + 1), '\t\t loss :', loss_avg)

        if (i + 1) % args.eval_per_epoch == 0:
            test(args, model, test_loader)


In [None]:
def visualize(args, model, latent):
    # latent : tensor size of (batch, latent_size)
    # only first 10 latents will be visualized
    latent = latent.to(args.device)
    out = model.forward_from_latent(latent).cpu().detach().numpy()

    size = latent.size(dim = 0)
    size = size if size < 10 else 10
    for i in range(size):
        plt.subplot(size//2, 2, i + 1)
        plt.imshow(out[i] * 256, cmap='gray')

    plt.show()


# Main

In [None]:
import torchvision
import torch

In [None]:
args = parse_args()

In [None]:
mnist_train = torchvision.datasets.MNIST(
    root = '../MNIST_data',
    train = True,
    transform = torchvision.transforms.ToTensor(),
    download = True
)

mnist_test = torchvision.datasets.MNIST(
    root = '../MNIST_data',
    train = False,
    transform = torchvision.transforms.ToTensor(),
    download = True
)


In [None]:
train_data_loader = torch.utils.data.DataLoader(
    dataset = mnist_train,
    shuffle = True,
    batch_size = args.batch_size,
    drop_last = True,
)

test_data_loader = torch.utils.data.DataLoader(
    dataset = mnist_test,
    shuffle = False,
    batch_size = args.batch_size,
    drop_last = True,
)

In [None]:
model = VAE(device = args.device).to(args.device)
train(args, model, train_data_loader, test_data_loader)

In [None]:
# Inference from test data
for img, _ in test_data_loader:
    latent = model.make_latent(img[:10].to(args.device))
    visualize(args, model, latent)
    break



In [None]:
# Inference from random noise
visualize(args, model, torch.randn(10, 32))