In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image

In [2]:
# Define hyperparameters:
image_size = 784  # Flattened size of MNIST images (28x28)
hidden_dim = 400  # Dimension of hidden layer in VAE
latent_dim = 20   # Dimension of latent space
batch_size = 128  # Batch size for training and testing
epochs = 10       # Number of training epochs

# Set device to GPU if available, else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load and transform MNIST dataset for training:
train_dataset = torchvision.datasets.MNIST(root='/data',
                                           train=True,
                                           transform=transforms.ToTensor(),
                                           download=True)

# DataLoader for the training dataset:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)

# Load and transform MNIST dataset for testing:
test_dataset = torchvision.datasets.MNIST(root='/data',
                                          train=False,
                                          transform=transforms.ToTensor())

# DataLoader for the test dataset:
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size, 
                                          shuffle=True)

# Create directory to save the reconstructed and sampled images (if directory not present):
sample_dir = 'results'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████████████████████████████████████████████████████████████| 9912422/9912422 [00:00<00:00, 19547397.33it/s]


Extracting /data\MNIST\raw\train-images-idx3-ubyte.gz to /data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████████████████████████████████████████████████████████████████| 28881/28881 [00:00<00:00, 28924473.21it/s]


Extracting /data\MNIST\raw\train-labels-idx1-ubyte.gz to /data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████████████████████████████████████████████████████████████| 1648877/1648877 [00:00<00:00, 14363609.81it/s]


Extracting /data\MNIST\raw\t10k-images-idx3-ubyte.gz to /data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|█████████████████████████████████████████████████████████████████████████| 4542/4542 [00:00<00:00, 4506867.46it/s]

Extracting /data\MNIST\raw\t10k-labels-idx1-ubyte.gz to /data\MNIST\raw






