In [5]:
import os
from typing import NamedTuple, override

import numpy as np
import torch
from torch import Tensor, nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image

## Model

In [119]:
class VAE(nn.Module):
    """Variational Autoencoder."""

    def __init__(self, latent_dim:int, use_affin:bool, use_bce: bool):
        """Initialize module."""
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Flatten(),  # [batch, 28 x 28 x 1] = [batch, 784]
            nn.Linear(28 * 28, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
        )
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)

        # -6: affine parameters (translation and rotation in 2-d Euclidean space)
        # +2: number of coordinates
        self.decoder_fc = nn.Linear(latent_dim - 6 + 2, 256)
        self.decoder = nn.Sequential(
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid() if use_bce else nn.Tanh(),
        )

        coord = torch.cartesian_prod(
            torch.linspace(-1, 1, 28), torch.linspace(-1, 1, 28)
        )
        coord = torch.reshape(coord, (28, 28, 2)).unsqueeze(0)  # [1, 28, 28, 2]
        self.register_buffer("coord", coord)
        self.use_affine = use_affin

    def encode(self, inputs: Tensor) -> tuple[Tensor, Tensor]:
        """Encode inputs.
        Args:
            inputs (torch.Tensor): input image
        Returns:
            mu (torch.Tensor): mean vector of posterior dist.
            logvar (torch.Tensor): log-starndard deviation vector of posterior dist.
        """
        hidden = self.encoder(inputs)
        mu = self.fc_mu(hidden)
        logvar = self.fc_logvar(hidden)
        return mu, logvar

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """Perform reparameterization trick.
        Args:
            mu (torch.Tensor): mean vector
            logvar (torch.Tensor): log-starndard deviation vector
        Returns:
            latent (torch.Tensor): latent vector
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        latent = mu + eps * std
        return latent

    def augment_latent(
        self, latent: Tensor, use_affine: bool, scale: float = 0.1
    ) -> Tensor:
        """Augment latent vector.
        Args:
            latent (torch.Tensor): latent vector
            use_affine (bool): flag to apply affine transform
            scale (float): scaling factor for affine transform
        Returns:
            outputs (torch.Tensor): augmented latent vector
        """
        batch_size = latent.shape[0]  # batch
        h_size = self.coord.shape[1]  # 28
        w_size = self.coord.shape[2]  # 28

        coord = self.coord.repeat(batch_size, 1, 1, 1)  # [batch, 28, 28, 2]
        if use_affine:
            affine = torch.reshape(latent[:, -6:], (-1, 2, 3))  # [batch, 2, 3]
            zeros = torch.zeros_like(affine[:, 0:1, :])  # [batch, 1, 3]
            affine = torch.concat([affine, zeros], dim=-2)  # [batch, 3, 3]
            affine = scale * affine + torch.eye(3).to(latent.device)  # [batch, 3, 3]
            ones = torch.ones_like(coord[:, :, :, 0:1])  # [batch, 28, 28, 1]
            coord = torch.concat([coord, ones], dim=-1)  # [batch, 28, 28, 3]
            # apply affin to coord
            coord = torch.einsum("bhwj, bji -> bhwi", coord, affine)
            coord = coord[:, :, :, 0:2]  # [batch, 28, 28, 2]

        latent_ = latent[:, :-6]  # [batch, 20]
        latent_ = latent_[:, :, None, None]  # [batch, 20, 1, 1]
        latent_ = torch.permute(latent_, (0, 2, 3, 1))  # [batch, 1, 1, 20]
        latent_ = latent_.repeat(1, h_size, w_size, 1)  # [batch, 28, 28, 20]

        outputs = torch.concat([coord, latent_], dim=-1)  # [batch, 28, 28, 22]
        outputs = torch.reshape(outputs, (-1, outputs.shape[-1]))
        return outputs  # [batch * 28 * 28, 22] = [100352, 22]

    def decode(self, latent: Tensor, batch_size: int, use_affine: bool) -> Tensor:
        """Decode latent vector.
        Args:
            latent (torch.Tensor): latent vector
            batch_size (int): batch size
            use_affine (bool): flag to apply affine transform
        Returns:
            reconst (torch.Tensor): reconstructed image
        """
        latent = self.augment_latent(latent, use_affine)
        hidden = self.decoder_fc(latent)
        hidden = self.decoder(hidden)
        hidden = torch.reshape(
            hidden, (batch_size, self.coord.shape[1], self.coord.shape[2], 1)
        )
        reconst: Tensor = torch.permute(hidden, (0, 3, 1, 2))
        return reconst

    @override
    def forward(self, inputs: Tensor) -> tuple[Tensor, Tensor, Tensor]:
        """Forward propagation.
        Args:
            inputs (torch.Tensor): input image
        Returns:
            reconst (torch.Tensor): reconstructed image
            mu (torch.Tensor): mean vector of posterior dist.
            logvar (torch.Tensor): log-starndard deviation vector of posterior dist.
        """
        mu, logvar = self.encode(inputs)
        latent = self.reparameterize(mu, logvar)
        reconst = self.decode(latent, inputs.shape[0], self.use_affine)
        return reconst, mu, logvar

In [7]:
def get_dataloader(
    is_train: bool, transform: transforms.Compose, batch_size: int
) -> DataLoader[tuple[Tensor, Tensor]]:
    """Get a dataloader for training or validation.
    Args:
        is_train (bool): a flag to determine training mode
        transform (transforms.Compose): a chain of transforms to be applied
        batch_size (int): batch size of data loader
    Returns:
        dataloader (Dataloader): a dataloader for training.
    """
    if is_train is True:
        dataset = datasets.MNIST(
            root="./data", train=True, transform=transform, download=True
        )
        dataloader = DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=True,
            drop_last=True,
        )
    else:
        dataset = datasets.MNIST(
            root="./data", train=False, transform=transform, download=True
        )
        dataloader = DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=False,
            drop_last=False,
        )
    return dataloader


def loss_function(model: VAE, inputs: Tensor, use_bce: bool) -> Tensor:
    """Compute loss function (ELBO).
    Args:
        model (VAE): VAE module
        inputs (torch.Tensor): input image
        use_bce (bool): flag to apply BCE loss or MSE loss
    Returns:
        loss (torch.Tensor): Evidence Lower Bound (ELBO)
    """
    reconst, mu, logvar = model(inputs)
    if use_bce:
        reconst_loss = nn.BCELoss(reduction="sum")(reconst, inputs)
    else:
        reconst_loss = nn.MSELoss(reduction="sum")(reconst, inputs)
    kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    loss: Tensor = reconst_loss + kl_divergence
    return loss

In [21]:
def generate_sample(
    model: VAE, latent_dim: int, val_data: Tensor, epoch: int, device:str
) -> None:
    """Generate samples from trained model.
    Args:
        model (VAE): VAE module
        latent_dim (int): configuration for model
        val_data (torch.Tensor): validation data
        epoch (int): current epoch
    """
    os.makedirs("VAE", exist_ok=True)
    batch_size = val_data.shape[0]
    with torch.no_grad():
        latent = torch.randn(batch_size, latent_dim).to(device)
        generated_images = model.decode(latent, batch_size, False)
        images = generated_images.cpu().view(val_data.size())
        save_image(images[:batch_size], f"VAE/generated_image_{epoch+1}.png")

        # save reconstructed images of validation data for comparison
        mu, logvar = model.encode(val_data)
        latent = model.reparameterize(mu, logvar)
        val_reconstructed = model.decode(latent, batch_size, False)
        val_reconstructed = val_reconstructed.view(val_data.size())
        comparison = torch.cat([val_data.cpu(), val_reconstructed.cpu()], dim=3)
        save_image(comparison, f"VAE/reconstructed_image_{epoch+1}.png")

In [117]:
USE_BCE = False
BATCH_SIZE = 8
LR = 1e-2

LATENT_DIM = 256

if USE_BCE:
    transform = transforms.Compose([transforms.ToTensor()])
else:
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )

train_loader = get_dataloader(True, transform, BATCH_SIZE)
test_loader = get_dataloader(False, transform, BATCH_SIZE)
model = VAE(LATENT_DIM, True, USE_BCE).to("cpu")
optimizer = optim.Adam(model.parameters(), lr=LR)

In [118]:
# prepare validation data
val_data, _ = next(iter(test_loader))
val_data = val_data.to("cpu")

for epoch in range(10):
    model.train()
    epoch_loss = []
    for data, _ in train_loader:
        data = data.to("cpu")
        optimizer.zero_grad()
        loss = loss_function(model, data, USE_BCE)
        epoch_loss.append(loss.item())
        loss.backward()
        optimizer.step()
    print(f"Epoch: {epoch+1}, Average Loss: {np.average(epoch_loss):.12f}")
    # visualise training progress by generating samples from current model
    model.eval()
    generate_sample(model, LATENT_DIM, val_data, epoch, "cpu")

print("Training finished.")

Epoch: 1, Average Loss: 5768.258027343750
Epoch: 2, Average Loss: 5735.025183398438
Epoch: 3, Average Loss: 5733.266077278646
Epoch: 4, Average Loss: 5732.236746256511
Epoch: 5, Average Loss: 5731.455840071614
Epoch: 6, Average Loss: 5730.011067513021
Epoch: 7, Average Loss: 5730.554882975261
Epoch: 8, Average Loss: 5729.782995638021
Epoch: 9, Average Loss: 5729.337324641927
Epoch: 10, Average Loss: 5729.182792903646
Training finished.


## Check

In [27]:
flt = nn.Flatten()
test = flt(data)

In [69]:
mu, logvar = model.encode(data)

In [71]:
latent = model.reparameterize(mu=mu, logvar=logvar)

In [72]:
batch_size = latent.shape[0]  # batch
h_size = 28
w_size = 28
scale = 0.1

In [89]:
coord = model.coord.repeat(batch_size, 1, 1, 1)  # [batch, 28, 28, 2]

In [74]:
affine = torch.reshape(latent[:, -6:], (-1, 2, 3))  # [batch, 2, 3]
zeros = torch.zeros_like(affine[:, 0:1, :])  # [batch, 1, 3]
affine = torch.concat([affine, zeros], dim=-2)  # [batch, 3, 3]

In [75]:
affine = scale * affine + torch.eye(3).to(latent.device)  # [batch, 3, 3]

In [76]:
ones = torch.ones_like(coord[:, :, :, 0:1])  # [batch, 28, 28, 1]
_coord = torch.concat([coord, ones], dim=-1)  # [batch, 28, 28, 3]

In [77]:
new_coord = torch.einsum("bhwj, bji -> bhwi", coord, affine)

In [79]:
_coord.shape, affine.shape

(torch.Size([8, 28, 28, 3]), torch.Size([8, 3, 3]))

In [87]:
_coord[0, :3, :3]

tensor([[[-1.0000, -1.0000,  1.0000],
         [-1.0000, -0.9259,  1.0000],
         [-1.0000, -0.8519,  1.0000]],

        [[-0.9259, -1.0000,  1.0000],
         [-0.9259, -0.9259,  1.0000],
         [-0.9259, -0.8519,  1.0000]],

        [[-0.8519, -1.0000,  1.0000],
         [-0.8519, -0.9259,  1.0000],
         [-0.8519, -0.8519,  1.0000]]])

In [84]:
affine[0, :3, :3]

tensor([[ 1.0110,  0.0181, -0.0665],
        [-0.0701,  1.1150, -0.0434],
        [ 0.0000,  0.0000,  1.0000]], grad_fn=<SelectBackward0>)

In [86]:
e_coord[0, :3, :3]

tensor([[[-0.9410, -1.1331,  1.1098],
         [-0.9461, -1.0505,  1.1066],
         [-0.9513, -0.9679,  1.1034]],

        [[-0.8661, -1.1317,  1.1049],
         [-0.8713, -1.0491,  1.1017],
         [-0.8764, -0.9666,  1.0985]],

        [[-0.7912, -1.1304,  1.1000],
         [-0.7964, -1.0478,  1.0968],
         [-0.8016, -0.9652,  1.0936]]], grad_fn=<SliceBackward0>)

In [90]:
affine = torch.reshape(latent[:, -6:], (-1, 2, 3))  # [batch, 2, 3]
zeros = torch.zeros_like(affine[:, 0:1, :])  # [batch, 1, 3]
affine = torch.concat([affine, zeros], dim=-2)  # [batch, 3, 3]
affine = scale * affine + torch.eye(3).to(latent.device)  # [batch, 3, 3]
ones = torch.ones_like(coord[:, :, :, 0:1])  # [batch, 28, 28, 1]
coord = torch.concat([coord, ones], dim=-1)  # [batch, 28, 28, 3]
coord = torch.einsum("bhwj, bji -> bhwi", coord, affine)
coord = coord[:, :, :, 0:2]  # [batch, 28, 28, 2]

latent_ = latent[:, :-6]  # [batch, dim-6]
latent_ = latent_[:, :, None, None]  # [batch, dim-6, 1, 1]
latent_ = torch.permute(latent_, (0, 2, 3, 1))  # [batch, 1, 1, dim-6]
latent_ = latent_.repeat(1, h_size, w_size, 1)  # [batch, 28, 28, dim-6]

outputs = torch.concat([coord, latent_], dim=-1)  # [batch, 28, 28, dim-4]

In [96]:
outputs = torch.reshape(outputs, (-1, outputs.shape[-1]))

In [97]:
outputs.shape  # [batch * 28 * 28, dim-4] = [100352, dim-4]

torch.Size([6272, 20])