In [25]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image

In [26]:
# Hyper param
image_size = 784 # 28x28 - MNIST
hidden_dim = 400
latent_dim = 20 # both for mean and std. dev.
batch_size = 128
epochs = 10

# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root="./mnist_dataset",
                               train = True,
                               transform=transforms.ToTensor(), # Originally they are just images, here we are transforming all the images to tensors.
                               download=True)

test_dataset = torchvision.datasets.MNIST(root="./mnist_dataset",
                              train=False,
                              transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(dataset= train_dataset,
                                           batch_size = batch_size,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=True)


In [27]:
# Directory to save reconstructed images
sample_dir = './14.results'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

Network
784 -> 400 -> {20, 20} 
sample from unit gaussian, multiply it with std. dev.
add with mean
resultant dim -> 20
20 -> 400 -> 784

In [28]:
#VAE Model
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(image_size, hidden_dim)
        self.fc2_mean = nn.Linear(hidden_dim, latent_dim)
        self.fc2_logvar = nn.Linear(hidden_dim, latent_dim)
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, image_size)

    def encode(self, x):
        h = F.relu(self.fc1(x)) # relu activation for fc1
        mu = self.fc2_mean(h)
        logvar = self.fc2_logvar(h)

        return mu, logvar

    def reparametrize(self, mu, logvar):
        std = torch.exp(logvar/2)  # Not sure why we divide by 2. Just saw that it was better.
        eps = torch.randn_like(std) # Ensure that the shape of the gaussian sample is same as the std so it will be easy when we element-wise multiply next
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc3(z)) # relu activation for fc3
        out = torch.sigmoid(self.fc4(h)) # Normalize the output to a value to [0,1] as that is the range for i/p as well. MNIST is normalized./
        return out

    def forward(self, x):
        # x = (batch_size, channels, x_dim, y_dim) = (128,1,28,28) --> (128,1,784)
        mu, logvar = self.encode(x.view(-1, image_size))
        z = self.reparametrize(mu, logvar)
        reconstructed = self.decode(z)
        return reconstructed, mu, logvar
    


In [29]:
# Model and Optimizer

model = VAE()
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3) 

In [30]:
# Define loss
def loss_function(reconstructed_image, original_image, mu, logvar):
    # Reconstruction loss which is BCE
    bce = F.binary_cross_entropy(reconstructed_image, original_image.view(-1,image_size), reduction='sum') # in the KL divergence loss we need to sum over samples in batch
    # KL divergence loss
    kld = 0.5 * torch.sum(logvar.exp() + mu.pow(2) - 1 - logvar)
    # The kld loss calculation can be broken down as follows
    # log_var, exp: (batch_size, 20)
    # kld = 0.5 * torch.sum(logvar.exp() + mu.pow(2) - 1 - logvar, 1) # batch size
    # kld_sum = torch.sum(kld)

    return bce + kld

def train(epoch):
    model.train() # useful in case of drop_out and batch_normalization as they act differently during training and testing. 

    train_loss = 0

    for i, (images, _) in enumerate(train_loader):
        reconstructed, mu, logvar = model(images)
        loss = loss_function(reconstructed, images, mu, logvar) # we are reshaping the images in bce loss so we dont need to do it here
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

        if i % 100 == 0:
            print(f"{epoch = }, [Batch {i}/{len(train_loader)}] \t Loss: {loss.item()/len(images)}")
        
    
    print(f"======> {epoch = }, Average loss: {train_loss/len(train_loader.dataset)}")

def test(epoch):
    model.eval() # useful in case of drop_out and batch_normalization as they act differently during training and testing. 

    test_loss = 0
    with torch.no_grad():
        for batch_idx, (images, _) in enumerate(test_loader):
            reconstructed, mu, logvar = model(images)
            loss = loss_function(reconstructed, images, mu, logvar) # we are reshaping the images in bce loss so we dont need to do it here
            test_loss += loss.item()
            if batch_idx == 0:
                comparison = torch.cat([images[:5], reconstructed.view(batch_size, 1, 28, 28)[:5]])
                save_image(comparison, '14.results/reconstruction_' + str(epoch) + '.png', nrow = 5)
    
    print(f"======> Average loss: {test_loss/len(test_loader.dataset)}")


In [31]:
# Main function
for epoch in range(1, epochs+1):
    train(epoch)
    test(epoch)
    
    with torch.no_grad():
        # Get rid of the encoder and sample values from gaussian as i/p to decoder
        # This should generate new images
        sample = torch.randn(64,20)
        generated = model.decode(sample)
        save_image(generated.view(64,1,28,28), '14.results/sample_' + str(epoch) + '.png')

epoch = 1, [Batch 0/469] 	 Loss: 549.6871337890625
epoch = 1, [Batch 100/469] 	 Loss: 185.02951049804688
epoch = 1, [Batch 200/469] 	 Loss: 161.3351287841797
epoch = 1, [Batch 300/469] 	 Loss: 131.08970642089844
epoch = 1, [Batch 400/469] 	 Loss: 133.47451782226562
epoch = 2, [Batch 0/469] 	 Loss: 129.60353088378906
epoch = 2, [Batch 100/469] 	 Loss: 122.0581283569336
epoch = 2, [Batch 200/469] 	 Loss: 122.07220458984375
epoch = 2, [Batch 300/469] 	 Loss: 123.73389434814453
epoch = 2, [Batch 400/469] 	 Loss: 120.58039855957031
epoch = 3, [Batch 0/469] 	 Loss: 112.46702575683594
epoch = 3, [Batch 100/469] 	 Loss: 112.49610900878906
epoch = 3, [Batch 200/469] 	 Loss: 111.21714782714844
epoch = 3, [Batch 300/469] 	 Loss: 117.95367431640625
epoch = 3, [Batch 400/469] 	 Loss: 117.32677459716797
epoch = 4, [Batch 0/469] 	 Loss: 117.24092102050781
epoch = 4, [Batch 100/469] 	 Loss: 112.57077026367188
epoch = 4, [Batch 200/469] 	 Loss: 113.23818969726562
epoch = 4, [Batch 300/469] 	 Loss: 112.