![vae](https://user-images.githubusercontent.com/30661597/78418103-a2047200-766b-11ea-8205-c7e5712715f4.png)

In [3]:
# Define the VAE class:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        
        # Fully connected layers for the encoder:
        self.fc1 = nn.Linear(image_size, hidden_dim)  # Input to hidden layer
        self.fc2_mean = nn.Linear(hidden_dim, latent_dim)  # Hidden to latent mean
        self.fc2_logvar = nn.Linear(hidden_dim, latent_dim)  # Hidden to latent log variance

        # Fully connected layers for the decoder
        self.fc3 = nn.Linear(latent_dim, hidden_dim)  # Latent to hidden layer
        self.fc4 = nn.Linear(hidden_dim, image_size)  # Hidden to output layer
        
        
    def encode(self, x):
        # Encoder that outputs mean and log variance of latent distribution
        h = F.relu(self.fc1(x))
        mu = self.fc2_mean(h)
        log_var = self.fc2_logvar(h)
        return mu, log_var
    
    def reparameterize(self, mu, logvar):
        # Reparameterization trick to sample from latent distribution
        std = torch.exp(logvar/2)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        # Decoder that maps latent space to output
        h = F.relu(self.fc3(z))
        out = torch.sigmoid(self.fc4(h))
        return out
    
    def forward(self, x):
        # <<x: (batch_size, 1, 28,28) --> (batch_size, 784)>>
        # Forward pass through the network
        mu, logvar = self.encode(x.view(-1, image_size))  # Flatten image and encode
        z = self.reparameterize(mu, logvar)               # Sample from latent space
        reconstructed = self.decode(z)                   # Decode sample to reconstruct image
        
        return reconstructed, mu, logvar

# Initialize the VAE model and move it to the appropriate device (GPU/CPU):
model = VAE().to(device)

# Define the optimizer for training
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

$Loss = -E[\log P(X | z)]+D_{K L}[N(\mu(X), \Sigma(X)) \| N(0,1)]$

#### $D_{K L}[N(\mu(X), \Sigma(X)) \| N(0,1)]=\frac{1}{2} \sum_{k}\left(\exp (\Sigma(X))+\mu^{2}(X)-1-\Sigma(X)\right)$

In [4]:
# Define the VAE loss function:
def loss_function(reconstructed_image, original_image, mu, logvar):
    # Binary Cross Entropy between the original and reconstructed image
    bce = F.binary_cross_entropy(reconstructed_image, original_image.view(-1, 784), reduction = 'sum')
    
    # Kullback-Leibler divergence losses summed over all elements and batch
    kld = 0.5 * torch.sum(logvar.exp() + mu.pow(2) - 1 - logvar)
    # <<kld = torch.sum(0.5 * torch.sum(logvar.exp() + mu.pow(2) - 1 - logvar, 1))>>
        
    return bce + kld
    

# Train the VAE
def train(epoch):
    model.train() # Set the model to training mode
    train_loss = 0
    
    # Loop over each batch from the training set:
    for i, (images, _) in enumerate(train_loader):
        images = images.to(device) # Move images to the device
        reconstructed, mu, logvar = model(images) # Forward pass
        loss = loss_function(reconstructed, images, mu, logvar) # Compute loss

        optimizer.zero_grad()  # Zero out gradients
        loss.backward()  # Backpropagation
        train_loss += loss.item()  # Accumulate the training loss
        optimizer.step()  # Update the weights
       
        # Print loss every 100 batches:
        if i % 100 == 0:
            print("Train Epoch {} [Batch {}/{}]\tLoss: {:.3f}".format(epoch, i, len(train_loader), loss.item()/len(images)))
     
    # Print average loss for the epoch
    print('=====> Epoch {}, Average Loss: {:.3f}'.format(epoch, train_loss/len(train_loader.dataset)))


# Test the VAE:
def test(epoch):
    model.eval()
    test_loss = 0 # Set the model to evaluation mode
    
    # Disable gradient calculations
    with torch.no_grad():
        for batch_idx, (images, _) in enumerate(test_loader):
            images = images.to(device)  # Move images to the device
            reconstructed, mu, logvar = model(images)  # Forward pass
            test_loss += loss_function(reconstructed, images, mu, logvar).item()  # Compute and accumulate loss
            
            # Save the first batch of real and reconstructed images
            if batch_idx == 0:
                comparison = torch.cat([images[:5], reconstructed.view(batch_size, 1, 28, 28)[:5]])
                save_image(comparison.cpu(), 'results/reconstruction_' + str(epoch) + '.png', nrow = 5)
    
    # Print average loss for the test set:
    print('=====> Average Test Loss: {:.3f}'.format(test_loss/len(test_loader.dataset)))

In [5]:
# Main training and testing loop

for epoch in range(1, epochs + 1):
    # Train the model for one epoch
    train(epoch)
    
     # Test the model after training
    test(epoch)
    
    # Generate and save new images after each epoch
    with torch.no_grad():
        # Generate random latent vectors (sampled from a Gaussian distribution)
        sample = torch.randn(64,20).to(device) # 64 samples, latent_dim = 20
        
        # Decode the latent vectors to generate images
        generated = model.decode(sample).cpu() 
        
        # Save the generated images to disk
        save_image(generated.view(64,1,28,28), 'results/sample_' + str(epoch) + '.png')

Train Epoch 1 [Batch 0/469]	Loss: 548.362
Train Epoch 1 [Batch 100/469]	Loss: 184.418
Train Epoch 1 [Batch 200/469]	Loss: 150.321
Train Epoch 1 [Batch 300/469]	Loss: 143.018
Train Epoch 1 [Batch 400/469]	Loss: 129.501
=====> Epoch 1, Average Loss: 165.459
=====> Average Test Loss: 128.130
Train Epoch 2 [Batch 0/469]	Loss: 128.493
Train Epoch 2 [Batch 100/469]	Loss: 126.065
Train Epoch 2 [Batch 200/469]	Loss: 115.037
Train Epoch 2 [Batch 300/469]	Loss: 119.969
Train Epoch 2 [Batch 400/469]	Loss: 115.301
=====> Epoch 2, Average Loss: 122.249
=====> Average Test Loss: 116.300
Train Epoch 3 [Batch 0/469]	Loss: 116.082
Train Epoch 3 [Batch 100/469]	Loss: 112.912
Train Epoch 3 [Batch 200/469]	Loss: 117.561
Train Epoch 3 [Batch 300/469]	Loss: 110.648
Train Epoch 3 [Batch 400/469]	Loss: 114.934
=====> Epoch 3, Average Loss: 114.982
=====> Average Test Loss: 112.051
Train Epoch 4 [Batch 0/469]	Loss: 116.126
Train Epoch 4 [Batch 100/469]	Loss: 115.840
Train Epoch 4 [Batch 200/469]	Loss: 115.666
