In [None]:
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import torch.optim as optim
from torchvision import datasets
from torch.utils.data import DataLoader, Dataset

from torch.autograd import Variable

import pandas as pd

In [None]:
class CelebADataset(Dataset):
  def __init__(self, data, labels, classes, transform=None):
    self.data = data
    self.labels = labels
    self.classes = classes
    self.transform = transform

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    img = Image.open(os.path.join('data_faces', 'img_align_celeba', self.data[idx])).convert('RGB')
    label = torch.Tensor(self.labels[idx,[39,20,22,2,31,15]].astype('uint8'))

    if self.transform:
      img = self.transform(img)
    sample = {'images': img, 'labels': label}
    return sample



In [None]:
import torchvision.utils as vutils

def display_grid(images, nrow=8, figsize=(12, 12)):
  fig=plt.figure(figsize=figsize)
  plt.imshow(np.transpose(vutils.make_grid(images, nrow=nrow, padding=2, normalize=True).cpu(), (1,2,0)))
  plt.axis('off')


In [None]:
df = pd.read_csv('list_attr_celeba.csv')


In [None]:
transform = transforms.Compose([
    transforms.Resize(size=(128, 128), interpolation=Image.BICUBIC),
    transforms.ToTensor()
])

In [None]:
data, labels = df.values[:, 0], (df.values[:, 1:]+1) //2
classes = df.columns[1:]

In [None]:
celeba_data = CelebADataset(data, labels, classes, transform)
celeba_loader = DataLoader(celeba_data, batch_size=64, shuffle=True)

In [None]:
class VAE(nn.Module):
  def __init__(self, image_size=128, latent_dim=512):
    super(VAE, self).__init__()

    self.encoder = nn.Sequential(
        nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),
        nn.ReLU(),
        nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
        nn.ReLU(),
        nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
        nn.ReLU(),
        nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
        nn.ReLU(),
        nn.Flatten()
    )
    # embedding 차원 정의 -
    self.Embeddings = nn.Embedding(2, 10) # 특성 표시 방법이 2가지 - 0, 1, 그리고 그림, 그리고 특성을 10개의 벡터 차원으로 표현
    # 평균, 분산과 관련된 파라미터 정의
    self.fc_mu = nn.Linear(256 * (image_size // 16) * (image_size // 16), latent_dim)
    self.fc_var = nn.Linear(256 * (image_size // 16) * (image_size // 16), latent_dim)

    self.decoder_input = nn.Linear(latent_dim+6*10, 256 * (image_size // 16) * (image_size // 16))
    # 위에서 decoder의 입력으로 들어가는 차원에 추가되는 이유는 우리가 변경할 특성의 숫자가 6개고, 또한 그 특성은 각각 embeddings에서 10의 차원으로 정의하였기 때문에, 각 특성의 갯수 * embedding의 차원을 고려하여 추가되는 것
    self.decoder = nn.Sequential(
        nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
        nn.ReLU(),
        nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
        nn.ReLU(),
        nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
        nn.ReLU(),
        nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
        nn.Sigmoid()
    )
    self.image_size = image_size

  def encode(self, x):
    x = self.encoder(x)
    mu, logvar = self.fc_mu(x), self.fc_var(x)
    return mu, logvar

  def reparameterize(self, mu, logvar):
    std = torch.exp(0.5*logvar)
    eps = torch.randn_like(std)
    return mu + eps*std

  def decode(self, z):
    x = self.decoder_input(z)
    x = x.view(-1, 256, (self.image_size // 16), (self.image_size // 16))
    x = self.decoder(x)
    return x

  def forward(self, x, labels):
    mu, logvar = self.encode(x)
    z = self.reparameterize(mu, logvar)
    B, _ = z.shape
    z_labels = self.Embeddings(labels).reshape(B, -1)
    z = torch.cat((z, z_labels), 1)
    recon_x = self.decode(z)
    return recon_x, mu, logvar


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = VAE().to(device)
model.load_state_dict(torch.load('vae_celeba.pt'))



In [None]:
data = next(iter(celeba_loader))
imgs, labels = data['images'].float().to(device), data['labels'].long().to(device)

In [None]:
with torch.no_grad():
  reconstructions, _, _ = model(imgs, labels)
  reconstructions = reconstructions.detach().cpu().clip(0,1)

display_grid(imgs.detach().cpu(), figsize=(8,8))
plt.savefig('vae_training_data.png', bbox_inches='tight', pad_inches=0, dpi=156)
plt.show()
display_grid(reconstructions, figsize=(8,8))
plt.savefig('vae_reconstruction.png', bbox_inches='tight', pad_inches=0, dpi=156)
plt.show()

In [None]:
att_names = classes[[39, 20, 22, 2, 31, 15]]
att_names

In [None]:
sample_idx = np.random.choice(len(imgs))
att_idx = 1

with torch.no_grad():
  reconstructions, _, _ = model(imgs[[sample_idx]], labels[[sample_idx]])
  reconstructions = reconstructions.detach().cpu().clip(0,1)

new_labels = labels[[sample_idx]]
new_labels[:, att_idx]= 1-new_labels[:, att_idx]

with torch.no_grad():
  reconstructions2, _, _ = model(imgs[[sample_idx]], new_labels)
  reconstructions2 = reconstructions2.detach().cpu().clip(0,1)

plt.subplot(131)
plt.imshow(imgs[[sample_idx]].permute(0, 2, 3, 1).cpu().numpy()[0])
plt.axis('off')
plt.title('Original')
plt.subplot(132)
plt.imshow(reconstructions.permute(0, 2, 3, 1).cpu().numpy()[0])
plt.axis('off')
plt.title(f'{att_names[att_idx]} %d {%(labels[sample_idx, att_idx])}')
plt.subplot(133)
plt.imshow(reconstructions2.permute(0, 2, 3, 1).cpu().numpy()[0])
plt.axis('off')
plt.title(f'{att_names[att_idx]} %d {%(new_labels[sample_idx, att_idx])}')
