In [4]:
!pip install colorama

Collecting colorama
  Downloading https://files.pythonhosted.org/packages/44/98/5b86278fbbf250d239ae0ecb724f8572af1c91f4a11edf4d36a206189440/colorama-0.4.4-py2.py3-none-any.whl
Installing collected packages: colorama
Successfully installed colorama-0.4.4


# Variational Autoencoders

In [5]:
import os
import sys
from tqdm import tqdm
import numpy as np
import torch
from PIL import Image
from skimage import io
from torch import nn
from torch import optim
from torch.nn import functional as F
from torchvision.utils import save_image
from colorama import Fore
import math

### Model

In [6]:
class VAE(nn.Module):
    def __init__(self, zsize):
        super(VAE, self).__init__()

        self.zsize = zsize

        # Encoder
        self.conv1 = nn.Conv2d(3, 128, 4, 2, 1)
        self.conv1_bn = nn.BatchNorm2d(128)
        self.conv2 = nn.Conv2d(128, 256, 4, 2, 1)
        self.conv2_bn = nn.BatchNorm2d(256)
        self.conv3 = nn.Conv2d(256, 512, 4, 2, 1)
        self.conv3_bn = nn.BatchNorm2d(512)
        self.conv4 = nn.Conv2d(512, 1024, 4, 2, 1)
        self.conv4_bn = nn.BatchNorm2d(1024)
        self.conv5 = nn.Conv2d(1024, 2048, 4, 2, 1)
        self.conv5_bn = nn.BatchNorm2d(2048)

        self.fc1 = nn.Linear(2048 * 4 * 4, zsize)
        self.fc2 = nn.Linear(2048 * 4 * 4, zsize)

        # Decoder
        self.d1 = nn.Linear(zsize, 2048 * 4 * 4)
        self.deconv1 = nn.ConvTranspose2d(2048, 1024, 4, 2, 1)
        self.deconv1_bn = nn.BatchNorm2d(1024)
        self.deconv2 = nn.ConvTranspose2d(1024, 512, 4, 2, 1)
        self.deconv2_bn = nn.BatchNorm2d(512)
        self.deconv3 = nn.ConvTranspose2d(512, 256, 4, 2, 1)
        self.deconv3_bn = nn.BatchNorm2d(256)
        self.deconv4 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
        self.deconv4_bn = nn.BatchNorm2d(128)
        self.deconv5 = nn.ConvTranspose2d(128, 3, 4, 2, 1)

    def encode(self, x):
        x = F.relu(self.conv1_bn(self.conv1(x)))
        x = F.relu(self.conv2_bn(self.conv2(x)))
        x = F.relu(self.conv3_bn(self.conv3(x)))
        x = F.relu(self.conv4_bn(self.conv4(x)))
        x = F.relu(self.conv5_bn(self.conv5(x)))
        x = x.view(x.shape[0], 2048 * 4 * 4)
        h1 = self.fc1(x)
        h2 = self.fc2(x)
        return h1, h2

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def decode(self, x):
        x = x.view(x.shape[0], self.zsize)
        x = self.d1(x)
        x = x.view(x.shape[0], 2048, 4, 4)
        x = F.leaky_relu(x, 0.2)
        x = F.leaky_relu(self.deconv1_bn(self.deconv1(x)), 0.2)
        x = F.leaky_relu(self.deconv2_bn(self.deconv2(x)), 0.2)
        x = F.leaky_relu(self.deconv3_bn(self.deconv3(x)), 0.2)
        x = F.leaky_relu(self.deconv4_bn(self.deconv4(x)), 0.2)
        x = torch.tanh(self.deconv5(x))
        return x

    def forward(self, x):
        mu, logvar = self.encode(x)
        mu = mu.squeeze()
        logvar = logvar.squeeze()
        z = self.reparameterize(mu, logvar)
        return self.decode(z.view(-1, self.zsize, 1, 1)), mu, logvar

    def weight_init(self, mean, std):
        for m in self.modules():
            if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
                m.weight.data.normal_(mean, std)
                m.bias.data.zero_()

### Loss Function

In [7]:
def loss_function(recon_x, x, mu, logvar):
    bce = torch.mean((recon_x - x) ** 2)
    kld = -0.5 * torch.mean(torch.mean(1 + logvar - mu.pow(2) - logvar.exp(), 1))
    return bce, kld * kl_weight

### Image Perprocessing

In [8]:
def process_images(image_collection):
    data = [np.array(Image.fromarray(x).resize([im_size, im_size])).transpose((2, 0, 1)) for x in image_collection]
    x = np.asarray(data, dtype=np.float32) / 127.5 - 1.
    x = x.reshape(-1, 3, im_size, im_size)
    return x

### Hyperparameters

In [9]:
im_size = 128
batch_size = 256
z_size = 512
kl_weight = 1.7
train_epoch = 200
lr = 0.0008
gradient_clipping_value = 0.1

### Preparing Dataset

In [10]:
im_collection = io.imread_collection('/content/drive/MyDrive/CSYE7370/data/image/*.jpg')
data_train = images = process_images(im_collection)
print("Training dataset size:", len(data_train))
batches_per_epoch = (math.ceil(len(data_train) / batch_size))
print("Batches per epoch: ", batches_per_epoch)
os.makedirs('/content/drive/MyDrive/CSYE7370/VAEresults_reconstructed', exist_ok=True)
os.makedirs('/content/drive/MyDrive/CSYE7370/VAEresults_generated', exist_ok=True)
train_loader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, shuffle=True)

