In [None]:
import torch
import numpy as np

import matplotlib.pyplot as plt

%matplotlib inline
%config InlineBackend.figure_format = "retina"


from typing import Any, Tuple

from torch.utils.data import DataLoader
from torchvision.datasets import MNIST


from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms

import lightning.pytorch as L
from einops import rearrange

torch.manual_seed(2023)
np.random.seed(2023)
torch.set_float32_matmul_precision("medium")

In [None]:
# trainset = FashionMNIST(root='~/.cache/torchvision_cache', train=True, download=True,
#                         transform=transforms.Compose([
#                         transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]))


trainset = MNIST(download=True, train=True, transform=transforms.ToTensor())

In [None]:
plt.imshow(trainset[0][0].squeeze(), cmap="gray")
print(trainset[0][0].size())

In [None]:
trainloader = DataLoader(trainset, batch_size=128, shuffle=True)

In [None]:
class VAELoss(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

    def forward(
        self,
        x: torch.Tensor,
        x_hat: torch.Tensor,
        mean: torch.Tensor,
        log_var: torch.Tensor,
    ) -> torch.Tensor:
        reproduction_loss = F.binary_cross_entropy(
            input=x_hat, target=x, reduction="sum"
        )
        kl_divergence = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())

        return reproduction_loss + kl_divergence

In [None]:
class VAE(L.LightningModule):
    def __init__(self) -> None:
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 512),
            nn.LeakyReLU(0.2),
        )

        # latent space
        self.latent_mean = nn.Linear(512, 256)
        self.latent_log_var = nn.Linear(512, 256)

        self.decoder = nn.Sequential(
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 784),
            nn.LeakyReLU(0.2),
            nn.Sigmoid(),
        )

    def reparameterisation(
        self, mean: torch.Tensor, log_var: torch.Tensor
    ) -> torch.Tensor:
        # assuming log_var
        # un-log and then sqrt to get the std-dev
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std).to(self.device)
        return mean + eps * std

    def forward(self, x: torch.Tensor) -> Any:
        out = self.encoder(x)

        mean = self.latent_mean(out)
        mean = F.leaky_relu(mean, 0.2)

        log_var = self.latent_log_var(out)
        log_var = F.leaky_relu(log_var, 0.2)

        z = self.reparameterisation(mean, log_var)

        x_hat = self.decoder(z)

        return x_hat, mean, log_var

    def configure_optimizers(self) -> OptimizerLRScheduler:
        return optim.AdamW(self.parameters(), lr=1e-3)

    def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
        x, _ = batch

        x = torch.flatten(x, start_dim=1)

        x_hat, mean, log_var = self(x)

        loss = F.binary_cross_entropy(x_hat, x, reduction="sum")
        kl = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
        loss += kl

        self.log("train_loss", loss, prog_bar=True)

        return {"loss": loss, "log": {"Loss/Training": loss}}


# model = VAE()
# x = torch.randn(8, 784)
# y = torch.zeros(8, )

# x_hat, mean, log_var = model(x)
# print(x_hat)

# loss = F.binary_cross_entropy(x_hat, x, reduction="sum")
# kl = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
# print(loss + kl)

In [None]:
def train() -> Any:
    logger = L.loggers.TensorBoardLogger(
        "tb_logs", name="vae_fashion_mnist", log_graph=True
    )

    model = VAE()
    trainer = L.Trainer(max_epochs=5, devices=1, accelerator="gpu", logger=logger)
    # trainer = L.Trainer(max_epochs=5, devices=1, accelerator="cpu", logger=logger)
    trainer.fit(model, trainloader)

    return model


model = train()

In [None]:
# %load_ext tensorboard
# %tensorboard --logdir tb_logs/


In [None]:
with torch.no_grad():
    x = torch.randn(
        256,
    ).to(model.device)
    out = model.decoder(x)

    out = torch.unflatten(out, -1, (28, 28))
    print(out.size())

    plt.imshow(out, cmap="gray")

In [None]:
with torch.no_grad():
    image = torch.flatten(trainset[0][0])
    print(trainset[0][1])
    out, _, _ = model(image)

    out = torch.unflatten(out, -1, (28, 28))

plt.imshow(out, cmap="gray")

In [None]:
@torch.no_grad()
def inference(image: torch.Tensor) -> torch.Tensor:
    flattened_image = torch.flatten(image)
    out, _, _ = model(flattened_image)
    out = torch.unflatten(out, -1, (28, 28))

    return out

In [None]:
inference(trainset[0][0]).shape

In [None]:
def randomly_pick_n(n: int = 8, max=len(trainset)) -> Tuple:
    indexes = torch.randint(0, max, size=(n,)).tolist()
    selection = [trainset[i][0] for i in indexes]
    actuals = [torch.from_numpy(np.array(im, dtype=np.float32)) for im in selection]
    generated = [inference(im) for im in selection]

    stacked_generated = torch.stack(generated)
    stacked_actuals = torch.stack(actuals)

    # drop the extra dim in actuals
    stacked_actuals = rearrange(stacked_actuals, "b 1 h w -> b h w")

    return stacked_generated, stacked_actuals

In [None]:
generated, actuals = randomly_pick_n()
print(generated.size())
print(actuals.size())

In [None]:
def comparison_plot(generated: list, actuals: list) -> None:
    fig = plt.figure(1, figsize=[12, 6])
    fig.tight_layout()

    for idx, gen in enumerate(generated):
        ax = fig.add_subplot(1, len(generated), idx + 1)
        ax.set_title(f"generated_{idx}")
        plt.imshow(gen)
        plt.axis("off")

    for idx, act in enumerate(actuals):
        ax = fig.add_subplot(2, len(actuals), idx + 1)
        ax.set_title(f"actual_{idx}")
        plt.imshow(act)
        plt.axis("off")

    fig.suptitle("Actual (upper) vs Generated(lower)")

In [None]:
comparison_plot(generated.tolist(), actuals.tolist())