In [1]:
import os
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.autograd import Variable


import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision import datasets

num_eps=10 # 학습 epoch 
bsize=32 # batch size
lrate=0.001 # learning_rate
lat_dimension=64 # Generator input vector size
image_sz=64 # 생성 image size
chnls=1 # image channel
logging_intv=200 # log output interval

In [2]:
class GANGenerator(nn.Module): # 생성기

    # Network 구성
    def __init__(self):
        super(GANGenerator, self).__init__()
        # Linear layer 이후 reshape feature map 크기 설정 
        # -> 추후 Upsampling을 통해 크기 복원
        self.inp_sz = image_sz // 4
        self.lin = nn.Linear(lat_dimension, 128 * self.inp_sz ** 2) # latent vector(z) -> 128 x inp_sz x inp_sz
        self.bn1 = nn.BatchNorm2d(128) # 1st Batch Normalization
        self.up1 = nn.Upsample(scale_factor=2) # 1st Upsampling (해상도 2배 증가)
        self.cn1 = nn.Conv2d(128, 128, 3, stride=1, padding=1) # 128 channel 유지 Conv
        self.bn2 = nn.BatchNorm2d(128, 0.8) # 2nd Batch Normalization
        self.rl1 = nn.LeakyReLU(0.2, inplace=True) # 1st LeakyReLU, 음수 영역도 일부 허용 -> gradient 소실 방지
        self.up2 = nn.Upsample(scale_factor=2) # 2nd Upsampling
        self.cn2 = nn.Conv2d(128, 64, 3, stride=1, padding=1) # channel 128 -> 64 Conv
        self.bn3 = nn.BatchNorm2d(64, 0.8) # 3rd Batch Normalization
        self.rl2 = nn.LeakyReLU(0.2, inplace=True) # 2nd LeakyReLU
        self.cn3 = nn.Conv2d(64, chnls, 3, stride=1, padding=1) # Last Conv 64 channels -> image channel 1 or 3
        self.act = nn.Tanh() # output normalization (-1 ~ 1) 

    # 순전파 정의
    # input: latent vector(z)
    def forward(self, x):
        x = self.lin(x) # z -> dense layer
        # linear output vector -> reshape 4D Tensor (batch, channel, height, width)  
        x = x.view(x.shape[0], 128, self.inp_sz, self.inp_sz)
        x = self.bn1(x) # batch normalization
        x = self.up1(x) # Upsampling (해상도 2배 증가)
        x = self.cn1(x) # convolution 진행
        x = self.bn2(x) # batch normalization
        x = self.rl1(x) # LeakyReLU
        x = self.up2(x) # Upsampling
        x = self.cn2(x) # convolution 진행
        x = self.bn3(x) # batch normalization
        x = self.rl2(x) # LeakyReLU
        x = self.cn3(x) # convolution 진행 후 최종 채널 수로 변환
        out = self.act(x) # Tanh로 output normalization (-1~1)
        return out # 생성된 image 반환

In [3]:
class GANDiscriminator(nn.Module): # 판별기
    def __init__(self):
        super(GANDiscriminator, self).__init__()

        # 하나의 Conv Block 구성
        # input_channels, output_channels
        def disc_module(ip_chnls, op_chnls, bnorm=True):
            mod = [nn.Conv2d(ip_chnls, op_chnls, 3, 2, 1), # kernel 3x3, stride=2 downsampling
                   nn.LeakyReLU(0.2, inplace=True), # 비선형성 도입
                   nn.Dropout2d(0.25)] # Overfitting Dropout
            if bnorm:
                mod += [nn.BatchNorm2d(op_chnls, 0.8)] # optional
            return mod

        # disc_module을 이어붙인 convolution stack
        # 4개의 block을 순차적으로 적용 -> channel 증가, 해상도 감소
        self.disc_model = nn.Sequential(
            *disc_module(chnls, 16, bnorm=False),
            *disc_module(16, 32),
            *disc_module(32, 64),
            *disc_module(64, 128),
        )

        # width and height down-sized image
        ds_size = image_sz // 2 ** 4
        self.adverse_lyr = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1),
                                         nn.Sigmoid()) 

    def forward(self, x):
        x = self.disc_model(x) # conv block 토과
        x = x.view(x.shape[0], -1) # flatten
        out = self.adverse_lyr(x) # linear + sigmoid
        return out

