In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import random
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image
from glob import glob
from torchinfo import summary
from torchvision.transforms.functional import to_pil_image
print(f"GPUs used:\t{torch.cuda.device_count()}")
device = torch.device("cuda", 6)
print(f"Device:\t\t{device}")

In [None]:
params = {'image_size': 1024,
          'lr': 2e-5,
          'batch_size': 16,
          'epochs': 1000,
          'n_classes': None,
          'data_path': '../../data/normalization_type/Notstandard/',
          'image_count': 5000,
          }
tf = transforms.ToTensor()

In [None]:
class CustomDataset(Dataset):
    """COCO Custom Dataset compatible with torch.utils.data.DataLoader."""

    def __init__(self, parmas, L_image, RGB_image):

        self.L_images = L_image
        self.args = parmas
        self.RGB_image = RGB_image

    def trans(self, L_image,RGB_image):
        if random.random() > 0.5:
            transform = transforms.RandomHorizontalFlip(1)
            L_image = transform(L_image)
            RGB_image=transform(RGB_image)

        if random.random() > 0.5:
            transform = transforms.RandomVerticalFlip(1)
            L_image = transform(L_image)
            RGB_image=transform(RGB_image)

        return L_image,RGB_image

    def __getitem__(self, index):
        L_image = self.L_images[index]
        RGB_image = self.RGB_image[index]
        L_image ,RGB_image = self.trans(L_image,RGB_image)
        return RGB_image,L_image 

    def __len__(self):
        return len(self.L_images)

image_list = glob(params['data_path']+'*.jpeg')
L_images=torch.zeros((len(image_list),3,params['image_size'],params['image_size']))
RGB_images=torch.zeros((len(image_list),3,params['image_size'],params['image_size']))
for i in tqdm(range(len(image_list))):
    image=Image.open(image_list[i])
    L_images[i]=tf(image.convert('L').convert('RGB').resize((params['image_size'], params['image_size'])))*2-1
    RGB_images[i]=tf(image.convert('RGB').resize((params['image_size'], params['image_size'])))*2-1
train_dataset = CustomDataset(params, L_images, RGB_images)
dataloader = DataLoader(
    train_dataset, batch_size=params['batch_size'], shuffle=True)

In [None]:
# U-Net 아키텍처의 다운 샘플링(Down Sampling) 모듈
class UNetDown(nn.Module):
    def __init__(self, in_channels, out_channels, normalize=True, dropout=0.0):
        super(UNetDown, self).__init__()
        # 너비와 높이가 2배씩 감소
        layers = [nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2))
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


# U-Net 아키텍처의 업 샘플링(Up Sampling) 모듈: Skip Connection 입력 사용
class UNetUp(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.0):
        super(UNetUp, self).__init__()
        # 너비와 높이가 2배씩 증가
        layers = [nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)]
        layers.append(nn.InstanceNorm2d(out_channels))
        layers.append(nn.ReLU(inplace=True))
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x, skip_input):
        x = self.model(x)
        x = torch.cat((x, skip_input), 1) # 채널 레벨에서 합치기(concatenation)

        return x


# U-Net 생성자(Generator) 아키텍처
class GeneratorUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(GeneratorUNet, self).__init__()

        self.down1 = UNetDown(in_channels, 64, normalize=False)  # 출력: [64 x 512 x 512]
        self.down2 = UNetDown(64, 128)                           # 출력: [128 x 256 x 256]
        self.down3 = UNetDown(128, 256)                          # 출력: [256 x 128 x 128]
        self.down4 = UNetDown(256, 512, dropout=0.5)             # 출력: [512 x 64 x 64]
        self.down5 = UNetDown(512, 512, dropout=0.5)             # 출력: [512 x 32 x 32]
        self.down6 = UNetDown(512, 512, dropout=0.5)             # 출력: [512 x 16 x 16]
        self.down7 = UNetDown(512, 512, dropout=0.5)             # 출력: [512 x 8 x 8]
        self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5)  # 출력: [512 x 4 x 4]

        self.up1 = UNetUp(512, 512, dropout=0.5)                 # 출력: [1024 x 8 x 8]
        self.up2 = UNetUp(1024, 512, dropout=0.5)                # 출력: [1024 x 16 x 16]
        self.up3 = UNetUp(1024, 512, dropout=0.5)                # 출력: [1024 x 32 x 32]
        self.up4 = UNetUp(1024, 512, dropout=0.5)                # 출력: [1024 x 64 x 64]
        self.up5 = UNetUp(1024, 256)                             # 출력: [512 x 128 x 128]
        self.up6 = UNetUp(512, 128)                              # 출력: [256 x 256 x 256]
        self.up7 = UNetUp(256, 64)                               # 출력: [128 x 512 x 512]

        self.final = nn.Sequential(
            nn.Upsample(scale_factor=2),  # 출력: [128 x 1024 x 1024]
            nn.Conv2d(128, out_channels, kernel_size=3, stride=1, padding=1),  # 출력: [3 x 1024 x 1024]
            nn.Tanh(),
        )

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)
        u1 = self.up1(d8, d7)
        u2 = self.up2(u1, d6)
        u3 = self.up3(u2, d5)
        u4 = self.up4(u3, d4)
        u5 = self.up5(u4, d3)
        u6 = self.up6(u5, d2)
        u7 = self.up7(u6, d1)

        return self.final(u7)


