In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import random
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image
from glob import glob
from torchinfo import summary
from torchvision.transforms.functional import to_pil_image
import torch
import torch.nn as nn
import torch.cuda.amp as amp  # Mixed Precision Training
print(f"GPUs used:\t{torch.cuda.device_count()}")
device = torch.device("cuda", 5)
print(f"Device:\t\t{device}")

In [4]:
params = {'image_size': 1024,
          'lr': 5e-5,
          'batch_size': 4,  # 배치 사이즈를 줄여 메모리 부담을 줄임
          'epochs': 1000,
          'n_classes': None,
          'data_path': '../../data/origin_type/tar/**/**/',
          'image_count': 5000,
          'gradient_accumulation_steps': 2,  # Gradient Accumulation 설정
          }
tf = transforms.ToTensor()

In [None]:
class CustomDataset(Dataset):
    """COCO Custom Dataset compatible with torch.utils.data.DataLoader."""

    def __init__(self, parmas, L_image, RGB_image):

        self.L_images = L_image
        self.args = parmas
        self.RGB_image = RGB_image

    def trans(self, L_image,RGB_image):
        if random.random() > 0.5:
            transform = transforms.RandomHorizontalFlip(1)
            L_image = transform(L_image)
            RGB_image=transform(RGB_image)

        if random.random() > 0.5:
            transform = transforms.RandomVerticalFlip(1)
            L_image = transform(L_image)
            RGB_image=transform(RGB_image)

        return L_image,RGB_image

    def __getitem__(self, index):
        L_image = self.L_images[index]
        RGB_image = self.RGB_image[index]
        L_image ,RGB_image = self.trans(L_image,RGB_image)
        return RGB_image,L_image 

    def __len__(self):
        return len(self.L_images)

image_list = glob(params['data_path']+'*.jpeg')
RGB_image_list=[f.replace('/tar', '/nor') for f in image_list]
L_images=torch.zeros((len(image_list),3,params['image_size'],params['image_size']))
RGB_images=torch.zeros((len(image_list),3,params['image_size'],params['image_size']))
for i in tqdm(range(len(image_list))):
    image=Image.open(image_list[i])
    L_images[i]=tf(image.convert('L').convert('RGB').resize((params['image_size'], params['image_size'])))*2-1
    image=Image.open(RGB_image_list[i])
    RGB_images[i]=tf(image.convert('RGB').resize((params['image_size'], params['image_size'])))*2-1
train_dataset = CustomDataset(params, L_images, RGB_images)
dataloader = DataLoader(
    train_dataset, batch_size=params['batch_size'], shuffle=True)

In [6]:
# U-Net 아키텍처의 다운 샘플링(Down Sampling) 모듈
class UNetDown(nn.Module):
    def __init__(self, in_channels, out_channels, normalize=True, dropout=0.0):
        super(UNetDown, self).__init__()
        # 너비와 높이가 2배씩 감소
        layers = [nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2))
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


# U-Net 아키텍처의 업 샘플링(Up Sampling) 모듈: Skip Connection 입력 사용
class UNetUp(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.0):
        super(UNetUp, self).__init__()
        # ConvTranspose2d 대신 Upsample과 Conv2d 사용
        layers = [nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                  nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
                  nn.InstanceNorm2d(out_channels),
                  nn.ReLU(inplace=True)]
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x, skip_input):
        x = self.model(x)
        x = torch.cat((x, skip_input), 1)  # 채널 레벨에서 합치기
        return x


# U-Net 생성자(Generator) 아키텍처
class GeneratorUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(GeneratorUNet, self).__init__()

        self.down1 = UNetDown(in_channels, 64, normalize=False)  # 출력: [64 x 512 x 512]
        self.down2 = UNetDown(64, 128)                           # 출력: [128 x 256 x 256]
        self.down3 = UNetDown(128, 256)                          # 출력: [256 x 128 x 128]
        self.down4 = UNetDown(256, 512, dropout=0.5)             # 출력: [512 x 64 x 64]
        self.down5 = UNetDown(512, 512, dropout=0.5)             # 출력: [512 x 32 x 32]
        self.down6 = UNetDown(512, 512, dropout=0.5)             # 출력: [512 x 16 x 16]
        self.down7 = UNetDown(512, 512, dropout=0.5)             # 출력: [512 x 8 x 8]
        self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5)  # 출력: [512 x 4 x 4]

        self.up1 = UNetUp(512, 512, dropout=0.5)                 # 출력: [1024 x 8 x 8]
        self.up2 = UNetUp(1024, 512, dropout=0.5)                # 출력: [1024 x 16 x 16]
        self.up3 = UNetUp(1024, 512, dropout=0.5)                # 출력: [1024 x 32 x 32]
        self.up4 = UNetUp(1024, 512, dropout=0.5)                # 출력: [1024 x 64 x 64]
        self.up5 = UNetUp(1024, 256)                             # 출력: [512 x 128 x 128]
        self.up6 = UNetUp(512, 128)                              # 출력: [256 x 256 x 256]
        self.up7 = UNetUp(256, 64)                               # 출력: [128 x 512 x 512]

        self.final = nn.Sequential(
            nn.Upsample(scale_factor=2),  # 출력: [128 x 1024 x 1024]
            nn.Conv2d(128, out_channels, kernel_size=3, stride=1, padding=1),  # 출력: [3 x 1024 x 1024]
            nn.Tanh(),
        )

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)
        u1 = self.up1(d8, d7)
        u2 = self.up2(u1, d6)
        u3 = self.up3(u2, d5)
        u4 = self.up4(u3, d4)
        u5 = self.up5(u4, d3)
        u6 = self.up6(u5, d2)
        u7 = self.up7(u6, d1)

        return self.final(u7)