In [4]:
gen = GANGenerator() # Generator model: random latent vector -> image 생성
disc = GANDiscriminator() # Discriminator model: image 진짜/가짜 판별

# define the loss metric -> Binary Cross Entropy Loss
adv_loss_func = torch.nn.BCELoss()

In [5]:
# define the dataset and corresponding dataloader
dloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "./data/mnist/",
        download=True,
        transform=transforms.Compose(
            [transforms.Resize((image_sz, image_sz)), # image resize 64x64
             transforms.ToTensor(), # image tensor 0~1 
             transforms.Normalize([0.5], [0.5])] # image normalization -1~1 
        ),
    ),
    batch_size=bsize,
    shuffle=True,
)

# define the optimization schedule for both G and D -> Adam
opt_gen = torch.optim.Adam(gen.parameters(), lr=lrate)
opt_disc = torch.optim.Adam(disc.parameters(), lr=lrate)

In [6]:
os.makedirs("./images_mnist", exist_ok=True)

# epoch 10회 반복
for ep in range(num_eps):
    # batch 반복
    for idx, (images, _) in enumerate(dloader):

        # truth/fake Ground Truth Label 생성
        good_img = Variable(torch.FloatTensor(images.shape[0], 1).fill_(1.0), requires_grad=False)
        bad_img = Variable(torch.FloatTensor(images.shape[0], 1).fill_(0.0), requires_grad=False)

        # get a real image
        actual_images = Variable(images.type(torch.FloatTensor))

        # train the generator model
        opt_gen.zero_grad()

        # generate a batch of images based on random noise as input
        noise = Variable(torch.FloatTensor(np.random.normal(0, 1, (images.shape[0], lat_dimension))))
        gen_images = gen(noise)

        # generator model optimization - how well can it fool the discriminator
        generator_loss = adv_loss_func(disc(gen_images), good_img)
        generator_loss.backward()
        opt_gen.step()

        # train the discriminator model
        opt_disc.zero_grad()

        # calculate discriminator loss as average of mistakes(losses) in confusing real images as fake and vice versa
        actual_image_loss = adv_loss_func(disc(actual_images), good_img)
        fake_image_loss = adv_loss_func(disc(gen_images.detach()), bad_img)
        discriminator_loss = (actual_image_loss + fake_image_loss) / 2

        # discriminator model optimization
        discriminator_loss.backward()
        opt_disc.step()

        batches_completed = ep * len(dloader) + idx
        if batches_completed % logging_intv == 0:
            print(f"epoch number {ep} | batch number {idx} | generator loss = {generator_loss.item()} | discriminator loss = {discriminator_loss.item()}")
            save_image(gen_images.data[:25], f"images_mnist/{batches_completed}.png", nrow=5, normalize=True)

epoch number 0 | batch number 0 | generator loss = 0.6853175163269043 | discriminator loss = 0.6935094594955444
epoch number 0 | batch number 200 | generator loss = 0.9600651860237122 | discriminator loss = 0.596145510673523
epoch number 0 | batch number 400 | generator loss = 0.9147008657455444 | discriminator loss = 0.5393365621566772
epoch number 0 | batch number 600 | generator loss = 1.2645645141601562 | discriminator loss = 0.4504905343055725
epoch number 0 | batch number 800 | generator loss = 1.3696907758712769 | discriminator loss = 0.8016576170921326
epoch number 0 | batch number 1000 | generator loss = 0.6775225400924683 | discriminator loss = 0.4123033881187439
epoch number 0 | batch number 1200 | generator loss = 3.7660908699035645 | discriminator loss = 0.13774575293064117
epoch number 0 | batch number 1400 | generator loss = 3.423884153366089 | discriminator loss = 0.18035462498664856
epoch number 0 | batch number 1600 | generator loss = 3.87838077545166 | discriminator 

