In [1]:
import torch
import torch.nn as nn
import torch.utils as utils
import torch.nn.init as init
import torch.utils.data as data
import torchvision.utils as v_utils
import torchvision.transforms as transforms 


from torch.optim import Adam
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader
from torchvision import datasets

from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt
import cv2

In [2]:
def conv_block(in_dim, out_dim, act_fn):
    model = nn.Sequential(
        nn.Conv2d(in_dim,out_dim, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_dim),
        act_fn
    )
    return model

In [3]:
def conv_trans_block(in_dim, out_dim, act_fn):
    model = nn.Sequential(
        nn.ConvTranspose2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
        nn.BatchNorm2d(out_dim),
        act_fn
    )
    return model

In [4]:
def maxpool():
    pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
    return pool

In [5]:
def conv_block_2(in_dim,out_dim,act_fn):
    model = nn.Sequential(
        conv_block(in_dim, out_dim, act_fn),
        nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_dim)
    )
    return model

In [6]:
class UnetGenerator(nn.Module):
    def __init__(self,in_dim,out_dim,num_filter):
        super(UnetGenerator,self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.num_filter = num_filter
        act_fn = nn.LeakyReLU(0.2, inplace=True)

        print("\n------Initiating U-Net------\n")

        self.down_1 = conv_block_2(self.in_dim,self.num_filter,act_fn)
        self.pool_1 = maxpool()
        self.down_2 = conv_block_2(self.num_filter*1,self.num_filter*2,act_fn)
        self.pool_2 = maxpool()
        self.down_3 = conv_block_2(self.num_filter*2,self.num_filter*4,act_fn)
        self.pool_3 = maxpool()
        self.down_4 = conv_block_2(self.num_filter*4,self.num_filter*8,act_fn)
        self.pool_4 = maxpool()

        self.bridge = conv_block_2(self.num_filter*8,self.num_filter*16,act_fn)

        self.trans_1 = conv_trans_block(self.num_filter*16,self.num_filter*8,act_fn)
        self.up_1 = conv_block_2(self.num_filter*16,self.num_filter*8,act_fn)
        self.trans_2 = conv_trans_block(self.num_filter*8,self.num_filter*4,act_fn)
        self.up_2 = conv_block_2(self.num_filter*8,self.num_filter*4,act_fn)
        self.trans_3 = conv_trans_block(self.num_filter*4,self.num_filter*2,act_fn)
        self.up_3 = conv_block_2(self.num_filter*4,self.num_filter*2,act_fn)
        self.trans_4 = conv_trans_block(self.num_filter*2,self.num_filter*1,act_fn)
        self.up_4 = conv_block_2(self.num_filter*2,self.num_filter*1,act_fn)

        self.out = nn.Sequential(
            nn.Conv2d(self.num_filter,self.out_dim,3,1,1),
            nn.Tanh(),  #필수는 아님
        )
        
    def forward(self,input):
        down_1 = self.down_1(input)
        pool_1 = self.pool_1(down_1)
        down_2 = self.down_2(pool_1)
        pool_2 = self.pool_2(down_2)
        down_3 = self.down_3(pool_2)
        pool_3 = self.pool_3(down_3)
        down_4 = self.down_4(pool_3)
        pool_4 = self.pool_4(down_4)

        bridge = self.bridge(pool_4)

        trans_1 = self.trans_1(bridge)
        concat_1 = torch.cat([trans_1,down_4],dim=1)
        up_1 = self.up_1(concat_1)
        trans_2 = self.trans_2(up_1)
        concat_2 = torch.cat([trans_2,down_3],dim=1)
        up_2 = self.up_2(concat_2)
        trans_3 = self.trans_3(up_2)
        concat_3 = torch.cat([trans_3,down_2],dim=1)
        up_3 = self.up_3(concat_3)
        trans_4 = self.trans_4(up_3)
        concat_4 = torch.cat([trans_4,down_1],dim=1)
        up_4 = self.up_4(concat_4)
        out = self.out(up_4)
        return out

In [7]:
class Discriminator(nn.Module):
    def __init__(self, in_dim=3,hidden_dim=16):
        super(Discriminator, self).__init__()
        self.in_dim = in_dim
        self.hidden_dim = hidden_dim
        
        act_fn = nn.LeakyReLU(0.2)
        
        self.disc = nn.Sequential(
            self.disc_block(in_dim,hidden_dim,kernel_size=4,stride=2,act_fn=act_fn),                # 16 * 127 * 127
            self.disc_block(hidden_dim,hidden_dim * 2,kernel_size=5,stride=2,act_fn=act_fn),        # 32 * 62 * 62
            self.disc_block(hidden_dim * 2,hidden_dim * 4,kernel_size=4,stride=2,act_fn=act_fn),    # 64 * 30 * 30
            self.disc_block(hidden_dim * 4,hidden_dim * 8,kernel_size=4,stride=2,act_fn=act_fn),    # 128 * 14 * 14
            self.disc_block(hidden_dim * 8,hidden_dim * 16,kernel_size=5,stride=3,act_fn=act_fn),    # 256 * 4 * 4
            self.disc_block(hidden_dim * 16,1,kernel_size=4,stride=1,act_fn=act_fn, final_layer=True),    # 1 * 1 * 1
        )
    def disc_block(self, in_dim, out_dim,kernel_size,stride, act_fn, final_layer=False):
        if final_layer:
            return nn.Sequential(
                nn.Conv2d(in_dim,out_dim,kernel_size,stride)
            )
        return nn.Sequential(
            nn.Conv2d(in_dim,out_dim,kernel_size,stride),
            nn.BatchNorm2d(out_dim),
            act_fn
        )
    def forward(self, image):
        disc_pred = self.disc(image)
        return disc_pred.view(len(disc_pred),-1)

In [8]:
batch_size = 16
img_size = 256
in_dim = 1
out_dim = 3
num_filters = 16

sample_input = torch.ones(size=(batch_size,1,img_size, img_size))

In [9]:
# preprocessing
transform=transforms.Compose([
                            transforms.Resize(256),
                            transforms.CenterCrop(256),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 이미지 값 [-1,1] 사이로 변환
                        ])
transform_gray=transforms.Compose([
                            transforms.Grayscale(num_output_channels=1),
                            transforms.Resize(256),
                            transforms.CenterCrop(256),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5), (0.5)), # 이미지 값 [-1,1] 사이로 변환
                        ])

