In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

# Aquire dataset

## Download dataset

In [1]:
import os
if not os.path.exists('miscs/m0811/logs'):
    os.makedirs('miscs/m0811/logs')
if not os.path.exists('miscs/m0811/datasets'):
    os.makedirs('miscs/m0811/datasets')


In [5]:
""" Horse2zebra 데이터셋 다운로드"""
!bash miscs/m0811/download_cyclegan_dataset.sh horse2zebra

/bin/bash: ./download_cyclegan_dataset.sh: No such file or directory


Collecting torchvision==0.1.8
  Downloading torchvision-0.1.8-py2.py3-none-any.whl (37 kB)
Collecting pillow
  Using cached Pillow-7.2.0-cp38-cp38-win_amd64.whl (2.1 MB)
Collecting future
  Downloading future-0.18.2.tar.gz (829 kB)
Building wheels for collected packages: future
  Building wheel for future (setup.py): started
  Building wheel for future (setup.py): finished with status 'done'
  Created wheel for future: filename=future-0.18.2-py3-none-any.whl size=491062 sha256=ebdf69f6048c9cf9d809a56d4170c662712b8a2121093db32763223b0b89b85a
  Stored in directory: c:\users\lsh\appdata\local\pip\cache\wheels\8e\70\28\3d6ccd6e315f65f245da085482a2e1c7d14b90b30f239e2cf4
Successfully built future
Installing collected packages: pillow, torchvision, future
Successfully installed future-0.18.2 pillow-7.2.0 torchvision-0.1.8


# Model definition & Hyperparameter

In [6]:
import torch.nn as nn
import torch.nn.functional
import torch
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
from torchvision.utils import make_grid, save_image
from torch.optim.lr_scheduler import StepLR
from torchsummary import summary
import numpy as np


In [7]:
img_size = 256 # 입력 이미지 사이즈 256x256 
channels = 3
ngf = 32 # G channels after first layer
ndf = 64 # D channels after first layer

epochs = 15 # 훈련 횟수, 200정도 까지 돌리면 좋으나, 시간 단축을 위해 15를 이용
            # parameter를 바꿔 200으로 학습해서 결과를 확인할 수 있음
batch_size = 4 # 배치 사이즈
lambda_X = 10 # 하이퍼파라메터
lambda_Y = 10
lambda_identity_X = 0.5
lambda_identity_Y = 0.5
lr = 0.0002 # learning rate
betas = (0.5, 0.999)

mean_init = 0.0
std_init = 0.02

In [8]:
# Cuda stuff
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )
print("Device is " + str(device) + ".")

Device is cpu.


# CycleGAN Model

### Resblock


![Resblock](miscs/m0811/description/Resnet_block.png)


In [6]:
# ResidualBlock 설계
# 입력: Tensor 출력: Resnet Output
class ResidualBlock(nn.Module):
    def __init__(self, c):
        super(ResidualBlock, self).__init__()
        
        block = [nn.ReflectionPad2d(1),
                 nn.Conv2d(c, c, 3, 1, 0),
                 nn.InstanceNorm2d(c),
                 nn.ReLU(),
                 nn.ReflectionPad2d(1),
                 nn.Conv2d(c, c, 3, 1, 0),
                 nn.InstanceNorm2d(c)]
        
        self.block = nn.Sequential(*block)
        
    def forward(self, x):
        """
        ###YOUR CODE HERE 
        # Forward medthod에 Residual block의 아웃풋을 채우기  
        # Note: 위의 Description을 참조
        """
    
        return ????????????
###Testing code ####
test_tensor = torch.Tensor(1,3,64,64)
R = ResidualBlock(3)
assert(list(R(test_tensor).size()) == [1,3,64,64])
print('test1 통과')

test1 통과


### Generator
![Generator](miscs/m0811/description/Generator.png)