In [7]:
# U-Net 구조의 Generator 구현
# image-to-iamge transform 사용
class UNetGenerator(nn.Module):
    def __init__(self, chnls_in=3, chnls_op=3):
        super(UNetGenerator, self).__init__()

        # encoder(Downsampling)
        # 8단계 encoding layer
        # 채널 수 점점 증가/유지, 해상도 감소
        # 4~7th layer dropout 적용 -> overfitting 방지
        # DownConvBlock, UpConvBlock class 참고
        self.down_conv_layer_1 = DownConvBlock(chnls_in, 64, norm=False)
        self.down_conv_layer_2 = DownConvBlock(64, 128)
        self.down_conv_layer_3 = DownConvBlock(128, 256)
        self.down_conv_layer_4 = DownConvBlock(256, 512, dropout=0.5)
        self.down_conv_layer_5 = DownConvBlock(512, 512, dropout=0.5)
        self.down_conv_layer_6 = DownConvBlock(512, 512, dropout=0.5)
        self.down_conv_layer_7 = DownConvBlock(512, 512, dropout=0.5)
        self.down_conv_layer_8 = DownConvBlock(512, 512, norm=False, dropout=0.5)

        # decoder(Upsampling)
        # 각 decoder는 대응되는 encoder의 출력을 skip connection으로 받아 함께 처리
        # input channel: 이전 decoder output 512 + 대응 encoder output 512 => concat
        self.up_conv_layer_1 = UpConvBlock(512, 512, dropout=0.5)
        self.up_conv_layer_2 = UpConvBlock(1024, 512, dropout=0.5)
        self.up_conv_layer_3 = UpConvBlock(1024, 512, dropout=0.5)
        self.up_conv_layer_4 = UpConvBlock(1024, 512, dropout=0.5)
        self.up_conv_layer_5 = UpConvBlock(1024, 256)
        self.up_conv_layer_6 = UpConvBlock(512, 128)
        self.up_conv_layer_7 = UpConvBlock(256, 64)

        # 마지막 upsampling + zero padding + convolution -> 원하는 크기로 맞추기
        self.upsample_layer = nn.Upsample(scale_factor=2)
        self.zero_pad = nn.ZeroPad2d((1, 0, 1, 0))
        self.conv_layer_1 = nn.Conv2d(128, chnls_op, 4, padding=1)
        self.activation = nn.Tanh() # output normalization -1~1

    def forward(self, x):
        
        # image를 점점 압축하며 특징 추출
        # 각 단계 결과는 디코딩 단계에서 skip connection으로 재사용
        enc1 = self.down_conv_layer_1(x)
        enc2 = self.down_conv_layer_2(enc1)
        enc3 = self.down_conv_layer_3(enc2)
        enc4 = self.down_conv_layer_4(enc3)
        enc5 = self.down_conv_layer_5(enc4)
        enc6 = self.down_conv_layer_6(enc5)
        enc7 = self.down_conv_layer_7(enc6)
        enc8 = self.down_conv_layer_8(enc7)

        # 각 upsample block은 대응 encoder 출력을 함께 받아 concate후, 처리
        # skip connection으로 해상도 정보 + 로컬 정보 보존
        dec1 = self.up_conv_layer_1(enc8, enc7)
        dec2 = self.up_conv_layer_2(dec1, enc6)
        dec3 = self.up_conv_layer_3(dec2, enc5)
        dec4 = self.up_conv_layer_4(dec3, enc4)
        dec5 = self.up_conv_layer_5(dec4, enc3)
        dec6 = self.up_conv_layer_6(dec5, enc2)
        dec7 = self.up_conv_layer_7(dec6, enc1)

        # 마지막 upsample + zero padding으로 크기 보정
        # Conv2D로 최종 output channel 수로 변환
        final = self.upsample_layer(dec7)
        final = self.zero_pad(final)
        final = self.conv_layer_1(final)
        return self.activation(final) # Tanh()으로 이미지 픽셀 값 정규화

