In [None]:
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.autograd import Variable

In [None]:
# 모델 파라미터 저장할 폴더 생성
if not os.path.exists('checkpoints'):
  os.makedirs('checkpoints')


In [None]:
# 모델 출력 저장할 폴더 생성
if not os.path.exists('outputs'):
  os.makedirs('outputs')

In [None]:
# Discriminator networks - 구분자!!! 참, 거짓만 밝혀내면 되므로 출력은 1차원이면 됨
class Discriminator(nn.Module):
  def __init__(self, image_size=128):
    super(Discriminator, self).__init__()
    self.cls = nn.Sequential(
        nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),
        nn.ReLU(),
        nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
        nn.ReLU(),
        nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
        nn.ReLU(),
        nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(256 * image_size // 16 * image_size // 16, 1),
        nn.Sigmoid()
    )

  def forward(self, img):
    validity = self.cls(img)
    return validity

In [None]:
# Generator networks
class Generator(nn.Module):

  def __init__(self, image_size=128, latent_dim=512):
    super(Generator, self).__init__()

    self.decoder_input = nn.Linear(latent_dim, 256 * (image_size // 16) * (image_size //16))
    self.decoder = nn.Sequential(
        nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
        nn.ReLU(),
        nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
        nn.ReLU(),
        nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
        nn.ReLU(),
        nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
        nn.Sigmoid()
    )
    self.image_size = image_size

  def forward(self, z):
    x = self.decoder_input(z)
    x = x.view(-1, 256, (self.image_size // 16), (self.image_size // 16))
    x = self.decoder(x)
    return x

In [None]:
def show(img, renorm=False, nrow=8, interpolation='bicubic'):
  if renorm:
    img = img*0.5 + 0.5
  img_grid = torchvision.utils.make_grid(img, nrow=nrow).numpy()
  plt.figure()
  plt.imshow(np.transpose(img_grid, (1,2,0)), interpolation=interpolation)
  plt.axis('off')
  plt.show()

In [None]:
root = 'data_faces/img_aling_celeba'
img_list = os.listdir(root)
print(len(img_list))

In [None]:
re_size = 128

transform = transforms.Compose([
    transforms.Resize(size=(re_size, re_size), interpolation=Image.BICUBIC),
    transforms.ToTensor()
])
batch_size = 64
celeba_data = datasets.ImageFolder('./data_faces', transform=transform)
celeba_loader= DataLoader(celeba_data, batch_size=batch_size, shuffle=True)

In [None]:
batch, _ = next(iter(celeba_loader))
show(batch[0:16], nrow=4)

In [None]:
# 생성자, 판별자 만들기
generator = Generator()
discriminator = Discriminator()

# loss 함수 설정
adversarial_loss = nn.BCELoss()

# optimizer 설정
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
# VAE 훈련
num_epochs = 50
latent_dim = 512

G_losses = []
D_losses = []

# 훈련 루프
for epoch in range(num_epochs):
  for i, (imgs, _) in enumerate(celeba_loader):

    # Adversarial ground truths
    valid = torch.ones(imgs.size[0], 1)
    fake = torch.zeros(imgs.size[0], 1)

    # generator 훈련
    optimizer_G.zero_grad()
    z = torch.randn(imgs.size[0], latent_dim)
    gen_imgs = generator(z)
    g_loss = adversarial_loss(discriminator(gen_imgs), valid)
    g_loss.backward()
    optimizer_G.step()

    # discriminator 훈련
    optimizer_D.zero_grad()

    real_loss = adversarial_loss(discriminator(imgs), valid)
    fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
    d_loss = (real_loss + fake_loss) / 2

    d_loss.backward()
    optimizer_D.step()

    # loss 저장

    G_losses.append(g_loss.item())
    D_losses.append(d_loss.item())

    # 진행상황 확인
    if i % 100 == 0:
      print(f"[Epoch {epoch}/{num_epochs}] [Batch {i} /  {len(celeba_loader)}] [D Loss: {d_loss.item()}][G Loss: {g_loss.item()}]")

      with torch.no_grad():
        gen_imgs = generator(z)
        # gen_img = 0.5 * gen_imgs + 0.5 # denormalize
        gen_imgs = gen_imgs.permute(0, 2, 3, 1).clip(0,1).detach().numpy()
        fig, axs = plt.subplots(5,5)
        cnt = 0
        for i in range(5):
          for j in range(5):
            axs[i, j].imshow(gen_imgs[cnt])
            axs[i, j].axis('off')
            cnt += 1
        plt.savefig(os.path.join('outputs', f'gan_images_epoch_{epoch}.png'))
        plt.close()

torch.save(generator.state_dict(), os.path.join('checkpoints', f'G_{epoch}.png'))
torch.save(discriminator.state_dict(), os.path.join('checkpoints', f'D_{epoch}.png'))
# 나중에 보면 알겠지만, 이 GAN의 기본 구조도 문제가 있음...