In [1]:
import torch

Encoder and decoders are both MLPs.

#### Decoder

\begin{align*}
    \log p_{\theta}\left(x\,\middle|\, z\right) & = \log N\left(x; \mu_z, \sigma_z^2I\right) \\
        & \text{where} \\
        h & = \tanh\left(W_1 z + b_1\right) \\
        \mu_z & = W_2h + b_2 \\
        \log \sigma_z^2 & = W_3 h + b_3 \\
\end{align*}

Thus $\theta = \left\{W_{1:3}, b_{1:3}\right\}$.

#### Encoder

\begin{align*}
    \log q_{\phi}\left(z\,\middle|\, x\right) & = \log N\left(z; \mu_x, \sigma_x^2I\right) \\
        & \text{where} \\
        k & = \tanh\left(W_4 z + b_4\right) \\
        \mu_x & = W_5k + b_5 \\
        \log \sigma_x^2 & = W_6 h + b_6 \\
\end{align*}

Thus $\phi = \left\{W_{4:6}, b_{4:6}\right\}$.

In [13]:
class EncoderNetwork(torch.nn.Module):
    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.linear1 = torch.nn.Linear(in_features=in_features, out_features=20)
        self.linear2 = torch.nn.Linear(in_features=20, out_features=out_features) # mean
        self.linear3 = torch.nn.Linear(in_features=20, out_features=out_features) # sigma2

    def forward(self, x):
        k = torch.tanh(self.linear1(x))
        mu = self.linear2(k)
        log_sigma2 = self.linear3(k)

        return torch.tensor([mu, log_sigma2])


class DecoderNetwork(torch.nn.Module):
    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.linear1 = torch.nn.Linear(in_features=in_features, out_features=20)
        self.linear2 = torch.nn.Linear(in_features=20, out_features=out_features) # mean
        self.linear3 = torch.nn.Linear(in_features=20, out_features=out_features) # sigma2

    def forward(self, x):
        k = torch.tanh(self.linear1(x))
        mu = self.linear2(k)
        log_sigma2 = self.linear3(k)

        return torch.tensor([mu, log_sigma2])

In [14]:
enc = EncoderNetwork(in_features=1, out_features=1)

x = torch.tensor([[3.0]])
enc(x)

tensor([-0.2988,  0.0646])

Fit this model to the MNIST data.

In [19]:
import torch
import torchvision
import matplotlib.pyplot as plt
import numpy as np

In [18]:
# download the data
training_data = torchvision.datasets.MNIST(
    root='./data', # where to save it
    train=True,
    download=True,
    transform=torchvision.transforms.ToTensor(),
)

test_data = torchvision.datasets.MNIST(
    root='./data',
    train=False, # create test data
    download=True,
    transform=torchvision.transforms.ToTensor(),
)

train_dl = torch.utils.data.DataLoader(training_data, batch_size=64, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=True)

In [42]:
class EncoderNetwork(torch.nn.Module):
    def __init__(self, data_dim: int, latent_dim: int):
        super().__init__()
        self.linear1 = torch.nn.Linear(in_features=data_dim, out_features=20)
        self.linear2 = torch.nn.Linear(in_features=20, out_features=latent_dim) # mean
        self.linear3 = torch.nn.Linear(in_features=20, out_features=latent_dim) # sigma2

    def forward(self, x):
        x = torch.flatten(x, 1)
        k = torch.tanh(self.linear1(x))
        mu = self.linear2(k)
        log_sigma2 = self.linear3(k)

        out = torch.concat([mu, log_sigma2], axis=1)

        print(out)
        return out


class DecoderNetwork(torch.nn.Module):
    def __init__(self, latent_dim: int, data_dim: int):
        super().__init__()
        self.linear1 = torch.nn.Linear(in_features=latent_dim, out_features=20)
        self.linear2 = torch.nn.Linear(in_features=20, out_features=data_dim) # mean
        self.linear3 = torch.nn.Linear(in_features=20, out_features=data_dim) # sigma2

    def forward(self, x):
        k = torch.tanh(self.linear1(x))
        mu = self.linear2(k)
        log_sigma2 = self.linear3(k)
        out = torch.concat([mu, log_sigma2], axis=1)

        return out


