In [1]:
import torch
import torch.nn.functional as F

from torch import nn, optim
from torch.utils.data import DataLoader

from torchvision import datasets, transforms
from torchvision.utils import save_image

from tqdm.notebook import tqdm

In [2]:
class VAE(nn.Module):
    def __init__(self, input_size, output_size, latent_size):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(input_size, 400)
        self.fc21 = nn.Linear(400, latent_size)  # Mean μ
        self.fc22 = nn.Linear(400, latent_size)  # Log variance log(σ^2)
        self.fc3 = nn.Linear(latent_size, 400)
        self.fc4 = nn.Linear(400, 400)
        self.fc5 = nn.Linear(400, output_size)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    @staticmethod
    def re_parameterize(mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        h4 = F.relu(self.fc4(h3))
        return torch.sigmoid(self.fc5(h4))

    def forward(self, x):
        mu, log_var = self.encode(x.view(-1, 784))
        z = self.re_parameterize(mu, log_var)
        return self.decode(z), mu, log_var


In [3]:
def vae_loss(recon_x, x, mu, log_var):
    bce = nn.functional.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    kld = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

    return bce + kld

In [4]:
transform = transforms.ToTensor()

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [5]:
device = 'mps'

model = VAE(784, 784, 40).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [6]:
def train(epoch: int):
    model.train()
    train_loss = 0

    p_bar = tqdm(total=len(train_loader))
    p_bar.set_description(f'Epoch {epoch}')

    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = vae_loss(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

        p_bar.update(1)
        p_bar.set_postfix({'loss': loss.item() / len(data)})
    p_bar.update(1)
    p_bar.set_postfix({'loss': train_loss / len(train_loader.dataset)})

In [8]:
for epoch in range(1, 10):
    train(epoch)

  0%|          | 0/1875 [00:00<?, ?it/s]

  0%|          | 0/1875 [00:00<?, ?it/s]

  0%|          | 0/1875 [00:00<?, ?it/s]

  0%|          | 0/1875 [00:00<?, ?it/s]

  0%|          | 0/1875 [00:00<?, ?it/s]

  0%|          | 0/1875 [00:00<?, ?it/s]

  0%|          | 0/1875 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [10]:
with torch.no_grad():
    sample = torch.randn(64, 20).to(device)
    sample = model.decode(sample).cpu()
    save_image(sample.view(64, 1, 28, 28), 'sample2.png')