In [30]:
import torch
import torchvision
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [31]:
transform = torchvision.transforms.ToTensor()

train_dataset = torchvision.datasets.MNIST(root="./data",
                                           train=True,
                                           download=True, transform=transform)

test_dataset = torchvision.datasets.MNIST(root="./data",
                                           train=False,
                                           download=True, transform=transform)

In [32]:
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=32,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=32,
                                          shuffle=False)

In [39]:
class SimpleAE(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.hidden = 8

        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(784, 256),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(256, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, self.hidden),
            torch.nn.ReLU(inplace=True)
        )

        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(self.hidden, 64),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(64, 256),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(256, 784),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        encode = self.encoder(x)
        decode = self.decoder(x)
        return encode, decode

In [40]:
# We will create the derived class where we will add our mu, sigma

class VariationalAE(SimpleAE):
    def __init__(self):
        super().__init__()
        # For our reparametrization trick
        self.mu = torch.nn.Linear(self.hidden, self.hidden)
        self.sigma = torch.nn.Linear(self.hidden, self.hidden)


    def reparametrize(self, mu, sigma):
        std = torch.exp(0.5*sigma)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        encoded = self.encoder(x)
        mu = self.mu(encoded)
        sigma = self.sigma(encoded)

        eta = self.reparametrize(mu, sigma)
        decoded = self.decoder(eta)
        return encoded, decoded, mu, sigma

In [34]:
def loss(x_p, x, mu, sigma):
    bce = torch.nn.functional.binary_cross_entropy(x_p, x.view(-1, 784), reduction="sum")
    # KLD is the Kullback-Leibler divergence between the latent variables and the standard Gaussian
    kld = -0.5 * torch.sum(1 + sigma - mu.pow(2) - sigma.exp())
    return bce + kld

In [45]:
var_ae = VariationalAE().to(device)
optimizer = torch.optim.Adam(var_ae.parameters(), lr=1e-3, weight_decay=1e-5)
criterion = torch.nn.MSELoss(reduction="sum")
num_epochs = 10

# Because of the extra added KLD term, the error is higher than usual.
for epoch in range(num_epochs):
    total_loss = 0.0
    for i, data in enumerate(train_loader):
        images, _ = data
        images = images.to(device)
        images = images.view(images.size(0), -1)
        encoded, decoded, mu, sigma = var_ae(images)
        kld = -0.5 * torch.sum(1 + sigma - mu.pow(2) - sigma.exp())
        loss = criterion(decoded, images) + kld
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)
    e_loss = total_loss / len(train_loader.dataset)
    print(f"Epoch {epoch + 1} out of {num_epochs}, Loss: {e_loss}")

Epoch 1 out of 10, Loss: 1466.954576171875
Epoch 2 out of 10, Loss: 1161.0882458007814
Epoch 3 out of 10, Loss: 1101.2987404622395
Epoch 4 out of 10, Loss: 1070.334257747396
Epoch 5 out of 10, Loss: 1051.8157483072916
Epoch 6 out of 10, Loss: 1039.3870731770833
Epoch 7 out of 10, Loss: 1028.581184407552
Epoch 8 out of 10, Loss: 1021.4821490885416
Epoch 9 out of 10, Loss: 1015.2931165039063
Epoch 10 out of 10, Loss: 1008.6276
