<a href="https://colab.research.google.com/github/Yohan0358/Study_GAN/blob/main/cGAN_GrayToColor(210719).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

[reference : 가짜연구소](https://pseudo-lab.github.io/Tutorial-Book/chapters/GAN/Ch3-GAN.html)

In [None]:
!git clone https://github.com/Pseudo-Lab/Tutorial-Book-Utils
!python Tutorial-Book-Utils/PL_data_loader.py --data GAN-Colorization
!unzip -q Victorian400-GAN-colorization-data.zip

In [None]:
import os
import glob

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid

import time

import cv2

from PIL import Image

In [None]:
path_origin = './original/'
path_gray = './gray/'
path_resized = './resized/'

In [None]:
_origin = sorted(glob.glob(path_origin + '*'))
_gray = sorted(glob.glob(path_gray + '*'))
_resized = sorted(glob.glob(path_resized + '*'))

In [None]:
data = [_origin, _gray, _resized]

for d in data:
    print(cv2.imread(d[0]).shape)

In [None]:
def plot_img(img):
    img = cv2.imread(img)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    plt.imshow(img)
    plt.show()

In [None]:
for d in [_origin, _gray, _resized]:
    plot_img(d[0])

In [None]:
def get_mean_std(file):
    mean = 0
    img_list = []
    for img in file:
        img = cv2.imread(img)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255
        img_list.append(img)
        mean += img.reshape(-1, 3).mean(axis = 0) / len(file)

    var = 0
    for img in img_list:
        var += ((img.reshape(-1, 3) - mean) ** 2).mean(axis = 0) / len(file)
    std = var ** 0.5

    return mean, std

In [None]:
# dataset load
class Custom_dataset(Dataset):
    def __init__(self, color_path, gray_path, color_transform, gray_transform):
        super(Custom_dataset, self).__init__()
        
        self.color_file = color_path
        self.gray_file = gray_path

        self.color_transform = color_transform
        self.gray_transform = gray_transform
    
    def __len__(self):
        return len(self.color_file)

    def __getitem__(self, idx):
        gray_img = Image.open(self.gray_file[idx]).convert('RGB')
        color_img = Image.open(self.color_file[idx]).convert('RGB')

        gray_img = self.gray_transform(gray_img)
        color_img = self.color_transform(color_img)

        return gray_img, color_img

In [None]:
color_mean, color_std = get_mean_std(_resized)
gray_mean, gray_std = get_mean_std(_gray)

color_transform = transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean = color_mean, std = color_std)  
                                    ])

gray_transform = transforms.Compose([
                                     transforms.ToTensor(),
                                     transforms.Normalize(mean = gray_mean[0], std = gray_std[0])
                                    ])

dataset = Custom_dataset(_resized, _gray, color_transform, gray_transform)

In [None]:
batch_size = 16
loader = DataLoader(dataset, batch_size, shuffle= True)

In [None]:
def imshow_grid(img, mean, std):
    img = make_grid(img.cpu().detach())
    np_img = np.transpose(img.numpy(), (1,2,0))
    np_img = np_img * std + mean
    np_img = np.clip(np_img, 0, 1)

    plt.figure(figsize = (10, 4))
    plt.imshow(np_img)
    plt.show()

sample_g, sample_c = next(iter(loader))

imshow_grid(sample_g, gray_mean, gray_std)
imshow_grid(sample_c, color_mean, color_std)

In [None]:
# 256 x 256 이미지 생성

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.main = nn.Sequential(
            nn.Conv2d(2, 64, 4, 2, 1, bias = False),
            nn.LeakyReLU(0.2,),
            self._conv_block(64, 128),
            self._convT_block(128, 64),
            nn.Conv2d(64, 3, 3, 1, 1),
            nn.Tanh()
        )

    def _conv_block(self, in_ch, out_ch, kernel_size = 3, stride = 1, padding = 1):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding, bias = False),
            nn.BatchNorm2d(out_ch),
            nn.LeakyReLU(0.2)
        )

    def _convT_block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.ConvTranspose2d(in_ch, out_ch, 4, 2, 1, bias = False),
            nn.BatchNorm2d(out_ch),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x):
        x = self.main(x)
        return x

'''
noise : (batch, 100, 1, 1) -> (batch, 1, 256, 256)
img   :                       (batch, 1, 256, 256)   
==> (batch, 2, 256, 256) 으로 만드는 generator, 학습시간이 너무 오래걸리고 성능 x

'''
# class Generator(nn.Module):
#     def __init__(self, latent):
#         super(Generator, self).__init__()

#         self.noise_up = nn.Sequential(
#             nn.ConvTranspose2d(latent, 512, 8, 1, 0),
#             nn.LeakyReLU(0.2),
#             self._convT_block(512, 256),
#             self._convT_block(256, 128),
#             self._convT_block(128, 64),
#             self._convT_block(64, 32),