In [7]:
# Generator 설계
def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()
        
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        # Encoding
        model = []
        model += [nn.ReflectionPad2d(4),
                  nn.Conv2d(3, ngf, 9, 1, 0),
                  nn.InstanceNorm2d(ngf),
                  nn.ReLU()]
        model += [nn.Conv2d(ngf, ngf*2, 4, 2, 1),
                  nn.InstanceNorm2d(ngf*2),
                  nn.ReLU()]
        model += [nn.Conv2d(ngf*2, ngf*4, 4, 2, 1),
                  nn.InstanceNorm2d(ngf*4),
                  nn.ReLU()]
        
        # Transformation
        """
        ### YOUR CODE HERE 
        # 아래 물음표 친 곳의 코드를 채우시오
        # 1. Generator에 Residual block 을 채우기
        # 2. Decoding의 차원을 잘 맞춰서 원래 이미지 사이즈로 복원하기 ==> 
        # Note: 1. https://pytorch.org/docs/master/generated/torch.nn.ConvTranspose2d.html에서 output shape 확인하고 채우기 
                2. 매번 줄어든 이미지를 다시 2배씩 키우고 채널수를 2배 줄이기
        """
        for i in range(6):
            model += [????????????(ngf*4)]   # 채널 수를 그대로 유지하면서 반복시켜주는 Residual block
        
        # Decoding
        model += [nn.ConvTranspose2d(ngf*4, ngf*2, kernel_size =?, stride =?, padding = ?, output_padding=?), # 줄여준 H * W 를 다시 반대로 늘려주는 과정
                  nn.InstanceNorm2d(ngf*2),
                  nn.ReLU()]

        model += [nn.ConvTranspose2d(ngf*2, ngf, kernel_size =?, stride =?, padding = ?, output_padding=?),
                  nn.InstanceNorm2d(ngf),
                  nn.ReLU()]

        model += [nn.ReflectionPad2d(4),
                  nn.Conv2d(ngf, 3, 9, 1, 0),
                  nn.Tanh()]
        
        self.model = nn.Sequential(*model)

    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    def forward(self, x):
        return self.model(x)
### Testing code
test_tensor = torch.Tensor(1,3,256,256)
G= Generator()
assert(list(G(test_tensor).size()) ==[1,3,256,256])
print('테스트2 통과')

테스트2 통과


### Discriminator
![Discriminator](miscs/m0811/description/Discriminator.png)


In [8]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        """
        ### YOUR CODE HERE 
        # 아래 물음표 친 곳의 코드를 채우시오
        # 1. Input, Output channel 의 사이즈를 위 그림에 맞게 넣기
        # 2. Decoding의 차원을 잘 맞춰서 원래 이미지 사이즈로 복원하기 ==> 
        """
        
        model = []
        model += [nn.Conv2d(3, ??, ??, ??, ??),   # outputchannel : ndf, kernel: 4, stride:2 , padding : 1
                  nn.LeakyReLU(0.2)]
        
        in_channels = ndf
        out_channels = ndf*2
        
        for i in range(2):
            model += [nn.Conv2d(???????????, ??????????, 4, 2, 1),     # 어떤 변수가 input channel이 되고, 어떤 변수가 output channel이 되는가?
                      nn.InstanceNorm2d(out_channels),
                      nn.LeakyReLU(0.2)]
            # 매 반복마다 channel 수가 두배가 되도록 하려면?

            in_channels = ???????????????           
            out_channels = ???????????????

        model += [nn.Conv2d(in_channels, out_channels, 4, 1, 1),
                  nn.InstanceNorm2d(out_channels),
                  nn.LeakyReLU(0.2)]
        
        model += [nn.Conv2d(out_channels, 1, 4, 1, 1)]
        
        self.model = nn.Sequential(*model)

    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)
        
    def forward(self, x):
        return self.model(x)
### Testing code
test_tensor = torch.Tensor(1,3,256,256)
D= Discriminator()
assert(list(D(test_tensor).size()) ==[1,1,30,30])
print('테스트3 통과')


테스트3 통과


# Data Load

In [23]:
# Dataset Code

import os
from PIL import Image
import random

class UnallignedDataset(Dataset):
    def __init__(self, root, transform, phase='train'):
        dir_A = os.path.join(root, phase + 'A')
        dir_B = os.path.join(root, phase + 'B')
        
        self.A_paths = [os.path.join(dir_A, f) for f in os.listdir(dir_A)]
        self.B_paths = [os.path.join(dir_B, f) for f in os.listdir(dir_B)]
        self.A_size = len(self.A_paths)
        self.B_size = len(self.B_paths)
        
        self.transform = transform
        
    def __getitem__(self, index):
        A_path = self.A_paths[index % self.A_size]
        B_path = self.B_paths[random.randint(0, self.B_size - 1)]
        
        A_img = Image.open(A_path).convert('RGB')
        B_img = Image.open(B_path).convert('RGB')

        A = self.transform(A_img)
        B = self.transform(B_img)
        return A, B
    
    def __len__(self):
        return max(self.A_size, self.B_size)

In [24]:
# 학습을 돕기 위한 추가 테크닉 (과제를 위해 알아야할 필요는 없음) (참고: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/75)

