여기서 활용하는 데이터셋은 CelebA라는 데이터셋으로, 유명 인사들의 얼굴이 담긴 사진이고
또한 여러 특성이 사진에 지정되어있음
자세한 내용은 kaggle의 CelebFaces Attributes Dataset의 내용을 참고할 것

In [None]:
!mkdir data_faces && wget https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/celeba.zip

In [None]:
import zipfile

with zipfile.ZipFile('celeba.zip', 'r') as zip_ref:
    zip_ref.extractall('data_faces/')

In [None]:
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import torch.optim as optim
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.autograd import Variable

In [None]:
def show(img, renorm=False, nrow=8, interpolation='bicubic'):
  if renorm:
    img = img*0.5 + 0.5
  img_grid = torchvision.utils.make_grid(img, nrow=nrow).numpy()
  plt.figure()
  plt.imshow(np.transpose(img_grid, (1,2,0)), interpolation=interpolation)
  plt.axis('off')
  plt.show()

In [None]:
root = 'data_faces/img_align_celeba'
img_list = os.listdir(root)
print(len(img_list))

In [None]:
class VAE(nn.Module):
  def __init__(self, image_size=128, latent_dim=512):
    super(VAE, self).__init__()

    # Encoder
    self.encoder = nn.Sequential(
        nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),
        nn.ReLU(),
        nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
        nn.ReLU(),
        nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
        nn.ReLU(),
        nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
        nn.ReLU(),
        nn.Flatten()
    )
    # 평균, 분산을 확인할 수 있는 파라미터 정의
    self.fc_mu = nn.Linear(256*(image_size // 16) * (image_size // 16), latent_dim)
    self.fc_logvar = nn.Linear(256*(image_size // 16) * (image_size // 16), latent_dim)

    # Decoder
    self.decoder_input = nn.Linear(latent_dim, 256 * (image_size // 16) * (image_size // 16))
    self.decoder = nn.Sequential(
        nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
        nn.ReLU(),
        nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
        nn.ReLU(),
        nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
        nn.ReLU(),
        nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
        nn.Sigmoid()
    )
    self.image_size = image_size

  def encoder(self, x):
    x = self.encoder(x)
    mu, logvar = self.fc_mu(x), self.fc_log_var(x)
    return mu, logvar

  def reparameterize(self, mu, logvar):
    std = torch.exp(0.5*logvar)
    eps = torch.randn_like(std)
    return mu * eps * std

  def decoder(self, z):
    x = self.decoder_input(z)
    x = x.view(-1, 256, (self.image_size // 16), (self.image_size // 16))
    x = self.decoder(x)
    return x

  def forward(self, x):
    mu, logvar = self.encoder(x)
    z = self.reparameterize(mu, logvar)
    reconstructed_x = self.decoder(z)
    return reconstructed_x, mu, logvar


In [None]:
re_size = 128

transform = transforms.Compose([
    transforms.Resize(size=(re_size, re_size), interpolation=Image.BICUBIC),
    transforms.ToTensor()
])

batch_size = 64
celeba_data = datasets.ImageFolder('./data_faces', transform=transform)
celeba_loader = DataLoader(celeba_data, batch_size=batch_size, shuffle=True)

In [None]:
batch, _ = next(iter(celeba_loader))
show(batch[0:16], nrow=4)

In [None]:
def vae_loss(recon_x, x, mu, logvar):
  BCE = nn.BCELoss(reduction='sum')(recon_x, x)
  KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
  return BCE + KLD



In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:

num_epochs = 50

best_loss = np.inf
model.train()
for epoch in range(num_epochs):
  total_loss = 0
  cpt=0
  for batch_idx, data in enumerate(celeba_loader):
    img, _ = data
    img = img.float().to(device)
    recon_img, mu, logvar = model(img)
    loss = vae_loss(recon_img, img, mu, logvar)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    total_loss += loss.item()
    cpt += 1

    if batch_idx % 1000 == 0:
      print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(celeba_loader)}], Loss: {total_loss/cpt:.4f}')
  print(f'Epoch [{epoch+1}/{num_epochs}], Total Loss: {total_loss/len(celeba_loader):.4f}')
  # loss값이 이전의 loss보다 낮으면 > 해당 모델 파라미터를 저장할 것
  if total_loss < best_loss:
    best_loss = total_loss
    torch.save(model.state_dict(), 'celeba_vae.pth')


In [None]:
!gdown --id liUkepsRdaADfCnczvx2-jkcT2bHLUyPd
# 학습된 모델을 바로 받아서 확인해볼 수도..

In [None]:
model.load_state_dict(torch.load('celeba_vae.pth', map_location=device))
model.eval()

In [None]:
# 원본
show(batch[0:16], nrow=4)

In [None]:
with torch.no_grad():
  recon_img, _, _ = model(batch.to(device))


In [None]:
# 복원 사진
show(recon_img[0:16].clip(0,1).detach().cpu(), nrow=4)