In [7]:
from modules import VariationalAutoencoder
from torchvision import transforms, datasets
import torch
import numpy as np
from torch.utils.data import DataLoader
from torch import optim
import matplotlib.pyplot as plt

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [8]:
def plot_reconstructed(autoencoder, r0=(-5, 10), r1=(-10, 5), n=12):
    w = 28
    img = np.zeros((n * w, n * w))
    for i, y in enumerate(np.linspace(*r1, n)):
        for j, x in enumerate(np.linspace(*r0, n)):
            z = torch.Tensor([[x, y]]).to(DEVICE)
            x_hat = autoencoder.decoder(z)
            x_hat = x_hat.reshape(28, 28).to(DEVICE).detach().numpy()
            img[(n - 1 - i) * w : (n - 1 - i + 1) * w, j * w : (j + 1) * w] = x_hat
    plt.imshow(img, extent=[*r0, *r1])
    plt.show()


def train(autoencoder, optimizer, data_loader, num_epochs=10):
    outputs = []

    for epoch in range(num_epochs):
        for img, _ in data_loader:
            optimizer.zero_grad()

            recon = autoencoder(img)
            loss = ((img - recon) ** 2).sum() + autoencoder.encoder.kl

            loss.backward()
            optimizer.step()

        print(f"Epoch:{epoch+1}, Loss:{loss.item():.4f}")
        outputs.append((epoch, img, recon))
    return outputs, autoencoder

In [9]:
# config
latent_dims = 2

# data
transform = transforms.ToTensor()
data = datasets.MNIST(root="./data", download=True, train=True, transform=transform)
data_loader = DataLoader(dataset=data, batch_size=64, shuffle=True)

vae = VariationalAutoencoder(latent_dims)

optimizer = optim.Adam(vae.parameters(), lr=1e-4)

In [10]:
outputs, vae = train(
        vae, optimizer=optimizer, data_loader=data_loader, num_epochs=20
    )

Epoch:1, Loss:1611.7825
Epoch:2, Loss:1515.8401
Epoch:3, Loss:1380.7217
Epoch:4, Loss:1362.4689
Epoch:5, Loss:1380.8900
Epoch:6, Loss:1377.2250
Epoch:7, Loss:1399.5312
Epoch:8, Loss:1307.5063
Epoch:9, Loss:1398.9788
Epoch:10, Loss:1320.8795
Epoch:11, Loss:1421.5096
Epoch:12, Loss:1285.0626
Epoch:13, Loss:1303.9221
Epoch:14, Loss:1514.8007
Epoch:15, Loss:1148.7078
Epoch:16, Loss:1323.2950
Epoch:17, Loss:1285.6498
Epoch:18, Loss:1242.9644
Epoch:19, Loss:1379.0179
Epoch:20, Loss:1254.9299


In [None]:
plot_reconstructed(vae, (-3, 3), (-3, 3))