In [None]:
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.autograd import Variable

In [None]:
# 모델 파라미터 저장할 폴더 생성
if not os.path.exists('checkpoints'):
  os.makedirs('checkpoints')


In [None]:
# 모델 출력 저장할 폴더 생성
if not os.path.exists('outputs'):
  os.makedirs('outputs')

In [None]:
# Discriminator networks - 구분자!!! 참, 거짓만 밝혀내면 되므로 출력은 1차원이면 됨
class Discriminator(nn.Module):
  def __init__(self, in_dim=3,  dim=64):
    super(Discriminator, self).__init__()

    def conv_ln_lrelu(in_dim, out_dim):
      return nn.Sequential(
          nn.Conv2d(in_dim, out_dim, 5, 2, 2),
          # LayerNorm의 효과는 여기서는 효과가 없기 때문에
          # LayerNorm 대신, InstanceNorm2d를 활용
          nn.InstanceNorm2d(out_dim, affince=True),
          nn.LeakyReLU(0.2)
      )
    self.ls = nn.Sequential(
        nn.Conv2d(in_dim, dim, 5, 2, 2),
        nn.LeakyReLU(0, 2),
        conv_ln_lrelu(dim, dim*2),
        conv_ln_lrelu(dim*2, dim*4),
        conv_ln_lrelu(dim*4, dim*8),
        nn.Conv2d(dim*8, 1, 4),

    )
  # 이제 discriminator는 0, 1로 참, 거짓을 판별하는 대신
  def forward(self, x):
    y = self.ls(x)
    y = y.view(-1)
    return y


In [None]:
# Generator networks
class Generator(nn.Module):

  def __init__(self, in_dim=128, dim=64):
    super(Generator, self).__init__()

    def dconv_bn_relu(in_dim, out_dim):
      return nn.Sequential(
          nn.ConvTranspose2d(in_dim, out_dim, 5, 2, padding=2, output_padding=1, bias=False),
          nn.BatchNorm2d(out_dim),
          nn.ReLU()
      )

    self.l1 = nn.Sequential(
        nn.Linear(in_dim, dim * 8 * 4 * 4, bias=False),
        nn.BatchNorm1d(dim * 8 * 4 * 4),
        nn.ReLU()
    )
    self.l2_5 = nn.Sequential(
        dconv_bn_relu(dim*8, dim*4),
        dconv_bn_relu(dim*4, dim*2),
        dconv_bn_relu(dim*2, dim),
        nn.ConvTranspose2d(dim, 3, 5, 2, padding=2, output_padding=1),
        nn.Tanh()
    )

  def forward(self, x):
    y = self.l1(x)
    y = y.view(y.size(0), -1, 4, 4)
    y = self.l2_5(y)
    return y

In [None]:
def show(img, renorm=False, nrow=8, interpolation='bicubic'):
  if renorm:
    img = img*0.5 + 0.5
  img_grid = torchvision.utils.make_grid(img, nrow=nrow).numpy()
  plt.figure()
  plt.imshow(np.transpose(img_grid, (1,2,0)), interpolation=interpolation)
  plt.axis('off')
  plt.show()

In [None]:
root = 'data_faces/img_aling_celeba'
img_list = os.listdir(root)
print(len(img_list))

In [None]:
crop_size = 128
re_size=64

offset_height = (218 - crop_size) // 2
offset_width = (178 - crop_size) // 2
crop = lambda x: x[:, offset_height:offset_height + crop_size, offset_width:offset_width + crop_size]

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(crop),
    transforms.ToPILImage(),
    transforms.Resize(size=(re_size, re_size), interpolation=Image.BICUBIC),
    transforms.ToTensor()
    transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3)
])

batch_size = 64
celeba_data = datasets.ImageFolder('./data_faces', transform=transform)
celeba_loader= DataLoader(celeba_data, batch_size=batch_size, shuffle=True)