In [8]:
# U-Net decoding module
# upsampling + normalization + activation + optional dropout 수행
# input tensor와 대응 encoder output tensor를 concat -> skip connection 구현
class UpConvBlock(nn.Module):
    # ip_sz: input channels
    # op_sz: output channels
    def __init__(self, ip_sz, op_sz, dropout=0.0):
        super(UpConvBlock, self).__init__()
        self.layers = [
            nn.ConvTranspose2d(ip_sz, op_sz, 4, 2, 1), # upsampling, input 해상도 2배 증가
            nn.InstanceNorm2d(op_sz), # instance normalization, small batch에 유리
            nn.ReLU(), # 비선형성 추가
        ]
        if dropout: # optional dropout
            self.layers += [nn.Dropout(dropout)]

    # x: 이전 decoder output
    # enc_ip: skip connection을 위한 encoder 출력
    def forward(self, x, enc_ip):
        x = nn.Sequential(*(self.layers))(x) # upsample + normalization + activation
        op = torch.cat((x, enc_ip), 1) # skip connection: encoder output과 연결
        return op

In [9]:
# encoding 단계에서 사용되는 Block
# Conv -> normalization -> LeakyReLU -> Dropout
class DownConvBlock(nn.Module):
    # ip_sz: input channels
    # op_sz: output channels
    # norm: InstanceNorm 사용 여부
    def __init__(self, ip_sz, op_sz, norm=True, dropout=0.0):
        super(DownConvBlock, self).__init__()
        # downsampling
        self.layers = [nn.Conv2d(ip_sz, op_sz, 4, 2, 1)] # kernel=4, stride=2, padding=1
        if norm: # optional normalization
            self.layers.append(nn.InstanceNorm2d(op_sz))
        self.layers += [nn.LeakyReLU(0.2)]
        if dropout: # optional dropout
            self.layers += [nn.Dropout(dropout)]
    def forward(self, x):
        op = nn.Sequential(*(self.layers))(x) 
        return op

In [10]:
# Pix2Pix 구조에서 사용하는 Discriminator 정의
# image를 patch 단위로 분류하는 PatchGAN 사용
# input: 진짜 image, 생성 image (2개의 이미지 concat)
# output: 각 patch별로 진짜/가짜 판단 map
class Pix2PixDiscriminator(nn.Module): # 진짜 이미지 와 가짜 이미지를 비교하여 둘의 차이를 학습
    def __init__(self, chnls_in=3):
        super(Pix2PixDiscriminator, self).__init__()

        # 내부 블록 정의
        def disc_conv_block(chnls_in, chnls_op, norm=1):
            layers = [nn.Conv2d(chnls_in, chnls_op, 4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.InstanceNorm2d(chnls_op))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        # 입력 이미지와 변환 이미지 2개를 채널 방향으로 concat
        # 채널 수를 늘리며 특징 추출
        self.lyr1 = disc_conv_block(chnls_in * 2, 64, norm=0)
        self.lyr2 = disc_conv_block(64, 128)
        self.lyr3 = disc_conv_block(128, 256)
        self.lyr4 = disc_conv_block(256, 512)
    
    def forward(self, real_image, translated_image):
        # 두 이미지를 채널 방향으로 합치기
        ip = torch.cat((real_image, translated_image), 1)
        # discriminator block 적용 -> downsampling을 통한 특징 추출
        # output feature map의 공간 해상도는 점점 작아진다.
        op = self.lyr1(ip)
        op = self.lyr2(op)
        op = self.lyr3(op)
        op = self.lyr4(op)
        # zero padding + 최종 conv
        op = nn.ZeroPad2d((1, 0, 1, 0))(op)
        op = nn.Conv2d(512, 1, 4, padding=1)(op)
        return op # 각 patch마다 진짜/가짜 판단 결과 map