In [10]:
dataset_color = datasets.ImageFolder("./colorization_data/color_image", transform=transform)
dataset_gray = datasets.ImageFolder("./colorization_data/color_image", transform=transform_gray)

In [11]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [12]:
dataset_color_train , dataset_color_test = train_test_split(dataset_color, test_size=0.2, random_state=42)
dataset_gray_train, dataset_gray_test = train_test_split(dataset_gray,test_size=0.2, random_state=42)

In [13]:
train_color_loader = DataLoader(dataset_color_train, batch_size=batch_size)
train_gray_loader = DataLoader(dataset_gray_train, batch_size=batch_size)
test_color_loader = DataLoader(dataset_color_test, batch_size=batch_size)
test_gray_loader = DataLoader(dataset_gray_test , batch_size=batch_size)

In [14]:
def show_tensor_images(image_tensor, num_images=25, size=(3,64,64)):
    '''
    image_tensor : 텐서로 구성된 이미지
    num_images : 이미지 개수
    size : 이미지 크기
    -> 이미지를 화면에 보여준다.
    '''
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = v_utils.make_grid(image_unflat[:num_images], nrow=4)
    plt.imshow(image_grid.permute(1,2,0).squeeze())
    plt.show()

In [15]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [18]:
display_step = 50
Learning_rate = 0.0001
n_epochs = 100

Unet = UnetGenerator(in_dim=in_dim, out_dim=out_dim,num_filter=num_filters).to(device)
disc = Discriminator().to(device)
loss_func = nn.BCEWithLogitsLoss()