# U-Net 판별자(Discriminator) 아키텍처
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(in_channels, out_channels, normalization=True):
            layers = [nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.InstanceNorm2d(out_channels))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels * 2, 64, normalization=False),  # 출력: [64 x 512 x 512]
            *discriminator_block(64, 128),                                  # 출력: [128 x 256 x 256]
            *discriminator_block(128, 256),                                 # 출력: [256 x 128 x 128]
            *discriminator_block(256, 512),                                 # 출력: [512 x 64 x 64]
            *discriminator_block(512, 512),                                 # 출력: [512 x 32 x 32]
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, kernel_size=4, padding=1, bias=False)         # 출력: [1 x 32 x 32]
        )

    def forward(self, img_A, img_B):
        img_input = torch.cat((img_A, img_B), 1)
        return self.model(img_input)
     

In [7]:
scaler = amp.GradScaler()
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


# 생성자(generator)와 판별자(discriminator) 초기화
generator = GeneratorUNet()
discriminator = Discriminator()

generator.to(device)
discriminator.to(device)

# 가중치(weights) 초기화
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# 손실 함수(loss function)
criterion_GAN = torch.nn.MSELoss()
criterion_pixelwise = torch.nn.L1Loss()

criterion_GAN.to(device)
criterion_pixelwise.to(device)

# 학습률(learning rate) 설정
lr = 2e-5

# 생성자와 판별자를 위한 최적화 함수
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
# summary(generator, input_size=(params['batch_size'], 3, params['image_size'], params['image_size']))

In [None]:

lambda_pixel = 50
lambda1=1.01
for epoch in range(params['epochs']):
    total_loss_D = 0
    total_loss_pixel = 0
    total_loss_GAN = 0
    steps = 0

    optimizer_G.zero_grad()
    optimizer_D.zero_grad()
    
    with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
        for i, (RGB_image, L_image) in enumerate(tqdmDataLoader):
            # 모델의 입력 데이터 로드
            real_A = L_image.to(device)  # Grayscale
            real_B = RGB_image.to(device)  # Color

            real = torch.FloatTensor(real_A.size(0), 1, 32, 32).fill_(0.9).to(device)  # 진짜
            fake = torch.FloatTensor(real_A.size(0), 1, 32, 32).fill_(0.0).to(device)  # 가짜

            # Mixed Precision Training으로 Generator 학습
            with amp.autocast():
                fake_B = generator(real_A)
                loss_GAN = criterion_GAN(discriminator(fake_B, real_A), real)
                loss_pixel = criterion_pixelwise(fake_B, real_B)
                loss_G = loss_GAN + (lambda_pixel*lambda1**epoch) * loss_pixel

            # Gradient Accumulation을 사용하여 매 N 스텝마다 가중치 업데이트
            scaler.scale(loss_G).backward()

            
            scaler.step(optimizer_G)
            scaler.update()
            optimizer_G.zero_grad()

            # Discriminator 학습
            with amp.autocast():
                loss_real = criterion_GAN(discriminator(real_B, real_A), real)
                loss_fake = criterion_GAN(discriminator(fake_B.detach(), real_A), fake)
                loss_D = (loss_real + loss_fake) / 2

            scaler.scale(loss_D).backward()

            if (i + 1) % params['gradient_accumulation_steps'] == 0:
                scaler.step(optimizer_D)
                scaler.update()
                optimizer_D.zero_grad()

            total_loss_D += loss_D.item()
            total_loss_pixel += loss_pixel.item()
            total_loss_GAN += loss_GAN.item()
            steps += 1

            tqdmDataLoader.set_postfix(
                ordered_dict={
                    "epoch": epoch + 1,
                    "D loss: ": total_loss_D / steps,
                    "G pixel loss: ": total_loss_pixel / steps,
                    "adv loss: ": total_loss_GAN / steps,
                }
            )

    # 이미지 샘플 저장
    imgs = next(iter(dataloader))
    real_A = (imgs[1].to(device) + 1) / 2
    real_B = (imgs[0].to(device) + 1) / 2
    fake_B = (generator(real_A) + 1) / 2
    img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2)
    to_pil_image(img_sample[0]).save(f'../../result/colorization/pix2pix_r/{epoch}.png')

    # 모델 저장
    torch.save(generator.state_dict(), f"../../model/colorization/pix2pix_r/Pix2Pix_Generator_for_Colorization_{epoch}.pt")
    torch.save(discriminator.state_dict(), f"../../model/colorization/pix2pix_r/Pix2Pix_Discriminator_for_Colorization_{epoch}.pt")