In [None]:
batch, _ = next(iter(celeba_loader))
show(batch[0:16], nrow=4)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
# model weights를 초기화하는 함수
def initialize_weights(model):
  for m in model.modules():
    if isinstance(m, nn.Conv2d):
      nn.init.normal_(m.weight.data, 0.0, 0.02), # 평균 0, 분산 0.02
    if isinstance(m, nn.ConvTranspose2d):
      nn.init.normal_(m.weight.data, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
      nn.init.normal_(m.weight.data, 0.0, 0.02)

In [None]:
# 생성자, 판별자 만들기
generator = Generator().to(device)
discriminator = Discriminator().to(device)

initialize_weights(generator)
initialize_weights(discriminator)

# loss 함수 설정
adversarial_loss = nn.BCELoss()

# optimizer 설정
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
# VAE 훈련
num_epochs = 50
latent_dim = 128

G_losses = []
D_losses = []

lambda_gp = 10
n_critic = 5 # 이 숫자는 generator 훈련하기 전에 discriminator 훈련을 몇 번씩 끊어 돌릴 건지 결정함
num_steps = 0

# 훈련 루프
for epoch in range(num_epochs):
  for i, (real_images, _) in enumerate(celeba_loader):

    real_images = real_images.to(device)
    batch_size = real_images.size(0)

    #Discriminator 훈련
    for _ in range(n_critic):
      # Discriminator 훈련
      optimizer_D.zero_grad()

      # Discriminator의 실제 그림 출력물
      D_real = discriminator(real_images)
      D_real_loss = torch.mean(D_real)

      # Generator가 가짜 그림을 만들고, discriminator가 결과물 확인하기
      z = torch.randn(batch_size, latent_dim).cuda()
      fake_images = generator(z)
      D_fake = discriminator(fake_images)
      D_fake_loss = torch.mean(D_fake)

      # gradient penalty 계산
      alpha = torch.rand(batch_size, 1, 1, 1).cuda()
      x_hat = alpha * real_images.data + (1 - alpha) * fake_images.data
      x_hat.requires_grad = True
      pred_hat = discriminator(x_hat)
      gradients = torch.autograd.grad(output=pred_hat, input=x_hat, grad_outputs=torch.ones(pred_hat.size()).cuda(),
                                      create_graph=True, retain_graph=True, only_inputs=True)[0]
      gradient_penalty = lambda_gp * torch.mean((1. - torch.sqrt(1e-8*torch.sum(gradients.view(gradients.size(0), -1)**2, dim=1)))**2) # torch.sqrt(torch.sum(gradients **2, dim=1) + 1e-12).mean()

      # total discriminator loss
      D_loss = D_fake_loss - D_real_loss + gradient_penalty
      D_loss.backward()
      optimizer_D.step()

    # Generator 훈련
    optimizer_G.zero_grad()

    # 거짓 그림 만들고, discriminator의 출력 확인
    z = torch.randn(batch_size, latent_dim).cuda()
    fake_images = generator(z)
    D_fake = discriminator(fake_images)
    G_loss = -torch.mean(D_fake)

    G_loss.backward()
    optimizer_G.step()

    # loss를 더해두기
    D_losses.append(D_loss.item())
    G_losses.append(G_loss.item())

    if i % 100 == 0:
      print(f"[Epoch {epoch}/{num_epochs}] [Batch {i}/{celeba_loader}]")
      print(f"D Loss: {D_loss.item()}")
      print(f"G Loss: {G_loss.item()}")
      with torch.no_grad():
        fake_images = fake_images.detach().cpu().permute(0,2,3,1)/2
        fake_images += 0.5
        fake_images = fake_images.clip(0,1).numpy()
        fig, axs = plt.subplots(5, 5)
        cnt = 0
        for i in range(5):
          for j in range(5):
            axs[i, j].imshow(fake_images[cnt])
            axs[i, j].axis('off')
            cnt += 1
        plt.savefig(os.path.join("outputs", f"gan_images_epoch_{epoch}.png"))
        plt.close()


torch.save(generator.state_dict(), os.path.join('checkpoints', f'G_{epoch}.png'))
torch.save(discriminator.state_dict(), os.path.join('checkpoints', f'D_{epoch}.png'))
