In [None]:
'''
    kaggle = https://www.kaggle.com/code/leejin11/pix2pix-code
'''

In [None]:
# image2image GRAY2RGB 로 연습

In [None]:
import os
import cv2
import numpy as np
import time

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image

import matplotlib.pyplot as plt

In [None]:
## network class

In [None]:
class EncodingBlock(nn.Module):
    '''EncodingBlock for G and D
        Args:
            in_dim(int) : input dimension
            output(int) : output dimension
    '''
    def __init__(self, in_dim, out_dim, *, kernel_size=4, stride=2, padding=1, normalize=True):
        super(EncodingBlock, self).__init__()
        # 기본적인 conv2d layer 생성
        layers = [nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)]

        # normalize 여부에 따라 batchnorm2d 삽입
        if normalize:
            layers.append(nn.BatchNorm2d(out_dim))

        # activation function 삽입
        layers.append(nn.LeakyReLU(0.2))
    
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.model(x)
        return x

class DecodingBlock(nn.Module):
    '''DecodingBlock for G
        Args:
            in_dim(int) : input dimension
            output(int) : output dimension
    '''
    def __init__(self, in_dim, out_dim, *, kernel_size=4, stride=2, padding=1,dropout=False):
        super(DecodingBlock, self).__init__()

        # 모든 layer에 동일하게 적용
        self.dropout = nn.Dropout2d(p=0.3) if dropout else nn.Identity()
            
        self.model = nn.Sequential(
            nn.ConvTranspose2d(in_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(out_dim),
            nn.ReLU(),
            self.dropout
        )

    def forward(self, x, skip_input=None):
        # 공간 정보(상위 feature)를 복원하기 위한 skip connetion 연결
        if skip_input is not None:
            x = torch.cat((x, skip_input), dim=1)
            
        x = self.model(x)
        return x

# Genrator
class Generator(nn.Module):
    '''Unet 기반 Generator model
        Args:
            in_dim(int) : Input dimension
            out_dim(int) : output dimension
            features(int) : hidden dimension
    '''
    def __init__(self, in_dim=3, out_dim=3, features=64):
        super(Generator, self).__init__()
        self.enc1 = EncodingBlock(in_dim, features, normalize=False)
        self.enc2 = EncodingBlock(features, features*2)
        self.enc3 = EncodingBlock(features*2, features*4)
        self.enc4 = EncodingBlock(features*4, features*8)
        self.enc5 = EncodingBlock(features*8, features*8)
        self.enc6 = EncodingBlock(features*8, features*8, normalize=False)
        
        self.dec1 = DecodingBlock(features*8, features*8, dropout=True)
        self.dec2 = DecodingBlock(features*16, features*8, dropout=True)
        self.dec3 = DecodingBlock(features*16, features*4)
        self.dec4 = DecodingBlock(features*8, features*2)
        self.dec5 = DecodingBlock(features*4, features)
        self.final = nn.Sequential(
            nn.ConvTranspose2d(features*2, out_dim, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )
    
    def forward(self, x):
        # Input [b, c, h, w] = [b, 3, 64, 64]
        e1 = self.enc1(x) # [b, 64, 32, 32]
        e2 = self.enc2(e1) # [b, 128, 16, 16]
        e3 = self.enc3(e2) # [b, 256, 8, 8]
        e4 = self.enc4(e3) # [b, 512, 4, 4]
        e5 = self.enc5(e4) # [b, 512, 2, 2]
        e6 = self.enc6(e5) # [b, 512, 1, 1]
        
        d1 = self.dec1(e6) # [b, 512, 2, 2]
        d2 = self.dec2(d1, e5) # [b, 1024(512 + 512), 2, 2] -> [b, 512, 4, 4]
        d3 = self.dec3(d2, e4) # [b, 1024, 4, 4] -> [b, 256, 8, 8]
        d4 = self.dec4(d3, e3) # [b, 512, 8, 8] -> [b, 128, 16, 16]
        d5 = self.dec5(d4, e2) # [b, 256, 16, 16] -> [b, 64, 32, 32]
        out = self.final(torch.cat((d5, e1), dim=1)) # [b, 128, 32, 32] -> [b, 3, 64, 64]
        return out

# Discrimnator
class Discriminator(nn.Module):
    '''Unet 기반 Discriminator model
        Args:
            in_dim(int) : Input dimension
            out_dim(int) : output dimension
            features(int) : hidden dimension
    '''
    def __init__(self, in_dim=3, out_dim=3,features=64):
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(
            # Input [b, 4, 64, 64]
            EncodingBlock(in_dim + out_dim, features, normalize=False), # [b, 64, 32, 32]
            EncodingBlock(features, features * 2), # [b, 128, 16, 16]
            EncodingBlock(features * 2, features * 4), # [b, 256, 8, 8]
            nn.Conv2d(features*4, out_dim, kernel_size=4, stride=2, padding=1), # [b, 3, 4, 4]
        )
        
    def forward(self, img_A, img_B):
        # input [b, 1, h, w], [b, 3, h, w] = [b, 1, 64, 64], [b, 3, 64, 64]
        x = torch.cat((img_A, img_B), dim=1) # [b, 4, 64, 64] 
        x = self.model(x) # [b, 3, 4, 4]
        return x

In [None]:
# Loading Data

In [None]:
print('loading data')
img_list = list()

image_path = '/kaggle/input/imagenet1k0'
for cls_name in os.listdir(image_path):
    for img_name in os.listdir(os.path.join(image_path, cls_name)):
        # 이미지의 이름을 저장
        img_list.append(os.path.join(cls_name, img_name))
print('finised loading data')

img_list = img_list[:50_000]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'device type : {device}\n')

## dataloader <- 이미지를 ram에 올리는 것은 불가능.
class CustomDataset(Dataset):
    def __init__(self, img_list, size=256):
        self.img_list = img_list
        self.path = image_path
        self.size = size
        
    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):

        img_path = os.path.join(self.path, self.img_list[idx])
        img = cv2.imread(img_path)

        # gray와 color (Input 과 Output)
        img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        img_color = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # 이미지의 크기를 size로 조절
        img_gray = cv2.resize(img_gray, (self.size, self.size))        
        img_color = cv2.resize(img_color, (self.size, self.size))

        # normalization
        img_gray = img_gray.astype(np.float32) / 255.0 * 2 - 1
        img_color = img_color.astype(np.float32) / 255.0 * 2 - 1

        # tensor 형태로 변환
        img_gray = torch.from_numpy(img_gray)
        img_gray = img_gray.unsqueeze(-1).permute(2, 0, 1) # [h, w]의 형태를 [1, h, w]로 변환
        img_color = torch.from_numpy(img_color).permute(2, 0, 1)

        return img_gray, img_color

