# VAE

In order to produce color images similar to DCGAN for subsequent comparison, a series of modifications were made to the previous code to enable it to produce color portraits with higher resolution. First, black and white will be converted to color, the color channel from the original single channel to three channels (RGB), and the encoder and decoder are modified accordingly.

####  LLM disclaimer

When the VAE model is changed from grayscale to color, ChatGPT gives several modification methods for reference.

Referring to one of the modification methods, the code is modified.

## Import some libraries

In [1]:
import os 
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import torchvision 
from torchvision import transforms 
from torchvision.utils import save_image

## Parameter configuration

#### Change the color path to 28x28 colour image with 3 channels (RGB)

In [2]:
#28x28 colour image with 3 channels
image_size = 3 * 28 * 28 
h_dim = 400 
z_dim = 60
num_epochs = 300 
batch_size = 128 
learning_rate = 0.001 

# Create a folder
sample_dir = 'samples'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
                                     
dataset = torchvision.datasets.ImageFolder(root='painting/', 
                                            transform=transforms.Compose([
                                                transforms.Resize((28, 28)),
                                                # Keep color images
                                                transforms.ToTensor(), 
                                            ]))

## Defining the VAE model

A variational autoencoder (VAE) neural network structure is defined, including an encoder and a decoder.

The encoder converts the input image into the mean and variance of the latent space, and the decoder reconstructs the latent vector into the original image.

#### Modify the code to set the color channels of generated images to RGB.

In [3]:
class VAE(nn.Module):
    def __init__(self, image_size=3 * 28 * 28, h_dim=400, z_dim=60):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(image_size, h_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(h_dim, z_dim)
        
        self.fc4 = nn.Linear(z_dim, h_dim)
        self.fc5 = nn.Linear(h_dim, image_size)

    def encoder(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2(h), self.fc3(h)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        eps = torch.randn_like(std)
        return mu + eps * std 

    def decoder(self, z):
        h = F.relu(self.fc4(z))
        # The output is a color image and the number of channels is 3
        return torch.sigmoid(self.fc5(h)).view(-1, 3, 28, 28) 

    
    def forward(self, x):
        mu, log_var = self.encoder(x)
        z = self.reparameterize(mu, log_var)
        x_reconst = self.decoder(z)
        return x_reconst, mu, log_var 
    

# Select the GPU and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

## Training model

A double loop is used, with an outer loop traversing the number of training rounds and an inner loop traversing the data set.

For each data batch, the data is first passed into the model for forward propagation, and then the reconstruction loss and KL divergence are calculated. Then backpropagation and optimization are carried out to update the model parameters.

For every 100 batches trained, print out the current round, the steps, and the value of the reconstruction loss and KL divergence.

In [4]:
for epoch in range(num_epochs):
    for i, (x, _) in enumerate(dataset):
        # Samples are obtained and propagated forward
        x = x.to(device).view(-1, image_size)
        x_reconst, mu, log_var = model(x)

        # Calculate reconstruction losses and KL divergence (KL divergence is a measure of how similar two distributions are)
        reconst_loss = F.binary_cross_entropy(x_reconst.view(-1, 3 * 28 * 28), x.view(-1, 3 * 28 * 28), reduction='sum')
        kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

        # Backpropagation and optimization
        loss = reconst_loss + kl_div 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1)%100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div:{:.4f}'.
                  format(epoch+1, num_epochs, i+1, len(dataset), reconst_loss.item(), kl_div.item()))



Epoch [1/300], Step [100/2042], Reconst Loss: 1250.6622, KL Div:2.9523
Epoch [1/300], Step [200/2042], Reconst Loss: 1099.3325, KL Div:2.8520
Epoch [1/300], Step [300/2042], Reconst Loss: 1413.2888, KL Div:16.4427
Epoch [1/300], Step [400/2042], Reconst Loss: 1186.1074, KL Div:3.6154
Epoch [1/300], Step [500/2042], Reconst Loss: 1683.8351, KL Div:63.3989
Epoch [1/300], Step [600/2042], Reconst Loss: 1356.7648, KL Div:8.9825
Epoch [1/300], Step [700/2042], Reconst Loss: 1324.6372, KL Div:22.0458
Epoch [1/300], Step [800/2042], Reconst Loss: 1163.2579, KL Div:11.8785
Epoch [1/300], Step [900/2042], Reconst Loss: 1195.2729, KL Div:15.4491
Epoch [1/300], Step [1000/2042], Reconst Loss: 1321.5194, KL Div:11.5000
Epoch [1/300], Step [1100/2042], Reconst Loss: 763.2000, KL Div:0.9630
Epoch [1/300], Step [1200/2042], Reconst Loss: 1299.4104, KL Div:18.8847
Epoch [1/300], Step [1300/2042], Reconst Loss: 1588.4285, KL Div:59.3868
Epoch [1/300], Step [1400/2042], Reconst Loss: 751.7977, KL Div:2.

## Generate image

Test with trained models and generated images: 

1. sampled images (a new image generated from the underlying vector z by the decoder)

2. reconstruct images (an image reconstructed from the original image by the encoder and decoder).

#### Modify the code to set the color channels of generated images to RGB.

In [None]:
# Test with the trained model
with torch.no_grad():
        # Save the sampled image, that is, the new image generated by the latent vector z through the decoder
        # Randomly generated images
        z = torch.randn(batch_size, z_dim).to(device)
        out = model.decoder(z).view(-1, 3, 28, 28)
        save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1)))
        
        # Save the reconstructed image
        out, _, _ = model(x)
        x_concat = torch.cat([x.view(-1, 3, 28, 28), out.view(-1, 3, 28, 28)], dim=3)
        save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch+1)))