Training dataset size: 5000
Batches per epoch:  20


### Training Model

In [12]:
vae = VAE(zsize=z_size)
vae.cuda()
vae.train()
vae.weight_init(mean=0, std=0.02)
vae_optimizer = optim.Adam(vae.parameters(), lr=lr)

for epoch in range(train_epoch):
    vae.train()
    reconstruction_loss = 0
    kullback_leibler_loss = 0
    i = 0
    training_pbar = tqdm(total=len(data_train),
                         position=0, leave=True,
                         file=sys.stdout, bar_format="{l_bar}%s{bar:70}%s{r_bar}" % (Fore.BLUE, Fore.RESET))
    for _, x in enumerate(train_loader):
#TRAINING 
        vae.train()
        vae.zero_grad()
        x = x.cuda()
        rec, mu, logvar = vae(x)
        loss_re, loss_kl = loss_function(rec, x, mu, logvar)
        (loss_re + loss_kl).backward()
        torch.nn.utils.clip_grad_norm_(vae.parameters(), gradient_clipping_value)
        vae_optimizer.step()
        reconstruction_loss += loss_re.item()
        kullback_leibler_loss += loss_kl.item()
        training_pbar.update(x.shape[0])

#VALIDATION
        i += 1
        if i % batches_per_epoch == 0:
            training_pbar.close()
            print('\nEpoch [%d/%d] - reconstruction loss: %.9f, Kullback-Leibler loss: %.9f' % (
                (epoch + 1), train_epoch, reconstruction_loss / batches_per_epoch,
                kullback_leibler_loss / batches_per_epoch))
            reconstruction_loss = 0
            kullback_leibler_loss = 0
            with torch.no_grad():
                vae.eval()
                x_rec, _, _ = vae(x)
                result_sampled = torch.cat([x, x_rec]) * 0.5 + 0.5
                result_sampled = result_sampled.cpu()
                save_image(result_sampled.view(-1, 3, im_size, im_size),
                           '/content/drive/MyDrive/CSYE7370/VAEresults_reconstructed/sample_' + str(epoch) + '.png')
                sample = torch.randn(128, z_size).view(-1, z_size, 1, 1).cuda()
                x_rec = vae.decode(sample)
                result_sampled = x_rec * 0.5 + 0.5
                result_sampled = result_sampled.cpu()
                save_image(result_sampled.view(-1, 3, im_size, im_size),
                           '/content/drive/MyDrive/CSYE7370/VAEresults_generated/sample_' + str(epoch) + '.png')
    torch.save(vae.state_dict(), "./weights_" + str(epoch) + ".pth")


100%|[34m██████████████████████████████████████████████████████████████████████[39m| 5000/5000 [00:28<00:00, 172.91it/s]

Epoch [1/200] - reconstruction loss: 0.451126985, Kullback-Leibler loss: 4899.585527064
100%|[34m██████████████████████████████████████████████████████████████████████[39m| 5000/5000 [00:28<00:00, 172.72it/s]

Epoch [2/200] - reconstruction loss: 0.337058076, Kullback-Leibler loss: 4.407654607
100%|[34m██████████████████████████████████████████████████████████████████████[39m| 5000/5000 [00:29<00:00, 171.93it/s]

Epoch [3/200] - reconstruction loss: 0.299824366, Kullback-Leibler loss: 0.369089157
100%|[34m██████████████████████████████████████████████████████████████████████[39m| 5000/5000 [00:29<00:00, 171.96it/s]

Epoch [4/200] - reconstruction loss: 0.253118189, Kullback-Leibler loss: 0.117526798
100%|[34m██████████████████████████████████████████████████████████████████████[39m| 5000/5000 [00:29<00:00, 172.01it/s]

Epoch [5/200] - reconstruction loss: 

KeyboardInterrupt: ignored

### Testing

In [13]:
vae.load_state_dict(torch.load("/content/weights_51.pth"))
os.makedirs('/content/drive/MyDrive/CSYE7370/VAEgenerated_results', exist_ok=True)
for i in range(0, 200):
    sample = torch.randn(128, z_size).view(-1, z_size, 1, 1).cuda()
    x_rec = vae.decode(sample)
    result_sampled = x_rec * 0.5 + 0.5
    result_sampled = result_sampled.cpu()
    save_image(result_sampled.view(-1, 3, im_size, im_size), '/content/drive/MyDrive/CSYE7370/VAEgenerated_results/sample_' + str(i) + '.png')

### Conclusion

128 fake images are created by VAE architecture. File "VAEgenerated_results","VAEresults_generated","VAEresults_reconsstructed" save all the images generated during this model.  From the result, it can generate clear animated faces with haze tone, not so obvious outlines but it can be recognized as animated faces.

Comparing to GAN and AE, it got good generation results. I trained it on Google Colab with GPU, however because of it produced very large weights.pth files (almost 60GB and ran out of my google disk space), so this training model had to stop at epoch 52. This is also show that this code has problems and is not optimal. However it still get very outstanding generated animated faces. If hardware facilities can be improved, and all the 200 epochs are finished, we can expect very clear generated images.

After adjusting the size of the latent space and changing the network architecture, current combination give its trade-off between compression and quility. 

Copyright <2021>

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.