In [None]:
# train

In [None]:
batch_size = 1
lr = 2e-4
epochs = 10
betas = (0.5, 0.999)
gamma = 100
size = 64

dataset = CustomDataset(img_list, size)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

g_model = Generator(in_dim=1).to(device)
d_model = Discriminator(in_dim=1).to(device)

g_optimizer = torch.optim.Adam(g_model.parameters(), lr=lr, betas=betas)
d_optimizer = torch.optim.Adam(d_model.parameters(), lr=lr, betas=betas)

criterion_GAN = nn.BCEWithLogitsLoss()
criterion_L1 = nn.L1Loss()

start_time = time.time()

for e in range(epochs):
    for i, (gray, color) in enumerate(dataloader):
        gray, color = gray.to(device), color.to(device)
    
        ## generator
        fake_img = g_model(gray) # 새로운 이미지 생성
        pred_fake = d_model(gray, fake_img) # discriminator가 인풋과 생성 이미지를 판단
        
        valid = torch.ones_like(pred_fake)
        fake  = torch.zeros_like(pred_fake)
        
        loss_GAN = criterion_GAN(pred_fake, valid) # discriminator를 속이는 loss
        loss_L1 = criterion_L1(fake_img, color) # 생성 이미지의 유사도 loss
        
        g_loss = loss_GAN + (gamma * loss_L1) # 이미지의 유사도의 gamma 가중치 적용
        g_loss.backward()
    
        g_optimizer.zero_grad()
        g_optimizer.step()
    
        ## discriminator
        pred_real = d_model(gray, color) # discrimonator가 Input을 진짜라고 판단
        pred_fake = d_model(gray, fake_img.detach()) # discriminator가 Input을 가짜라고 판단
        
        d_loss_real = criterion_GAN(pred_real, valid)
        d_loss_fake = criterion_GAN(pred_fake, fake)
        
        d_loss = 0.5 * (d_loss_real + d_loss_fake) # 두 개의 로스를 적용
        d_loss.backward()
    
        d_optimizer.zero_grad()
        d_optimizer.step()

        info = f'epoch : {e}    iter : {i+1:5d}    d_loss : {d_loss:.4f}  g_loss : {g_loss:.4f}'
        with open('output.txt', 'a') as f:
            f.write(info + '\n')
        
        
        if (i+1) % 1000 == 0:
            print(info +f'     time : {time.time()-start_time:5.3f}')
            start_time = time.time()    
    # 이미지 저장 및 시각화 코드
    save_image([gray.repeat(1, 3, 1, 1)[0], fake_img[0], color[0]], f'output{e}.jpg', nrow=3, normalize=True)
    img = cv2.imread(f'output{e}.jpg')
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    plt.imshow(img)
    plt.axis('off')
    plt.show()
        
torch.save({
    'discriminator': d_model.state_dict(),
    'd_optimizer': d_optimizer.state_dict(),
    'generator': g_model.state_dict(),
    'g_optimizer': g_optimizer.state_dict(),
    'epoch': epochs,
    
}, 'checkpoint.pth')