class VAE(torch.nn.Module):
    def __init__(self, data_dim: int, latent_dim: int):
        super().__init__()
        self.encoder = EncoderNetwork(data_dim=data_dim, latent_dim=latent_dim)
        self.decoder = DecoderNetwork(data_dim=data_dim, latent_dim=latent_dim)

    def forward(self, x):
        phi = self.encoder(x) # mu, sigma
        mu = phi[:, :self.latent_dim]
        sigma = torch.exp(phi[:, self.latent_dim:])

        # generate random noise variable
        n_eps = phi.shape[1] // 2
        eps = torch.randn(self.latent_dim)
        z = mu + eps * sigma

        return z



vae = VAE(data_dim=28*28, latent_dim=2)

In [41]:
enc = EncoderNetwork(28*28, 1)

for batch in test_dl:
    # print(batch[0].shape)
    enc(batch[0])

tensor([[-0.0648, -0.1620],
        [-0.0646, -0.1727],
        [-0.1493, -0.0748],
        [-0.1605,  0.0052],
        [-0.1472,  0.1409],
        [-0.2330,  0.1807],
        [-0.1487,  0.1362],
        [-0.1266,  0.0065],
        [-0.0931,  0.0740],
        [-0.0215, -0.0582],
        [-0.2338,  0.0517],
        [-0.2458,  0.2126],
        [-0.2033,  0.0763],
        [-0.1341,  0.0424],
        [-0.1394,  0.1522],
        [-0.2388,  0.0768],
        [-0.1911,  0.1474],
        [-0.1092,  0.0641],
        [-0.1832,  0.1572],
        [-0.1382,  0.0862],
        [-0.0339, -0.1871],
        [-0.1110,  0.0141],
        [-0.1480,  0.3610],
        [-0.0407, -0.0481],
        [-0.1741,  0.2677],
        [-0.2559,  0.1464],
        [-0.1891,  0.1998],
        [-0.1611,  0.1766],
        [-0.0936, -0.1710],
        [-0.1442,  0.1974],
        [-0.1955,  0.1206],
        [-0.1525,  0.2327],
        [-0.1218,  0.0917],
        [-0.1052,  0.0766],
        [-0.1263,  0.2857],
        [-0.2289,  0

In [61]:
class VAEEncoder(torch.nn.Module):
    def __init__(self, data_shape: int, latent_dim: int, model):
        super().__init__()
        self.latent_dim = latent_dim
        self.data_shape = data_shape
        self.model = model
        self.output_size = self.get_output_size()

        # Output layers: output distribution params
        self.linear_mu = torch.nn.Linear(
            in_features=self.output_size,
            out_features=latent_dim,
            )
        self.linear_var = torch.nn.Linear(
            in_features=self.output_size,
            out_features=latent_dim,
            )

    def get_output_size(self):
        device = next(self.model.parameters()).device.type
        size = self.model(torch.zeros(1, *self.data_shape, device=device)).size(1)
        return size

    def forward(self, x):
        base_out = self.model(x)

        # This was standard AE so far. Next is the variational part.
        # Push outputs of base model through the MLPs to get
        # the mean and log_variance of the distribution.
        # (record params for use in loss calculation)
        self.mu = self.linear_mu(base_out)
        self.log_var = self.linear_var(base_out)
        std = torch.exp(self.log_var/2)

        # Generate random noise inputs
        eps = torch.randn_like(self.mu)

        # Calculate latent vars
        z = self.mu + eps * std

        return z


    def kl_loss(self):
        kl_loss = -0.5 * (1 + self.log_var - self.mu**2 - torch.exp(self.log_var))
        return kl_loss


In [62]:
latent_dim = 1
input_shape = (1, 28, 28)

base_model = torch.nn.Sequential(
    torch.nn.Flatten(),
    torch.nn.Linear(np.prod(input_shape), 2048),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(2048, 2048),
    torch.nn.LeakyReLU(),
)

var_enc = VAEEncoder(input_shape, 1, base_model)

var_dec = torch.nn.Sequential(
    torch.nn.Linear(latent_dim, 2048),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(2048, 2048),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(2048, np.prod(input_shape)),
    torch.nn.Unflatten(1, input_shape),
)


class AutoEncoder(torch.nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, x):
        """Encode and then reconstruct the input."""
        z = self.encoder(x)
        return self.decoder(z)


model_vae = AutoEncoder(var_enc, var_dec)

In [63]:
for batch in test_dl:
    x = batch[0]
    # z = var_enc(x)
    # x_hat = var_dec(z)
    x_hat = model_vae(x)

In [64]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_vae.to(device)
loss_fn = torch.nn.MSELoss(reduction='none')
optim = torch.optim.Adam(model_vae.parameters(), 0.0003)

num_epochs = 30

train_losses = []

reconstruction_loss_factor = 1

for epoch in range(1, num_epochs+1):
    batch_losses = []
    for i, (x, _) in enumerate(train_dl):
        model_vae.train()
        x = x.to(device)

        # Step 1 - Computes our model's predicted output - forward pass
        yhat = model_vae(x)

        # Step 2 - Computes the loss
        # reduce (sum) over pixels (dim=[1, 2, 3])
        # and then reduce (sum) over batch (dim=0)
        loss = loss_fn(yhat, x).sum(dim=[1, 2, 3]).sum(dim=0)
        # reduce (sum) over z (dim=1)
        # and then reduce (sum) over batch (dim=0)
        kl_loss = model_vae.encoder.kl_loss().sum(dim=1).sum(dim=0)
        # we're adding the KL loss to the original MSE loss
        total_loss = reconstruction_loss_factor * loss + kl_loss

        # Step 3 - Computes gradients
        total_loss.backward()
        # Step 4 - Updates parameters using gradients and the learning rate
        optim.step()
        optim.zero_grad()
        
        batch_losses.append(np.array([total_loss.data.item(), 
                                      loss.data.item(), 
                                      kl_loss.data.item()]))

    # Average over batches
    train_losses.append(np.array(batch_losses).mean(axis=0))

    print(f'Epoch {epoch:03d} | Loss >> {train_losses[-1][0]:.4f}/ \
            {train_losses[-1][1]:.4f}/{train_losses[-1][2]:.4f}')

Epoch 001 | Loss >> 2957.8439/             2822.8459/134.9980
Epoch 002 | Loss >> 2774.3949/             2599.2895/175.1054
Epoch 003 | Loss >> 2710.8907/             2524.6728/186.2179
Epoch 004 | Loss >> 2677.4004/             2484.7674/192.6331
Epoch 005 | Loss >> 2646.7884/             2448.2837/198.5047
Epoch 006 | Loss >> 2624.1446/             2419.9611/204.1835
Epoch 007 | Loss >> 2606.4339/             2397.8940/208.5399
Epoch 008 | Loss >> 2599.0817/             2387.5249/211.5568
Epoch 009 | Loss >> 2580.4330/             2366.1935/214.2395
Epoch 010 | Loss >> 2565.7926/             2347.9003/217.8923
Epoch 011 | Loss >> 2556.3262/             2336.7809/219.5453
Epoch 012 | Loss >> 2548.8412/             2326.6746/222.1666
Epoch 013 | Loss >> 2536.5919/             2312.9841/223.6079
Epoch 014 | Loss >> 2540.7751/             2316.2083/224.5668
Epoch 015 | Loss >> 2531.5051/             2306.0886/225.4166
Epoch 016 | Loss >> 2519.2260/             2292.3914/226.8346
Epoch 01