In [7]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.datasets import MNIST

import pytorch_lightning as pl

In [8]:
bs = 256
train_dl = DataLoader(
    MNIST("data", train=True, download=True, transform=T.ToTensor()),
    num_workers=4, 
    batch_size=bs
)

test_dl = DataLoader(
    MNIST("data", train=False, download=True, transform=T.ToTensor()),
    num_workers=4, 
    batch_size=bs
)

In [39]:
def kl_divergence(log_var, mu):
    return torch.mean(0.5 * torch.sum(log_var.exp() - log_var - 1 + mu.pow(2), dim=1))


class VAE(pl.LightningModule):
    def __init__(self, d=16, lr=1e-2):
        super().__init__()
        self.save_hyperparameters()
        
        self.d = d
        self.lr = lr
        
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 128),
            nn.ReLU(),
            nn.Linear(128, d * 2),
        )
        self.decoder = nn.Sequential(
            nn.Linear(d, 128),
            nn.ReLU(),
            nn.Linear(128, 28 * 28),
            nn.Sigmoid(),
        )
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)
        
    def forward(self, x, train=True):
        enc_out = self.encoder(x).view(-1, 2, self.d)
        mu, log_var = enc_out[:, 0, :], enc_out[:, 1, :]
        if train:
            std = torch.exp(log_var * 0.5)
            eps = torch.randn_like(std)
            z = (eps * std) + mu
        else:
            # no need to random sample if not training
            z = mu
        return self.decoder(z), mu, log_var
    
    def training_step(self, batch, batch_idx):
        x = batch[0].view(-1, 28 * 28)
        x_hat, mu, log_var = self.forward(x)
        loss = (
            F.binary_cross_entropy(x_hat, x) + kl_divergence(log_var, mu)
        )
        return pl.TrainResult(loss)

In [40]:
trainer = pl.Trainer(max_epochs=20)
model = VAE()
trainer.fit(model, train_dl)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 104 K 
1 | decoder | Sequential | 103 K 


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…




1

60000