# U-Net 판별자(Discriminator) 아키텍처
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(in_channels, out_channels, normalization=True):
            layers = [nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.InstanceNorm2d(out_channels))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels * 2, 64, normalization=False),  # 출력: [64 x 512 x 512]
            *discriminator_block(64, 128),                                  # 출력: [128 x 256 x 256]
            *discriminator_block(128, 256),                                 # 출력: [256 x 128 x 128]
            *discriminator_block(256, 512),                                 # 출력: [512 x 64 x 64]
            *discriminator_block(512, 512),                                 # 출력: [512 x 32 x 32]
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, kernel_size=4, padding=1, bias=False)         # 출력: [1 x 32 x 32]
        )

    def forward(self, img_A, img_B):
        img_input = torch.cat((img_A, img_B), 1)
        return self.model(img_input)
     

In [None]:

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


# 생성자(generator)와 판별자(discriminator) 초기화
generator = GeneratorUNet()
discriminator = Discriminator()

generator.to(device)
discriminator.to(device)

# 가중치(weights) 초기화
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# 손실 함수(loss function)
criterion_GAN = torch.nn.MSELoss()
criterion_pixelwise = torch.nn.L1Loss()

criterion_GAN.to(device)
criterion_pixelwise.to(device)

# 학습률(learning rate) 설정
lr = 2e-4

# 생성자와 판별자를 위한 최적화 함수
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
# summary(generator, input_size=(params['batch_size'], 3, params['image_size'], params['image_size']))

In [48]:

lambda_pixel = 100

for epoch in range(params['epochs']):
    with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
        for RGB_image,L_image in tqdmDataLoader:
            # 모델의 입력(input) 데이터 불러오기
            real_A = RGB_image.to(device)#color
            real_B = L_image.to(device)#grayscale

            # 진짜(real) 이미지와 가짜(fake) 이미지에 대한 정답 레이블 생성 (너바와 높이를 16씩 나눈 크기)
            real = torch.FloatTensor(real_A.size(0), 1, 32, 32).fill_(1.0).to(device) # 진짜(real): 1
            fake = torch.FloatTensor(real_A.size(0), 1, 32, 32).fill_(0.0).to(device) # 가짜(fake): 0

            """ 생성자(generator)를 학습합니다. """
            optimizer_G.zero_grad()

            # 이미지 생성
            fake_B = generator(real_A)

            # 생성자(generator)의 손실(loss) 값 계산
            loss_GAN = criterion_GAN(discriminator(fake_B, real_A), real)

            # 픽셀 단위(pixel-wise) L1 손실 값 계산
            loss_pixel = criterion_pixelwise(fake_B, real_B) 

            # 최종적인 손실(loss)
            loss_G = loss_GAN + lambda_pixel * loss_pixel

            # 생성자(generator) 업데이트
            loss_G.backward()
            optimizer_G.step()

            """ 판별자(discriminator)를 학습합니다. """
            optimizer_D.zero_grad()

            # 판별자(discriminator)의 손실(loss) 값 계산
            loss_real = criterion_GAN(discriminator(real_B, real_A), real) # 조건(condition): real_A
            loss_fake = criterion_GAN(discriminator(fake_B.detach(), real_A), fake)
            loss_D = (loss_real + loss_fake) / 2

            # 판별자(discriminator) 업데이트
            loss_D.backward()
            optimizer_D.step()
            tqdmDataLoader.set_postfix(
                ordered_dict={
                    "epoch": epoch + 1,
                    "D loss: ": loss_D.item(),
                    "G pixel loss: ": loss_pixel.item(),
                    "adv loss: ": loss_GAN.item(),
                }
            )
    imgs = next(iter(dataloader)) # 10개의 이미지를 추출해 생성
    real_A = imgs[1].to(device)
    real_B = imgs[0].to(device)
    fake_B = generator(real_A)
    # real_A: 조건(condition), fake_B: 변환된 이미지(translated image), real_B: 정답 이미지
    img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2) # 높이(height)를 기준으로 이미지를 연결하기
    to_pil_image(img_sample[0]).save(f'../../result/pix2pix/origin/{epoch}.png')
    torch.save(generator.state_dict(), f"../../model/pix2pix/origin/Pix2Pix_Generator_for_Colorization_{epoch}.pt")
    torch.save(discriminator.state_dict(), f"../../model/pix2pix/origin/Pix2Pix_Discriminator_for_Colorization_{epoch}.pt")
    # 하나의 epoch이 끝날 때마다 로그(log) 출력
    
     

  6%|▌         | 262/4478 [00:13<03:38, 19.31it/s, epoch=1, D loss: =0.325, G pixel loss: =0.0463, adv loss: =0.324]


KeyboardInterrupt: 

In [47]:
real = torch.FloatTensor(real_A.size(0), 1, 32, 32).fill_(1.0) # 진짜(real): 1
fake = torch.FloatTensor(real_A.size(0), 1, 32, 32).fill_(0.0) 

In [44]:
discriminator(real_B, real_A).shape

torch.Size([1, 1, 32, 32])