In [23]:
import os
import csv
import json
import easydict
import cv2
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from torchvision.utils import save_image
from torch.autograd import Variable

from PIL import Image

from tqdm import tqdm
from time import sleep

import matplotlib.pyplot as plt

In [24]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

In [25]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, transform=transform),
    batch_size=4, shuffle=True
)

In [26]:
# Initialize weights of network
def normal_init(m, mean, std):
  # network m이 fully connected layer이면
  if isinstance(m, nn.Linear):
    m.weight.data.normal_(mean, std)
    m.bias.data.zero_()

In [22]:
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()

    # z value를 100차원에서 256차원으로 임베딩한다.
    self.fc1_1 = nn.Linear(100, 256)
    self.fc1_1_bn = nn.BatchNorm1d(256)

    # y value(class label)가 10차원이고 이를 256차원으로 임베딩한다.
    self.fc1_2 = nn.Linear(10, 256)
    self.fc1_2_bn = nn.BatchNorm1d(256)

    # 임베딩한 z와 y를 concatenation 하면 512차원
    self.fc2 = nn.Linear(512, 2048)
    self.fc2_bn = nn.BatchNorm1d(2048)

    self.fc3 = nn.Linear(2048, 1024)
    self.fc3_bn = nn.BatchNorm1d(1024)

    # 이미지 크기 28 * 28의 784 차원의 output 생성
    self.fc4 = nn.Linear(1024, 784)
    self.dropout = nn.Dropout(p=0.2)

  
  def weight_init(self, mean, std):
    for m in self._modules:
      normal_init(self._modules[m], mean, std)

  def forward(self, input, label):
    x = F.relu(self.fc1_1_bn(self.fc1_1(input)))
    y = F.relu(self.fc1_2_bn(self.fc1_2(label)))
    
    x = torch.cat([x, y], dim=1)

    x = F.relu(self.fc2_bn(self.fc2(x)))
    x = self.dropout(x)
    x = F.relu(self.fc3_bn(self.fc3(x)))
    x = self.dropout(x)
    x = F.tanh(self.fc4(x))

    return x

In [27]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()

    # generator가 생성한 784차원의 이미지 데이터를 1024차원으로 임베딩한다.
    self.fc1_1 = nn.Linear(784, 1024)
    # y value가 10차원이고 이를 1024차원으로 임베딩한다.
    self.fc1_2 = nn.Linear(10, 1024)
    
    self.fc2 = nn.Linear(2048, 512)
    self.fc2_bn = nn.BatchNorm1d(512)

    self.fc3 = nn.Linear(512, 256)
    self.fc3_bn = nn.BatchNorm1d(256)

    self.fc4 = nn.Linear(256, 1)
    self.dropout = nn.Dropout(p=0.2)

  def weight_init(self, mean, std):
    for m in self._modules:
      normal_init(self._modules[m], mean, std)

  def forward(self, input, label):
    x = F.leaky_relu(self.fc1_1(input), 0.1)
    y = F.leaky_relu(self.fc1_2(label), 0.1)

    x = torch.cat([x, y], dim=1)
    x = F.leaky_relu(self.fc2_bn(self.fc2(x)), 0.1)
    x = self.dropout(x)
    x = F.leaky_relu(self.fc3_bn(self.fc3(x)), 0.1)
    x = F.sigmoid(self.fc4(x))

    return x

In [28]:
parser = easydict.EasyDict({
    "n_epochs": 50, 
    "batch_size":64, 
    "lr":0.00001, 
    "b1":0.5, 
    "b2":0.999, 
    "n_cpu":8, 
    "latent_dim":100, 
    "n_classes":10, 
    "img_size":28, 
    "channels":1, 
    "sample_interval":1
})

In [29]:
cuda = True if torch.cuda.is_available() else False
img_shape = (parser.channels, parser.img_size, parser.img_size)

In [30]:
def sample_image(n_row, epoch):
  size = parser['img_size']

  # latent space에서 뽑는 latent value z는 noise 역할
  z = torch.randn(n_row, parser.latent_dim).type(torch.FloatTensor).cuda()
  
  gen_labels = []

  # noise에 관해 임의로 label 생성
  for randpos in np.random.randint(0, parser.n_classes, n_row):
    gen_labels.append(torch.eye(parser.n_classes)[randpos])

  # shape of gen_labels: (n_row, 10)
  gen_labels = torch.stack(gen_labels).cuda()

  generator.eval()
  gen_imgs = generator(z, gen_labels)
  save_image(gen_imgs.view(n_row, 1, size, size).data, "%d.png" % epoch, nrow=n_row, normalize=True)

In [31]:
cross_entropy = torch.nn.BCELoss().cuda()
generator = Generator().cuda()
discriminator = Discriminator().cuda()

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, transform=transform),
    batch_size=parser['batch_size'], shuffle=True   
)

val_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, download=True, transform=transform),
    batch_size=parser['batch_size'], shuffle=True
)

optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.00001, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.00001, betas=(0.5, 0.999))

In [32]:
discriminator.train()

g_loss = torch.Tensor([0])
d_loss = torch.Tensor([0])

for epoch in range(parser.n_epochs):
  for batch_idx, (x, y) in enumerate(train_loader):
    # 이미지의 28 * 28 차원을 784 하나의 차원으로 바꾼다.
    x_flatten = x.view(x.shape[0], -1)

    # 레이블에 관하여 one-hot encoding
    one_hot_label = torch.nn.functional.one_hot(y, num_classes=parser['n_classes'])

    # GPU로 올리기
    img_torch2vec = x_flatten.type(torch.FloatTensor).cuda()
    label_torch = one_hot_label.type(torch.FloatTensor).cuda()

    valid = torch.ones(parser.batch_size, 1).cuda()
    fake = torch.zeros(parser.batch_size, 1).cuda()

    real_imgs = img_torch2vec
    labels = label_torch

    # Train generator

    optimizer_G.zero_grad()

    z = torch.randn(parser.batch_size, parser.latent_dim).cuda()

    gen_labels = []

    # noise에 관해 임의로 label 생성
    for randpos in np.random.randint(0, parser.n_classes, parser.batch_size):
      gen_labels.append(torch.eye(parser.n_classes)[randpos])

    # shape of gen_labels: (n_row, 10)
    gen_labels = torch.stack(gen_labels).cuda()

    gen_imgs = generator(z, gen_labels)

    val_output = discriminator(gen_imgs, gen_labels)
    g_loss = cross_entropy(val_output, valid)

    g_loss.backward()
    optimizer_G.step()


    # Train discriminator

    optimizer_D.zero_grad()

    validity_real = discriminator(real_imgs, labels)

    # batch의 크기가 설정한 batch 크기보다 모자란 경우 (마지막 배치 가능성)
    try:
      d_real_loss = cross_entropy(validity_real, valid)
    except:
      valid = torch.ones(validity_real.shape[0], 1).cuda()
      d_real_loss = cross_entropy(validity_real, valid)

    validity_fake = discriminator(gen_imgs.detach(), gen_labels)
    d_fake_loss = cross_entropy(validity_fake, fake)

    d_loss = (d_real_loss + d_fake_loss) / 2

    d_loss.backward()
    optimizer_D.step()

    if batch_idx % 100 == 0:
       print('{:<13s}{:<8s}{:<6s}{:<10s}{:<8s}{:<9.5f}{:<8s}{:<9.5f}'
       .format('Train Epoch: ', '[' + str(epoch) + '/' + str(parser['n_epochs']) + ']', 
               'Step: ', '[' + str(batch_idx) + '/' + str(len(train_loader)) + ']', 
               'G loss: ', g_loss.item(), 
               'D loss: ', d_loss.item())
       )

    
  if epoch % parser.sample_interval == 0:
    sample_image(n_row=10, epoch=epoch)



Train Epoch: [0/50]  Step: [0/938]   G loss: 0.70223  D loss: 0.70925  
Train Epoch: [0/50]  Step: [100/938] G loss: 0.70903  D loss: 0.69175  
Train Epoch: [0/50]  Step: [200/938] G loss: 0.71599  D loss: 0.68960  
Train Epoch: [0/50]  Step: [300/938] G loss: 0.72896  D loss: 0.66992  
Train Epoch: [0/50]  Step: [400/938] G loss: 0.74036  D loss: 0.66657  
Train Epoch: [0/50]  Step: [500/938] G loss: 0.75312  D loss: 0.64150  
Train Epoch: [0/50]  Step: [600/938] G loss: 0.76056  D loss: 0.63693  
Train Epoch: [0/50]  Step: [700/938] G loss: 0.78457  D loss: 0.62080  
Train Epoch: [0/50]  Step: [800/938] G loss: 0.79767  D loss: 0.60652  
Train Epoch: [0/50]  Step: [900/938] G loss: 0.81074  D loss: 0.60507  
Train Epoch: [1/50]  Step: [0/938]   G loss: 0.83261  D loss: 0.58429  
Train Epoch: [1/50]  Step: [100/938] G loss: 0.78922  D loss: 0.61770  
Train Epoch: [1/50]  Step: [200/938] G loss: 0.81139  D loss: 0.61687  
Train Epoch: [1/50]  Step: [300/938] G loss: 0.80839  D loss: 0.

😓 I think a mode collapse might have occured...

In [33]:
def show_image(condition:int):
  generator.eval()

  z = torch.randn(1, parser.latent_dim).type(torch.FloatTensor).cuda()
  condition_vector = torch.eye(parser.n_calsses)[condition].reshape(-1, 1).cuda()
  gen_imgs = generator(z, condition_vector)
  plt.imshow(gen_imgs.view(1, 1, 28, 28)[0][0].detach().cpu().numpy(), cmap='gray')