unet_optimizer = torch.optim.Adam(Unet.parameters(), lr=Learning_rate)
disc_optimizer = torch.optim.Adam(disc.parameters(), lr=Learning_rate)

Unet = Unet.apply(weights_init)
disc = disc.apply(weights_init)


------Initiating U-Net------



In [None]:
mean_disc_loss = 0
mean_unet_loss = 0
cur_step = 0
display_step = 100

for epoch in range(n_epochs):
    print(f'epoch: {epoch + 1}/{n_epochs}')
    for [color , _],[gray , _] in tqdm(zip(train_color_loader, train_gray_loader)):
        disc_optimizer.zero_grad()
        gray = gray.to(device)
        color = color.to(device)

        output = Unet(gray)
        logits_output = disc(output.detach())
        logits_color = disc(color)

        disc_loss_output = loss_func(logits_output, torch.zeros_like(logits_output))
        disc_loss_color = loss_func(logits_color, torch.ones_like(logits_color))
        disc_loss = (disc_loss_output + disc_loss_color) / 2

        mean_disc_loss += disc_loss.item()
        
        disc_loss.backward(retain_graph=True)
        disc_optimizer.step()

        unet_optimizer.zero_grad()

        output = Unet(gray)
        logits_output = disc(output)

        unet_loss = loss_func(logits_output, torch.ones_like(logits_output))

        mean_unet_loss += unet_loss.item() 

        unet_loss.backward()
        unet_optimizer.step()

        if cur_step%display_step == 0 or cur_step == 0 or epoch == n_epochs :
            print(f"Step {cur_step}: Generator loss: {mean_unet_loss / display_step}, discriminator loss: {mean_disc_loss / display_step}")
            show_tensor_images(output, num_images=16, size=(3,256,256))
            show_tensor_images(color,num_images=16, size=(3,256,256))
            mean_unet_loss = 0
            mean_disc_loss = 0
        cur_step += 1

In [22]:
def save_tensor_images(image_tensor, num_images=25, size=(3,64,64), dir="", num=0):
    '''
    image_tensor : 텐서로 구성된 이미지
    num_images : 이미지 개수
    size : 이미지 크기
    -> 이미지를 화면에 보여준다.
    '''
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    #image_grid = v_utils.make_grid(image_unflat[:num_images], nrow=4)
    #print(image_unflat[0].permute(1,2,0).squeeze().shape)
    for i in range(num_images):
        image_name = f'{num}_{i}.jpg'
        v_utils.save_image(image_unflat[i].squeeze(), os.path.join(dir,image_name))

In [None]:
gray_path = "./output_image_GAN_200/gray/"
color_path = "./output_image_GAN_200/color/"
output_path = "./output_image_GAN_200/output/"

gray_image = []
color_image = []
output_image = []
display_step = 10
with torch.no_grad():
    cur_step = 0
    for i,[[color, _] ,[gray, _]] in tqdm(enumerate(zip(test_color_loader, test_gray_loader))):
        color = color.to(device)
        gray = gray.to(device)

        output = Unet(gray)
        ''''
        gray_image.append(torch.squeeze(gray.cpu().data))
        color_image.append(torch.squeeze(color.cpu().data))
        output_image.append(torch.squeeze(output.cpu().data))
        '''
        if cur_step%display_step == 0 or cur_step == 0:
            show_tensor_images(output, num_images=16, size=(3,256,256))
            show_tensor_images(color,num_images=16, size=(3,256,256))
            show_tensor_images(gray,num_images=16, size=(256,256))
        
        save_tensor_images(gray,num_images=16, size=(256,256),dir=gray_path,num=cur_step)
        save_tensor_images(color,num_images=16, size=(3,256,256),dir=color_path,num=cur_step)
        save_tensor_images(output,num_images=16, size=(3,256,256),dir=output_path,num=cur_step)
        cur_step += 1

In [21]:
torch.save(Unet,"./model/colorization_GAN_200_Unet.pt")
torch.save(disc, "./model/colorization_GAN_200_disc.pt")