In [80]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt

In [81]:
torch.manual_seed(123)

<torch._C.Generator at 0x7fbff5809e50>

In [82]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=8)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=testset.data.shape[0], shuffle=True, num_workers=8)

In [83]:
train_data = trainset.data
train_labels = trainset.targets
test_data = testset.data
test_labels = testset.targets
print(train_data.shape)
print(train_labels.shape)
print(test_data.shape)
print(test_labels.shape)

torch.Size([60000, 28, 28])
torch.Size([60000])
torch.Size([10000, 28, 28])
torch.Size([10000])


$\mathbb{KL}\left( \mathcal{N}_{x}(\mu_{x}, \sigma_{x}) \parallel \mathcal{N}(0, 1) \right) = \sum_{x \in X} \left( \sigma_{x}^2 + \mu_{x}^2 - \log \sigma_{x} - \frac{1}{2} \right)$

In [84]:
class VariationalAutoEncoder(nn.Module):
    def __init__(self, input_size, latent_size):
        super(VariationalAutoEncoder, self).__init__()

        self.latent_size = latent_size
        self.input_size = input_size

        sub_hidden_sizes = [256, 64]
        
        self.encoder = nn.Sequential(
            nn.Flatten(start_dim=1, end_dim=3),
            nn.Linear(input_size, sub_hidden_sizes[0], bias=True),
            nn.Sigmoid(),
            nn.Linear(sub_hidden_sizes[0], sub_hidden_sizes[1], bias=True),
            nn.Sigmoid(),
            nn.Linear(sub_hidden_sizes[1], 2*latent_size, bias=True), #####
        )

        self.decoder = nn.Sequential(
            nn.Linear(latent_size, sub_hidden_sizes[1], bias=True),
            nn.Sigmoid(),
            nn.Linear(sub_hidden_sizes[1], sub_hidden_sizes[0], bias=True),
            nn.Sigmoid(),
            nn.Linear(sub_hidden_sizes[0], input_size, bias=True),
            nn.Sigmoid(),
            nn.Unflatten(dim=1, unflattened_size=(1, 28, 28))
        )

    def forward(self, input):
        ''' 
        input: tensor of size (batch_size, input_size)
        output: tensor of size (batch_size, input_size)
        '''
        encoded = self.encoder(input) # (batch_size, input_size) => (batch_size, 2*latent_size)
        assert torch.equal(torch.tensor(encoded.shape), torch.tensor([input.shape[0], 2*self.latent_size]))

        mean = encoded[:, :self.latent_size]
        std_dev = torch.exp(encoded[:, self.latent_size:])
        assert torch.equal(torch.tensor(mean.shape), torch.tensor([input.shape[0], self.latent_size]))
        assert torch.equal(torch.tensor(mean.shape), torch.tensor(std_dev.shape))

        encoded_sampled = mean + std_dev * torch.empty_like(mean).normal_(mean=0,std=1)
        assert torch.equal(torch.tensor(mean.shape), torch.tensor(encoded_sampled.shape))

        decoded = self.decoder(encoded_sampled)

        self.kl_div = torch.sum(std_dev**2 + mean**2 - torch.log(std_dev) - 1/2)

        return decoded

In [85]:
model = VariationalAutoEncoder(input_size=784, latent_size=32)
n_epochs = 10

criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

training_losses = []
test_losses = []

for epoch in range(n_epochs):
        training_loss = 0
        model.train()
        for step, batch in enumerate(tqdm(trainloader, desc=f"Epoch: {epoch+1}/{n_epochs}", leave=True, unit="batch")):
                input = batch[0]
                # print(input.shape, input)
                input /= input.max()
                assert torch.equal(torch.tensor(input.shape[1:]), torch.tensor([1, 28, 28]))
                # assert input.max().item() == 255
                optimizer.zero_grad()
                output = model(input)
                assert torch.equal(torch.tensor(output.shape), torch.tensor([input.shape[0], 1, 28, 28]))
                loss = criterion(output, input) + model.kl_div
                loss.backward()
                optimizer.step()
                training_loss += loss.item()

        model.eval()
        with torch.no_grad():
                testing_loss = 0
                for batch in testloader:
                        input /= input.max()
                        assert torch.equal(torch.tensor(input.shape[1:]), torch.tensor([1, 28, 28]))
                        output = model(input)
                        assert torch.equal(torch.tensor(output.shape), torch.tensor([input.shape[0], 1, 28, 28]))
                        loss = criterion(output, input) + model.kl_div
                        testing_loss += loss.item()

        training_losses.append(training_loss)
        test_losses.append(testing_loss)
                
        print(f'=====> Average Training, Testing Loss at epoch:{epoch+1} = {training_loss}, {testing_loss}\n')

Epoch: 1/10: 100%|██████████| 938/938 [00:05<00:00, 160.61batch/s]


=====> Average Training, Testing Loss at epoch:1 = 667335.6408081055, 355.1553649902344



Epoch: 2/10: 100%|██████████| 938/938 [00:05<00:00, 156.48batch/s]


=====> Average Training, Testing Loss at epoch:2 = 665668.6543884277, 355.160888671875



Epoch: 3/10: 100%|██████████| 938/938 [00:05<00:00, 163.86batch/s]


=====> Average Training, Testing Loss at epoch:3 = 665668.6097106934, 355.1541442871094



Epoch: 4/10: 100%|██████████| 938/938 [00:05<00:00, 162.05batch/s]


=====> Average Training, Testing Loss at epoch:4 = 665669.1684570312, 355.1418151855469



Epoch: 5/10:   9%|▉         | 86/938 [00:00<00:07, 118.57batch/s]


KeyboardInterrupt: 