class ImagePool():
    def __init__(self, pool_size):
        self.pool_size = pool_size
        self.images = []
        
    def get(self, img):
        if len(self.images) < self.pool_size:
            self.images.append(img)
            return img
        else:
            p = random.random()
            if p > 0.5:
                idx = random.randint(0, self.pool_size-1)
                tmp = self.images[idx]
                self.images[idx] = img
                return tmp
            else:
                return img

# Training

In [26]:
G = Generator().to(device)
F = Generator().to(device)
D_X = Discriminator().to(device)
D_Y = Discriminator().to(device)
G.weight_init(mean_init, std_init)
F.weight_init(mean_init, std_init)
D_X.weight_init(mean_init, std_init)
D_Y.weight_init(mean_init, std_init)
G.train()
F.train()
D_X.train()
D_Y.train()
shuffle=False

"""
### YOUR CODE HERE 
# 아래 물음표 친 곳의 코드를 채우시오
# torch.utils.data.DataLoader 의 data를 불러오는 방식을 random으로 설정
# Note: https://pytorch.org/docs/stable/data.html 참조
"""

transform = transforms.Compose([transforms.Resize(img_size), transforms.CenterCrop(img_size), transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])
train_loader = torch.utils.data.DataLoader(dataset=UnallignedDataset('miscs/m0811/datasets/horse2zebra', transform), 
                                           batch_size=batch_size, 
                                           ???????????, 
                                           pin_memory=True, 
                                           num_workers=2)
test_loader = torch.utils.data.DataLoader(dataset=UnallignedDataset('miscs/m0811/datasets/horse2zebra', transform, phase='test'), 
                                           batch_size=batch_size, 
                                           ???????????, 
                                           pin_memory=True, 
                                           num_workers=2)

X_pool = ImagePool(50)
Y_pool = ImagePool(50)

mse_criterion = nn.MSELoss()
l1_criterion = nn.L1Loss()

GF_optimizer = torch.optim.Adam(list(G.parameters()) + list(F.parameters()), lr=lr, betas=betas)
D_X_optimizer = torch.optim.Adam(D_X.parameters(), lr=lr, betas=betas)
D_Y_optimizer = torch.optim.Adam(D_Y.parameters(), lr=lr, betas=betas)

GF_scheduler = StepLR(GF_optimizer, 1, lr/100.0)
D_X_scheduler = StepLR(D_X_optimizer, 1, lr/100.0)
D_Y_scheduler = StepLR(D_Y_optimizer, 1, lr/100.0)

In [None]:
summary(G, (3, 256, 256))
summary(D_X, (3, 256, 256))

In [None]:
def mean(lst):
    return sum(lst)/len(lst)

# Prepare some test data, 5 of each kind
test_data = [(x.to(device), y.to(device)) for i, (x, y) in enumerate(test_loader) if i<5]

