In [None]:
import numpy as np
import pandas as pd

In [None]:
data = pd.read_csv("./data/A_Z Handwritten Data.csv")
X = data.drop(columns="0")
y = data["0"].astype("float32")

In [None]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

In [None]:
import pandas as pd
from torch.utils.data import Dataset


class HandwritingDataset(Dataset):
    def __init__(self, X: pd.DataFrame, y: pd.Series, transform=None, target_transform=None):
        self.img_labels = y
        self.images = [row.reshape(28, 28) for _, row in X.iterrows()]
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label


train_dataset = HandwritingDataset(X_train, y_train)
test_dataset = HandwritingDataset(X_train, y_test)

In [None]:
import torchvision

transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

# mnist data
train_dataset = torchvision.datasets.MNIST(root='data/mnist', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='data/mnist', train=False, transform=transform, download=True)

In [None]:
import torch
import torch.nn as nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# put into batches
batch_size = 64
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

In [None]:
class VAE(nn.Module):
    def __init__(self, num_latent_var: int):
        super(VAE, self).__init__()
        self.num_latent_var = num_latent_var
        self.input_encoder = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=4, kernel_size=3, stride=1),
            nn.Conv2d(in_channels=4, out_channels=4, kernel_size=3, stride=2),
            nn.Conv2d(in_channels=4, out_channels=1, kernel_size=2, stride=1),
        )
        self.y_encoder = nn.Sequential(nn.Linear(11 * 11, 128), nn.ReLU(), nn.Linear(128, 10), nn.Softmax(dim=1))
        self.z_mean = nn.Linear(11 * 11, num_latent_var * num_latent_var)
        self.log_z_var = nn.Linear(11 * 11, num_latent_var * num_latent_var)

        self.y_decoder = nn.Sequential(nn.Linear(10, 128), nn.ReLU(), nn.Linear(128, 11 * 11))
        self.z_decoder = nn.Linear(num_latent_var * num_latent_var, 11 * 11)

        self.output_decoder = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=4, kernel_size=2, stride=1),
            nn.UpsamplingNearest2d(scale_factor=2),
            nn.Conv2d(in_channels=4, out_channels=4, kernel_size=2, stride=1),
            nn.UpsamplingNearest2d(scale_factor=2),
            nn.Conv2d(in_channels=4, out_channels=1, kernel_size=1, stride=1),
            nn.Sigmoid(),
        )

    def encode(self, x):
        x = self.input_encoder(x)
        y_pred = self.y_encoder(x.reshape(x.shape[0], 11 * 11))
        z_mean = self.z_mean(x.reshape(-1, 11 * 11))
        log_z_var = self.log_z_var(x.reshape(-1, 11 * 11))
        return y_pred, z_mean, log_z_var

    def forward(self, x):
        y_pred, z_mean, log_z_var = self.encode(x)
        std = log_z_var.mul(0.5).exp_()
        epsilon = torch.randn(*z_mean.size()).to(device)
        z = z_mean + std * epsilon
        x_hat = self.decode(z, y_pred)
        return x_hat, z_mean, log_z_var, y_pred

    def decode(self, z, y):
        y_decoder_input = self.y_decoder(y).reshape(-1, 1, 11, 11)
        z_decoder_input = self.z_decoder(z).reshape(-1, 1, 11, 11)
        return self.output_decoder(y_decoder_input + z_decoder_input)

    def loss(self, x, y):
        x_hat, z_mean, log_z_var, y_pred = self.forward(x)
        categorisation_loss = nn.CrossEntropyLoss()(y_pred, y)
        reconstruction_loss = nn.BCELoss()(x_hat, x)
        kl_div_loss = -0.5 * torch.sum(1 + log_z_var - z_mean.pow(2) - log_z_var.exp()) / x.shape[0]
        return categorisation_loss, reconstruction_loss, kl_div_loss

In [None]:
model = VAE(num_latent_var=2).to(device)
optimiser = torch.optim.Adam(model.parameters())

In [None]:
num_epochs = 10

for epoch in range(1, num_epochs + 1):
    minloss = 1
    running_kl_loss = 0
    running_recons_loss = 0
    running_cat_loss = 0
    num_images = 0
    for i, (img, label) in enumerate(train_loader):
        img = img.to(device)
        label = label.to(device)
        optimiser.zero_grad()
        cat_loss, recons_loss, kl_loss = model.loss(img, label)
        loss = recons_loss + 0.1 * cat_loss + epoch * 0.001 * kl_loss
        optimiser.backward()
        optimiser.step()
        running_cat_loss = running_cat_loss + cat_loss.item() * len(img)
        running_recons_loss = running_recons_loss + recons_loss.item() * len(img)
        running_kl_loss = running_kl_loss + kl_loss.item() * len(img)

        num_images = num_images + len(img)
    print(
        'epoch: '
        + str(epoch)
        + ' cat_loss: '
        + str(running_cat_loss / num_images)
        + ' recons_loss: '
        + str(running_recons_loss / num_images)
        + ' kl_loss: '
        + str(running_kl_loss / num_images)
    )

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

In [None]:
recons, z_mean, log_z_var, ysoft = 0, 0, 0, 0
for im, l in test_loader:
    recons, z_mean, log_z_var, ysoft = model.forward(im.to(device))
    break

In [None]:
labels = []
for i in range(10):
    for j in range(6):
        labels.append(i)
labels = torch.Tensor(np.array(labels)).long().to(device)

In [None]:
zl = []
for i in range(10):
    e = torch.randn(6, 11 * 11)
    std = log_z_var[:6].mul(0.5).exp_()
    z = z_mean[:6].cpu().detach() + e * std.cpu().detach()
    zl.append(np.array(z))
zl = np.array(zl)
zl = torch.Tensor(zl.reshape(60, 49)).to(device)

In [None]:
imgs = model.decode(labels, zl).cpu().detach().reshape(60, 28, 28)

plt.gray()
fig = plt.figure(figsize=(10.0, 6.0))
grid = ImageGrid(
    fig,
    111,  # similar to subplot(111)
    nrows_ncols=(10, 6),  # creates 2x2 grid of axes
    axes_pad=0.05,  # pad between axes in inch.
)

for ax, im in zip(grid, imgs):
    # Iterating over the grid returns the Axes.
    ax.imshow(im)

plt.show()