## 5.2 CycleGAN으로 모네 그림 그리기
- 데이터셋을 [링크](https://www.kaggle.com/datasets/balraj98/monet2photo)\[1]에서 다운 받은 뒤, 로컬에 ```./data/monet2photo``` 폴더를 만들어 넣어주세요.

[1] Zhu et al., Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks, ICCV 2017

### 5.2.1 데이터 살펴보기

In [None]:
import os
import matplotlib.pyplot as plt

monet = os.listdir('./data/monet2photo/trainA/')[0:10]
photo = os.listdir('./data/monet2photo/trainB/')[0:10]

plt.figure(figsize=(40,10))
for col, style in enumerate(['A','B']):
    for row in range(10):
        img_name = {"A":monet, "B":photo}[style][row]
        img = plt.imread(f'./data/monet2photo/train{style}/{img_name}')/255
        plt.subplot(2,10,col*10+row+1)
        plt.imshow(img)
        plt.title(f'{style} : {img.shape}')

### 5.2.2 생성자(ResNet)

In [None]:
from torchsummary import summary
import torch
import torch.nn as nn
from torchvision import transforms


class ResNet(nn.Module):
    def __init__(self,num_residuals=9):
        super().__init__()
        self.resize = transforms.Resize((256,256)) # 원래 사이즈로 안돼서 적용하겠습니다. 255 -> 256
        residual = []
        for _ in range(num_residuals):
            residual.append(nn.Sequential(nn.Conv2d(128,128,3,1,1),
                                          nn.InstanceNorm2d(128),
                                          nn.ReLU(),
                                          nn.Conv2d(128,128,3,1,1),
                                          nn.InstanceNorm2d(128)))
        self.down = nn.Sequential(nn.Conv2d(3,32,7,1,1),
                                  nn.InstanceNorm2d(32),
                                  nn.ReLU(),
                                  nn.Conv2d(32,64,3,2,1),
                                  nn.InstanceNorm2d(64),
                                  nn.ReLU(),
                                  nn.Conv2d(64,128,3,2,2),
                                  nn.InstanceNorm2d(128),
                                  nn.ReLU())
        self.residual = nn.ModuleList(residual)
        self.up = nn.Sequential(nn.ConvTranspose2d(128,64,3,2,2),
                                nn.InstanceNorm2d(64),
                                nn.ReLU(),
                                nn.ConvTranspose2d(64,32,3,2),
                                nn.InstanceNorm2d(32),
                                nn.ReLU(),
                                nn.ConvTranspose2d(32,3,7,1,1),
                                nn.InstanceNorm2d(3),
                                nn.Sigmoid())
        
    def forward(self,x):
        x = self.down(x)
        for layer in self.residual:
            residual = x.clone()
            x = layer(x) + residual
        x = self.up(x)
        x = self.resize(x)
        return x


if __name__ == '__main__':
    x = torch.randn((16,3,256,256))
    model = ResNet()
    y = model(x)
    print(f'output shape : {y.shape}')
    summary(model, (3,256,256), device='cpu')

### 5.2.3 데이터셋

In [None]:
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
import matplotlib.pyplot as plt

class Monet2Photo(Dataset):
    def __init__(self, split):
        assert split in ['train', 'test']
        self.folder_name_monet = split + 'A'
        self.folder_name_photo = split + 'B'
        self.data_list_monet = os.listdir(f'./data/monet2photo/{self.folder_name_monet}')
        self.data_list_photo = os.listdir(f'./data/monet2photo/{self.folder_name_photo}')
    
    def __len__(self):
        return min(len(self.data_list_monet), len(self.data_list_photo))
    
    def __getitem__(self, idx):
        # Monet
        img_monet = plt.imread(f'./data/monet2photo/{self.folder_name_monet}/{self.data_list_monet[idx]}')/255 # (256,256,3)
        img_monet = np.einsum('...c->c...', img_monet)
        img_monet = torch.from_numpy(img_monet)

        # Photo
        img_photo = plt.imread(f'./data/monet2photo/{self.folder_name_photo}/{self.data_list_photo[idx]}')/255 # (256,256,3)
        img_photo = np.einsum('...c->c...', img_photo)
        img_photo = torch.from_numpy(img_photo)
        return img_monet.type(torch.float32), img_photo.type(torch.float32)

if __name__ == '__main__':
    dataset = Monet2Photo('train')
    dataloader = DataLoader(dataset, 32)
    img_monet, img_photo = next(iter(dataloader))
    print(img_monet.shape, img_photo.shape)

    del dataset, dataloader, img_monet, img_photo

### 5.2.4 학습(w.o identity)

In [None]:
# Hyper-parameter
g_AB_lr = 2e-4
g_BA_lr = 2e-4
d_A_lr = 2e-4
d_B_lr = 2e-4

bs = 16 # batch size # 16GiB 단일 GPU에서 돌릴 수 있는 size. 여러분의 상황에 따라서 늘리거나 줄이세요
epochs = 200
g_d_ratio = 5 # D가 몇 번(g_d_ratio) 학습될 때 G가 한번 학습될지 정한다

identity = False

gpu = 1 # 여러분의 상황에 맞게 바꾸시길 바랍니다

In [None]:
# 모델 선언
import torch.optim as optim

g_AB = ResNet().cuda(gpu)
g_BA = ResNet().cuda(gpu)

d_A = PatchDiscriminator(in_channels=[3,16,32,16],
                         out_channels=[16,32,16,3],
                         kernel_size=[4,4,4,4],
                         stride=[2,2,2,2],
                         padding=[2,2,2,2]).cuda(gpu)

d_B = PatchDiscriminator(in_channels=[3,16,32,16],
                         out_channels=[16,32,16,3],
                         kernel_size=[4,4,4,4],
                         stride=[2,2,2,2],
                         padding=[2,2,2,2]).cuda(gpu)

# optimizer
g_AB_optimizer = optim.Adam(g_AB.parameters(), lr=g_AB_lr)
g_BA_optimizer = optim.Adam(g_BA.parameters(), lr=g_BA_lr)
d_A_optimizer = optim.Adam(d_A.parameters(), lr=d_A_lr)
d_B_optimizer = optim.Adam(d_B.parameters(), lr=d_B_lr)

# criterion
g_criterion = CycleGANGeneratorLoss(identity)
d_criterion = CycleGANDiscrimonatorLoss()

In [None]:
# D -> G 순으로 학습하겠다
from tqdm import tqdm

train_dataset = Monet2Photo('train')
test_dataset = Monet2Photo('test')

train_dataloader = DataLoader(train_dataset, bs, True)
test_dataloader = DataLoader(test_dataset, bs, True)

# 속도 향상
torch.backends.cudnn.benchmark = True

# 로그
train_g_loss = []
train_d_loss = []


for epoch in tqdm(range(epochs)):
    for i, (img_monet, img_photo) in enumerate(train_dataloader):
        # Discriminator 학습
        img_monet, img_photo = img_monet.cuda(gpu), img_photo.cuda(gpu)

        
        fake_B = g_AB(img_monet)
        fake_A = g_BA(img_photo)

        pred_real_B = d_B(img_photo)
        pred_fake_B = d_B(fake_B)
        pred_real_A = d_A(img_monet)
        pred_fake_A = d_A(fake_A)

        d_loss = d_criterion(pred_real_A, pred_fake_A) + d_criterion(pred_real_B, pred_fake_B)
        d_A_optimizer.zero_grad()
        d_B_optimizer.zero_grad()
        d_loss.backward()
        d_A_optimizer.step()
        d_B_optimizer.step()
        train_d_loss.append(d_loss.detach().cpu().item())

        if (i+1)%g_d_ratio==0:
            # Generator 학습
            del fake_B, fake_A, pred_real_B, pred_fake_B, pred_real_A, pred_fake_A
            torch.cuda.empty_cache()

            img_monet = img_monet.detach().clone().cuda(gpu)
            img_photo = img_photo.detach().clone().cuda(gpu)
            img_monet.requires_grad, img_photo.requires_grad = True, True

            
            fake_B = g_AB(img_monet)
            fake_A = g_BA(img_photo)
            cycle_B = g_AB(fake_A) # B처럼 보여야한다
            cycle_A = g_BA(fake_B) # A처럼 보여야한다
            
            pred_real_B = d_B(img_photo)
            pred_fake_B = d_B(fake_B)
            pred_real_A = d_A(img_monet)
            pred_fake_A = d_A(fake_A)

            if identity:
                id_B = g_AB(img_photo) # orange 그대로 나와야한다
                id_A = g_BA(img_monet) # apple 그대로 나와야한다
            else:
                id_B = None
                id_A = None  

            g_loss = g_criterion(img_monet,
                                img_photo,
                                pred_real_A,
                                pred_real_B,
                                pred_fake_A,
                                pred_fake_B,
                                cycle_A,
                                cycle_B,
                                id_A,
                                id_B)
            g_AB_optimizer.zero_grad()
            g_BA_optimizer.zero_grad()
            g_loss.backward()
            g_AB_optimizer.step()
            g_BA_optimizer.step()
            train_g_loss.append(g_loss.detach().cpu().item())        

### 5.2.5 학습결과 (w.o. identity)
#### 5.2.5.1 Train loss

In [None]:
plt.figure(figsize=(20,10))

plt.plot([item for item in train_g_loss for _ in range(g_d_ratio)], label='Generator Loss') # g_d_ratio iter에 한 번만 학습했으므로 g_d_ratio회 복제해준다
plt.plot(train_d_loss, label='Discriminator Loss')
plt.xlabel('Iter')
plt.ylabel('Loss')
plt.legend()
plt.show()

#### 5.2.5.2 Transfer

In [None]:
monet, photo = next(iter(test_dataloader))
monet, photo = monet.cuda(gpu), photo.cuda(gpu)

with torch.no_grad():
    g_BA.eval()
    g_AB.eval()
    fake_photo = g_AB(monet)
    fake_monet = g_BA(photo)
    imgs_ = torch.cat([monet, fake_photo, photo, fake_monet],dim=0)
    imgs = imgs_.detach().clone().cpu().numpy()

plt.figure(figsize=(60,20))
for row in range(4):
    name = {0:'Original monet', 1:'monet -> photo', 2:'Original photo', 3:'photo -> monet'}[row]
    for col in range(1,bs+1):
        img = imgs[row*bs+col-1]
        img = np.einsum('c...->...c', img)
        plt.subplot(4,bs,row*bs+col)
        plt.imshow(img)
        plt.title(name)

del imgs_, imgs
torch.cuda.empty_cache()

#### 5.2.5.3 Reconstruction

In [None]:
with torch.no_grad():
    g_BA.eval()
    g_AB.eval()
    cycle_monet = g_BA(fake_photo)
    cycle_photo = g_AB(fake_monet)
    imgs = torch.cat([monet, cycle_monet, photo, cycle_photo],dim=0)
    imgs = imgs.detach().cpu().numpy()

plt.figure(figsize=(60,20))
for row in range(4):
    name = {0:'monet',1:'Recon monet', 2:'photo',3:'Recon photo'}[row]
    for col in range(1,bs+1):
        img = imgs[row*bs+col-1]
        img = np.einsum('c...->...c', img)
        plt.subplot(4,bs,row*bs+col)
        plt.imshow(img)
        plt.title(name)

del cycle_monet, cycle_photo, imgs
torch.cuda.empty_cache()

#### 5.2.5.4 identity

In [None]:
with torch.no_grad():
    g_BA.eval()
    g_AB.eval()
    id_monet = g_BA(monet)
    id_photo = g_AB(photo)
    imgs = torch.cat([monet, id_monet, photo, id_photo],dim=0)
    imgs = imgs.detach().cpu().numpy()

plt.figure(figsize=(60,20))
for row in range(4):
    name = {0:'monet',1:'id monet', 2:'photo',3:'id photo'}[row]
    for col in range(1,bs+1):
        img = imgs[row*bs+col-1]
        img = np.einsum('c...->...c', img)
        plt.subplot(4,bs,row*bs+col)
        plt.imshow(img)
        plt.title(name)

del id_monet, id_photo, imgs
torch.cuda.empty_cache()

### 5.2.6 학습(w. identity)

### 5.2.7 학습결과 (w. identity)
#### 5.2.7.1 Train loss

#### 5.2.7.2 Transfer

#### 5.2.7.3 Reconstruction

#### 5.2.7.4 Identity