#         )

#         self.img_down = nn.Sequential(
#             nn.Conv2d(1, 32, 4, 2, 1),
#             nn.LeakyReLU(0.2),
#             # nn.MaxPool2d((2,2))
#         )

#         self.conv = nn.Sequential(
#             self._convT_block(64, 128),
#             self._conv_block(128, 128, 3, 1, 1),
#             self._conv_block(128, 64, 3, 1, 1),
#             nn.Conv2d(64, 3, 3, 1, 1),

#             nn.Tanh()            
#         )

#     def _convT_block(self, in_ch, out_ch):
#         return nn.Sequential(
#             nn.ConvTranspose2d(in_ch, out_ch, 4, 2, 1),
#             nn.BatchNorm2d(out_ch),
#             nn.LeakyReLU(0.2)
#         )

#     def _conv_block(self, in_ch, out_ch, kernel_size = 4, stride = 2, padding = 1):
#         return nn.Sequential(
#             nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding),
#             nn.BatchNorm2d(out_ch),
#             nn.LeakyReLU(0.2)
#         )

#     def forward(self, img, z):
#         img = self.img_down(img)
#         z = self.noise_up(z)
#         x = torch.cat([img, z], dim = 1)
#         x = self.conv(x)
#         return x

def test():
    G = Generator()
    img = torch.randn(4, 1, 256, 256)
    z = torch.randn(4, 1, 256, 256)
    x = torch.cat([img, z], dim = 1)
    out = G(x)
    print(out.shape)

test()

In [None]:
class Discriminator(nn.Module):
    def __init__(self, image_channel = 3):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(image_channel, 64, 4, 2, 1, bias= False),
            nn.LeakyReLU(0.2),

            self._block(64, 128),
            self._block(128, 256),
            self._block(256, 512),
            self._block(512, 512),
            self._block(512, 256),
            self._block(256, 128),

            nn.Conv2d(128, 1, 4, 2, 1, bias = False),
            nn.Sigmoid()
        )

    def _block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 4, 2, 1, bias = False),
            nn.BatchNorm2d(out_ch),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x):
        x = self.main(x)
        return x.squeeze()

def test():
    D = Discriminator()
    x = torch.randn(16, 3, 256, 256)
    print(D(x).shape)

test()

In [None]:
def weight_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.01)
        nn.init.constant_(m.bias.data, 0)

In [None]:
# hyper parameter
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device :', device)

lr = 2e-4
epochs = 50
latent = 256
img_size = 256
G = Generator().to(device)
D = Discriminator().to(device)

weight_init(G)
weight_init(D)

G_optim = optim.Adam(G.parameters(), lr = lr)
D_optim = optim.Adam(D.parameters(), lr = lr)

# L2 Losss
criterion = nn.MSELoss()

def generate_noise(batch_size, latent):
    z = torch.randn(batch_size, 1, latent, latent)
    return z

In [None]:
G.train()
D.train()

for epoch in range(epochs):
    D_losses = 0
    G_losses = 0

    t = time.time()
    for i, (img_g, img_c) in enumerate(loader):
        batch_size = len(img_g)

        img_g, img_c = img_g[:, 0:1, :, :].to(device), img_c.to(device)

        # Discriminator 학습
        real_labels = torch.ones(batch_size, img_size).to(device)
        fake_labels = torch.zeros(batch_size, img_size).to(device)

        z = generate_noise(batch_size, latent).to(device)
        z = torch.cat([img_g, z], dim = 1)

        output_z = D(G(z))
        output_img = D(img_c)

        D_loss = torch.mean((output_img - 1) ** 2) + torch.mean(output_z ** 2)

        D_optim.zero_grad()
        D_loss.backward()
        D_optim.step()

        # Generator 학습
        fake_img = G(z)
        output = D(fake_img)
        G_loss = torch.mean((output - 1 ) **2 )

        G_optim.zero_grad()
        G_loss.backward()
        G_optim.step()

        D_losses += D_loss.item() / len(loader)
        G_losses += G_loss.item() / len(loader)

    print(f'[{epoch + 1} / {epochs}] epochs \t D_loss : {D_losses:.4f} \t G_loss : {G_losses:.4f} \t time : {time.time() - t}')
    
    if (epoch + 1) % 5 == 0:
        G.eval()

        z = generate_noise(batch_size, latent).to(device)
        z = torch.cat([img_g, z], dim = 1)
        print('====== GRAY ======')
        imshow_grid(img_g, gray_mean, gray_std)
        print('====== COLOR ======')
        imshow_grid(img_c, color_mean, color_std)
        output = G(z)
        print('====== FAKE ======')
        imshow_grid(output, color_mean, color_std)

        G.train()

In [None]:
# 모델 저장
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')