# Define target vectors
fake_target = 0.0
real_target = 1.0
for epoch in range(epochs):
    G_gan_loss_epoch = []
    G_cycle_loss_epoch = []
    G_ident_loss_epoch = []
    D_X_gan_loss_epoch = []
    
    # Linear lr decay
    if epoch > 99:
        GF_scheduler.step()
        D_X_scheduler.step()
        D_Y_scheduler.step()
        
    for i, (X, Y) in enumerate(train_loader):
        X = X.to(device)
        Y = Y.to(device)
        #########################################################
        # Update generators
        #########################################################
        GF_optimizer.zero_grad()
        
        # Translate from X to Y, check D_Y output
        G_out = G(X)
        D_Y_out = D_Y(G_out)
        G_gan_loss = mse_criterion(D_Y_out, torch.ones_like(D_Y_out).to(device))
        
        # Translate from Y to X, check D_X output
        F_out = F(Y)
        D_X_out = D_X(F_out)
        F_gan_loss = mse_criterion(D_X_out, torch.ones_like(D_X_out).to(device))
        
        # Translate from X to Y to X, check reconstruction error
        X_recon = F(G_out)
        G_cycle_loss = l1_criterion(X_recon, X) * lambda_X
        
        # Translate from Y to X to Y, check reconstruction error
        Y_recon = G(F_out)
        F_cycle_loss = l1_criterion(Y_recon, Y) * lambda_Y
        
        # Translate a picture from Y from X to Y, should be Y
        Y_ident = G(Y)
        G_ident_loss = l1_criterion(Y_ident, Y) * lambda_identity_X * lambda_X
        
        # Translate a picture from X from Y to X, should be X
        X_ident = F(X)
        F_ident_loss = l1_criterion(X_ident, X) * lambda_identity_X * lambda_Y
        """
        ### YOUR CODE HERE 
        # 아래 물음표 친 곳의 코드를 채우시오
        # 1. Generator loss를 완성하시오
        # 2. Discriminator의 loss를 완성하시오
        # 3. loss와 optimizer의 update를 완성하시오
        # Note: https://pytorch.org/docs/stable/data.html 참조
        """
        GF_loss = G_cycle_loss + ?????????? + G_ident_loss + ???????????? + G_gan_loss + ??????????? 
        GF_loss.backward()
        GF_optimizer.step()
        
        #########################################################
        # Update discriminators
        # D_Y, minimize L_D_Y = E_y (D(y) - 1) ^2 + E_x (D(x))^2
        #########################################################
        D_Y_optimizer.zero_grad()
        
        # Test D_Y with fake and real input
        G_out = Y_pool.get(G_out)
        D_Y_out_fake = D_Y(G_out.detach())
        D_Y_out_real = D_Y(Y)
        # Calculate loss
        D_Y_loss_fake = mse_criterion(D_Y_out_fake, torch.zeros_like(D_Y_out_fake).to(device))
        D_Y_loss_real = mse_criterion(D_Y_out_real, torch.ones_like(D_Y_out_real).to(device))
        D_Y_gan_loss = (D_Y_loss_real + D_Y_loss_fake)*0.5        
        
        D_Y_gan_loss.???????? # back propagation 해주기
        D_Y_optimizer.?????? # optimizer가 한 스텝 나아가기
        
        #########################################################
        # D_X, minimize L_D_X = E_x (D(x) - 1) ^2 + E_y (D(y))^2
        #########################################################
        D_X_optimizer.zero_grad()
        
        # Test D_X with fake and real input
        F_out = X_pool.get(F_out)
        D_X_out_fake = D_X(F_out.detach())
        D_X_out_real = D_X(X)
        # Calculate loss
        D_X_loss_fake = mse_criterion(D_X_out_fake, torch.zeros_like(D_X_out_fake).to(device))
        D_X_loss_real = mse_criterion(D_X_out_real, torch.ones_like(D_X_out_real).to(device))
        D_X_gan_loss = (D_X_loss_real + D_X_loss_fake)*0.5

        D_X_gan_loss.???????? # back propagation 해주기
        D_X_optimizer.?????? # optimizer가 한 스텝 나아가기
                
        # Save losses
        G_gan_loss_epoch.append(G_gan_loss.item())
        G_cycle_loss_epoch.append(G_cycle_loss.item())
        G_ident_loss_epoch.append(G_ident_loss.item())
        D_X_gan_loss_epoch.append(D_X_gan_loss.item())
        
        # Do some test output every 100 batches
        if i % 100 == 0:
            checkname = 'Epoch [%d/%d], Batch [%d/%d]' % (epoch+1, epochs, i, len(train_loader))
            savename = 'miscs/m0811/logs/Epoch%dBatch%d' % (epoch+1, i)
            print(checkname)
            
            image_tensor = None
            # Generate test outputs
            
            with torch.no_grad():
                G.eval()
                F.eval()
                for X, Y in test_data:
                    G_out = G(X)
                    F_out = F(Y)
                    if image_tensor is None:
                        image_tensor = torch.cat((X, G_out, Y, F_out), 0)
                    else:
                        image_tensor = torch.cat((image_tensor, X, G_out, Y, F_out), 0)
                G.train()
                F.train()
            save_image(image_tensor, savename + '.png', nrow=4, padding=50)
            
#             save_image(image_tensor, './i.' nrow=4, padding=2, normalize=True)
#             writer.add_image('test_images', image, i+epoch*len(train_loader))
    
    # Calculate mean
    G_gan_loss_epoch = mean(G_gan_loss_epoch)
    G_cycle_loss_epoch = mean(G_cycle_loss_epoch)
    G_ident_loss_epoch = mean(G_ident_loss_epoch)
    G_loss_epoch = G_gan_loss_epoch + G_cycle_loss_epoch + G_ident_loss_epoch
    D_X_gan_loss_epoch = mean(D_X_gan_loss_epoch)
  
print('학습 완료')


In [None]:
# 학습된 parameter 저장하기
torch.save(G.state_dict(), 'miscs/m0811/G.pt')
torch.save(F.state_dict(), 'miscs/m0811/F.pt')
torch.save(D_X.state_dict(), 'miscs/m0811/D_X.pt')
torch.save(D_Y.state_dict(), 'miscs/m0811/D_Y.pt')