In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import datasets, transforms
import torchvision.transforms as transforms
from torch.utils.data import Dataset, random_split, DataLoader
from torchvision.utils import save_image

from torchsummary import summary

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import os
import math
from PIL import Image
from PIL import Image
from IPython.display import display
import glob

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
transformObj = transforms.Compose([
    transforms.Resize(120),
    transforms.CenterCrop(120),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

In [4]:
dataroot = "./celeba/"

dataset = datasets.ImageFolder(root=dataroot, transform=transformObj)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=100, shuffle=True)

In [5]:
class VAE(nn.Module):
    def __init__(self, latent_size=20):
        super(VAE, self).__init__()
        
        self.latent_size = latent_size
        
        self.l1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=2, padding=1)
        self.l2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=4, stride=2, padding=1)        
        
        self.l21 = nn.Linear(6*2*30*30, self.latent_size)
        self.l22 = nn.Linear(6*2*30*30, self.latent_size)
        
        self.f = nn.Linear(self.latent_size, 6*2*30*30)
        
        self.l3 = nn.ConvTranspose2d(in_channels=12, out_channels=6, kernel_size=4, stride=2, padding=1)
        self.l4 = nn.ConvTranspose2d(in_channels=6, out_channels=3, kernel_size=4, stride=2, padding=1)
        
    def encoder(self, x_in):
        h = F.relu(self.l1(x_in))
        h = F.relu(self.l2(h))
        
        h = h.view(h.size(0), -1)
        
        return self.l21(h), self.l22(h)
    
    def decoder(self, z):
        z = self.f(z)
        z = z.view(z.size(0), 6*2, 30, 30)
        
        z = F.relu(self.l3(z))
        z = torch.sigmoid(self.l4(z))
        
        return z
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return torch.add(eps.mul(std), mu)
    
    def forward(self, x_in):
        mu, log_var = self.encoder(x_in)
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var

In [6]:
vae = VAE()
    
vae.to(device)

VAE(
  (l1): Conv2d(3, 6, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (l2): Conv2d(6, 12, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (l21): Linear(in_features=10800, out_features=20, bias=True)
  (l22): Linear(in_features=10800, out_features=20, bias=True)
  (f): Linear(in_features=20, out_features=10800, bias=True)
  (l3): ConvTranspose2d(12, 6, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (l4): ConvTranspose2d(6, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
)

In [7]:
summary(vae, (3, 120, 120))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 6, 60, 60]             294
            Conv2d-2           [-1, 12, 30, 30]           1,164
            Linear-3                   [-1, 20]         216,020
            Linear-4                   [-1, 20]         216,020
            Linear-5                [-1, 10800]         226,800
   ConvTranspose2d-6            [-1, 6, 60, 60]           1,158
   ConvTranspose2d-7          [-1, 3, 120, 120]             291
Total params: 661,747
Trainable params: 661,747
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.16
Forward/backward pass size (MB): 0.82
Params size (MB): 2.52
Estimated Total Size (MB): 3.51
----------------------------------------------------------------


In [8]:
optimizer = optim.Adam(vae.parameters())

def loss_function(recon_x, x, mu, log_var):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) # KL Divergence from MIT 6.S191
    return (BCE + KLD)

In [9]:
def train(epoch):
    vae.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(dataloader):
        data = data.to(device)
        optimizer.zero_grad()
        
        r_batch, mu, log_var = vae(data)

        loss = loss_function(r_batch, data, mu, log_var)

        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        #if batch_idx%50==0:
        #    print("Batch finished in Epoch: ", batch_idx)
    print('Epoch: {} Train mean loss: {:.8f}'.format(epoch, train_loss / len(dataloader.dataset)))

In [13]:
n_epoches = 100

for epoch in range(1, n_epoches+1):
    train(epoch)

Epoch: 1 Train mean loss: -1019435.44637161
Epoch: 2 Train mean loss: -1028609.07177033
Epoch: 3 Train mean loss: -1037917.89274322
Epoch: 4 Train mean loss: -1045636.24381978
Epoch: 5 Train mean loss: -1053661.66905901
Epoch: 6 Train mean loss: -1062967.43440989
Epoch: 7 Train mean loss: -1067107.52113238
Epoch: 8 Train mean loss: -1074994.69338118
Epoch: 9 Train mean loss: -1082275.46152313
Epoch: 10 Train mean loss: -1094249.95773525
Epoch: 11 Train mean loss: -1101453.58572568
Epoch: 12 Train mean loss: -1112642.60007974
Epoch: 13 Train mean loss: -1121924.90988836
Epoch: 14 Train mean loss: -1132490.45135566
Epoch: 15 Train mean loss: -1141355.20334928
Epoch: 16 Train mean loss: -1151193.50917065
Epoch: 17 Train mean loss: -1161476.62639553
Epoch: 18 Train mean loss: -1169279.51315789
Epoch: 19 Train mean loss: -1177864.07216906
Epoch: 20 Train mean loss: -1183102.15629984
Epoch: 21 Train mean loss: -1189510.81060606
Epoch: 22 Train mean loss: -1193607.01594896
Epoch: 23 Train mea

In [14]:
with torch.no_grad():
    z = torch.randn(1, 20).to(device)
    for i in range(100):
        z = torch.add(z, 0.05)
        
        sample = vae.decoder(z).to(device)
        save_image(sample.view(3, 120, 120), './samplesFACES/sample' + str(i) + '.png')