In [None]:
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import SubsetRandomSampler
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from datetime import datetime
from torch.autograd import Variable
import matplotlib
import matplotlib.pyplot as plt
from torchsummary import summary as summary_
from tqdm import tqdm
import re

In [None]:
from google.colab import drive
\
# Accessing My Google Drive
drive.mount('/content/drive')

In [None]:
!pip install kaggle  
from google.colab import files  
files.upload() 

In [None]:
!mkdir -p ~/.kaggle  
!cp kaggle.json ~/.kaggle/    
!chmod 600 ~/.kaggle/kaggle.json 

###Dataset 다운로드

In [None]:
!kaggle datasets download -d aayush9753/image-colorization-dataset

In [None]:
!unzip image-colorization-dataset.zip  

###Custom Dataset 만들기

In [None]:
from PIL import Image
class ColorDataSet(Dataset):
    def __init__(self, idx_set):
        self.gray_path = '/content/data/train_black'
        self.color_path = '/content/data/train_color'
        self.gray_files = os.listdir(self.gray_path)
        self.color_files = os.listdir(self.color_path)
    def __len__(self):
        return len(self.gray_files)
    
    def __getitem__(self, idx):        
        im = Image.open(self.color_path+'/'+ self.color_files[idx]).convert('RGB')
        im = im.resize((256,256))
        im = np.array(im)
        lab = rgb2lab(im).astype(np.float32)
        lab_t = transforms.ToTensor()(lab)
        img_l = lab_t[[0], ...] / 50.0 - 1.0
        img_ab = lab_t[[1, 2], ...] / 110

        img_color = cv2.imread(self.color_path+'/'+ self.color_files[idx])
        img_color = cv2.cvtColor(img_color, cv2.COLOR_BGR2RGB)
        img_color = cv2.resize(img_color, (256, 256))
        img_color = np.transpose(img_color, (2, 0, 1))

        img_gray = cv2.imread(self.gray_path+'/'+ self.gray_files[idx])
        img_gray = cv2.cvtColor(img_gray, cv2.COLOR_BGR2RGB)
        img_gray = cv2.resize(img_gray, (256, 256))
        img_gray = np.transpose(img_gray, (2, 0, 1))


        img_color = img_color.astype('float32') 
        img_gray = img_gray.astype('float32') 
        img_gray = img_gray / 255.
        img_color = img_color / 255.
        
        return img_l, img_ab, img_gray, img_color

In [None]:
idx_list = list(np.arange(1000))
train_val_split = 0.2
np.random.seed(0)
idx_split = int(train_val_split*1000)
np.random.shuffle(idx_list)
train_idx = idx_list[idx_split:1000]
val_idx = idx_list[:idx_split]

train_set = ColorDataSet(idx_set=train_idx)
valid_set = ColorDataSet(idx_set=val_idx)

train_loader = DataLoader(train_set , batch_size=32 , shuffle=False )
valid_loader = DataLoader(valid_set , batch_size=32 , shuffle=False)

In [None]:
train_data = next(iter(train_loader))
test_data = next(iter(valid_loader))

In [None]:
print(torch.max(train_data[0]))
print(torch.min(train_data[0]))
print(torch.max(train_data[1]))
print(torch.min(train_data[1]))
print(train_data[0].shape)
print(train_data[1].shape)

###Generator & Discriminator

In [None]:
class Generator(nn.Module):

    def __init__(self):
        super(Generator, self).__init__()

        # --------- Encoder ---------
        self.encod1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=4, padding=1, stride=2),
        )
        self.encod2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, padding=1, stride=2),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(negative_slope=0.2)
        )
        self.encod3 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, padding=1, stride=2),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(negative_slope=0.2)
        )
        self.encod4 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, padding=1, stride=2),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(negative_slope=0.2)
        )
        self.encod5 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=4, padding=1, stride=2),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(negative_slope=0.2)
        )
        self.encod6 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=4, padding=1, stride=2),
            nn.LeakyReLU(negative_slope=0.2),
            nn.BatchNorm2d(512),
        )
        self.encod7 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=4, padding=1, stride=2),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(negative_slope=0.2)
        )
        self.encod8 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=4, padding=1, stride=2),
        )

        # --------- Decoder ---------
        self.decod8 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=4, padding=1, stride=2),
            nn.BatchNorm2d(512),
            nn.Dropout2d(0.5),
            nn.ReLU()
        )
        self.decod7 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=2 * 512, out_channels=512, kernel_size=4, padding=1, stride=2),
            nn.BatchNorm2d(512),
            nn.Dropout2d(0.5),
            nn.ReLU()
        )
        self.decod6 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=2 * 512, out_channels=512, kernel_size=4, padding=1, stride=2),
            nn.BatchNorm2d(512),
            nn.Dropout2d(0.5),
            nn.ReLU()
        )
        self.decod5 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=2 * 512, out_channels=512, kernel_size=4, padding=1, stride=2),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.decod4 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=2 * 512, out_channels=256, kernel_size=4, padding=1, stride=2),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.decod3 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=2 * 256, out_channels=128, kernel_size=4, padding=1, stride=2),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.decod2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=2 * 128, out_channels=64, kernel_size=4, padding=1, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.decodout = nn.Sequential(
            nn.ConvTranspose2d(in_channels=2 * 64, out_channels=2, kernel_size=4, padding=1, stride=2),
            nn.Tanh())

    def forward(self, x: torch.Tensor):
        # --------- Encoder ---------
        e1 = self.encod1(x)
        e2 = self.encod2(e1)
        e3 = self.encod3(e2)
        e4 = self.encod4(e3)
        e5 = self.encod5(e4)
        e6 = self.encod6(e5)
        e7 = self.encod7(e6)
        e8 = self.encod8(e7)

        # --------- Decoder ---------
        d8 = self.decod8(e8)
        d7 = self.decod7(torch.cat([d8, e7], 1))  
        d6 = self.decod6(torch.cat([d7, e6], 1))  
        d5 = self.decod5(torch.cat([d6, e5], 1))  
        d4 = self.decod4(torch.cat([d5, e4], 1))  
        d3 = self.decod3(torch.cat([d4, e3], 1))  
        d2 = self.decod2(torch.cat([d3, e2], 1))  

        out = self.decodout(torch.cat([d2, e1], 1))

        return out

In [None]:
class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()

        self.image_size = 256

        self.conv = nn.Sequential(
            # input = bs x 256 x 256 x 3 / output = bs x 128 x 128 x 64
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1, stride=2),
            nn.LeakyReLU(negative_slope=0.2),

            # input = bs x 128 x 128 x 64 / output = bs x 64 x 64 x 128
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1, stride=2),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(negative_slope=0.2),
            
            # input = bs x 64 x 64 x 128 / output = bs x 32 x 32 x 256
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1, stride=2),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(negative_slope=0.2),

            # input = bs x 32 x 32 x 256 / output = bs x 16 x 16 x 512
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1, stride=2),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(negative_slope=0.2),
            
            # input = bs x 16 x 16 x 512 / output = bs x 8 x 8 x 1
            nn.Conv2d(in_channels=512, out_channels=1, kernel_size=3, padding=1, stride=1)
        )

    def forward(self, clr: torch.Tensor, bw: torch.Tensor):

        cat_clr_bw = torch.cat((clr, bw), 1)
        features = self.conv(cat_clr_bw)
        output = torch.sigmoid(features)

        return output

###모델

In [None]:
model = Generator().to(device)
summary_(model,(1,256,256),batch_size=1)

In [None]:
model = Discriminator().to(device)
summary_(model,[(3,256,256),(3,256,256)],batch_size=1)

In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [None]:
G=Generator().to(device)
D=Discriminator().to(device)

###시각화

In [None]:
from skimage.color import lab2rgb, rgb2lab, rgb2gray
def Lab2Rgb(L, AB):
        """Convert an Lab tensor image to a RGB numpy output
        Parameters:
            L  (1-channel tensor array): L channel images (range: [-1, 1], torch tensor array)
            AB (2-channel tensor array):  ab channel images (range: [-1, 1], torch tensor array)
        Returns:
            rgb (RGB numpy image): rgb output images  (range: [0, 255], numpy array)
        """
        AB2 = AB * 110.0
        L2 = (L + 1.0) * 50.0
        Lab = torch.cat([L2, AB2], dim=0)
        Lab = Lab.detach().cpu().float().numpy()
        Lab = np.transpose(Lab.astype(np.float64), (1,2,0))
        rgb = lab2rgb(Lab) 
        return rgb

In [None]:
def generate_images(model, test_input, tar, test_gray,test_rgb):
  test_input.to(device)
  prediction = model(test_input)
  predict = Lab2Rgb(test_input[0],prediction[0])
  test_gray = test_gray.cpu().detach().numpy()
  test_gray = test_gray.transpose((0,2,3,1))
  test_rgb = test_rgb.cpu().detach().numpy()
  test_rgb = test_rgb.transpose((0,2,3,1))
  display_list = [test_gray[0], test_rgb[0], predict]
  title = ['Input Image', 'Ground Truth', 'Predicted Image']

  plt.figure(figsize=(15,15))
  for i in range(3):
      plt.subplot(1, 3, i+1)
      plt.title(title[i])
      plt.imshow((display_list[i]))
      plt.axis('off')
  plt.show()


l_img = test_data[0]
ab_img = test_data[1]
gray_img = test_data[2]
rgb_img = test_data[3]
generate_images(G,l_img.to(device),ab_img.to(device),gray_img.to(device),rgb_img.to(device))

###Training

In [None]:
import torch,gc
gc.collect()
torch.cuda.empty_cache()

lr=2e-4
betas=(0.5,0.999)
g_loss=[]   # storing Generator loss
d_loss=[]   # storing Discriminator loss
g_epoch_loss=[]
d_epoch_loss=[]
Epochs = 150

SAVEPATH = '/content/drive/MyDrive/Colab Notebooks/pix2pix/'
print(os.path.isfile(SAVEPATH + 'model.pth'))

In [None]:
g_optimizer=optim.Adam(G.parameters(),lr=lr,betas=betas)
d_optimizer=optim.Adam(D.parameters(),lr=lr,betas=betas)

checkpoint = torch.load(SAVEPATH + 'pix2pix(6).pth')
G.load_state_dict(checkpoint['G_state_dict'])
D.load_state_dict(checkpoint['D_state_dict'])
g_optimizer.load_state_dict(checkpoint['g_optimizer_state_dict'])
d_optimizer.load_state_dict(checkpoint['d_optimizer_state_dict'])

criteria = nn.BCEWithLogitsLoss()
criteriaL1 = torch.nn.L1Loss()
lambda_L1 = 100.0

model.train()

for i in range(Epochs):
    t1=datetime.now()
    discriminator_running_loss = 0.0
    generator_running_loss = 0.0
    size = 0

    for ix, data in enumerate(train_loader):
        gray_img = data[0].to(device)
        rgb_img = data[1].to(device)

        batch_size = rgb_img.size(0)
        size += batch_size

        ## Discriminator train 
        D.zero_grad()
        #real_loss
        out=D(rgb_img,gray_img)
        real_loss=criteria(out,Variable(torch.ones_like(out)))
        #fake_loss
        fake_images=G(gray_img).detach()
        fake_out=D(fake_images,gray_img)
        fake_loss=criteria(fake_out,Variable(torch.zeros_like(fake_out)))

        loss_D = (real_loss + fake_loss) * 0.5

        discriminator_running_loss += loss_D.item()#.cpu().detach().numpy()

        loss_D.backward()
        d_optimizer.step()

        G.zero_grad()
        result=D(G(gray_img),gray_img)
        GAN_loss=criteria(result,Variable(torch.ones_like(result)))
        L1_loss = criteriaL1(G(gray_img),rgb_img) * lambda_L1

        loss_G = GAN_loss + L1_loss

        generator_running_loss += loss_G.cpu().detach().numpy()
        loss_G.backward()
        g_optimizer.step()
    epoch_dis_loss = discriminator_running_loss / size
    epoch_gen_loss = generator_running_loss / size

    d_loss.append(epoch_dis_loss)
    g_loss.append(epoch_gen_loss)
    print("===> Epoch[{}]({}/{}): Loss_D: {:.4f} Loss_G: {:.4f}".format(
            i, ix, len(train_loader), d_loss[i], g_loss[i]))
    generate_images(G,test_data[0].to(device),test_data[1].to(device),test_data[2].to(device),test_data[3].to(device))
    if(i+1) % 10 == 0:
      torch.save({'epoch':i,
                  'G_state_dict': G.state_dict(),
                  'D_state_dict': D.state_dict(),
                  'g_optimizer_state_dict':g_optimizer.state_dict(),
                  'd_optimizer_state_dict':d_optimizer.state_dict()},SAVEPATH + 'pix2pix(6).pth')
    gc.collect()
    torch.cuda.empty_cache()
plt.plot(g_loss)
plt.plot(d_loss)
plt.legend(["Generator","Discriminator"])

###영상 품질 평가 지표

In [None]:
def PSNR(gray,color):
  gray = torch.Tensor(gray)
  color = torch.Tensor(color)
  mse = torch.mean((gray-color)**2)
  return 20 * torch.log10(255.0 / torch.sqrt(mse))

def SSIM(y_true , y_pred):
    u_true = np.mean(y_true)
    u_pred = np.mean(y_pred)
    var_true = np.var(y_true)
    var_pred = np.var(y_pred)
    std_true = np.sqrt(var_true)
    std_pred = np.sqrt(var_pred)
    c1 = np.square(0.01*7)
    c2 = np.square(0.03*7)
    ssim = (2 * u_true * u_pred + c1) * (2 * std_pred * std_true + c2)
    denom = (u_true ** 2 + u_pred ** 2 + c1) * (var_pred + var_true + c2)
    return ssim / denom

###Test

In [None]:
def generate_images(model, test_input, tar, test_gray,test_rgb):
  test_input.to(device)
  prediction = model(test_input)
  predict = []
  for i in range(len(test_input)):
    predict.append(Lab2Rgb(test_input[i],prediction[i]))
  test_gray = test_gray.cpu().detach().numpy()
  test_gray = test_gray.transpose((0,2,3,1))
  test_rgb = test_rgb.cpu().detach().numpy()
  test_rgb = test_rgb.transpose((0,2,3,1))
  display_list = [test_gray, test_rgb, predict]
  title = ['Input Image', 'Ground Truth', 'Predicted Image']
  

  for i in range(len(test_input)):
    ssim = SSIM(test_rgb[i],predict[i])
    psnr = PSNR(test_rgb[i],predict[i]).detach().cpu().numpy()
    print("ssim: ",ssim,"psnr = ",psnr)
    plt.figure(figsize = (20, 20))
        
    plt.subplot(1,3,1)
    plt.imshow(test_gray[i])
    plt.title('BandW Image',fontsize = 20)
    plt.axis('off')
        
    plt.subplot(1,3,2)
    plt.imshow(predict[i])
    plt.title('GenerateImg',fontsize = 20)
    plt.axis('off')
        
    plt.subplot(1,3,3)
    plt.imshow(test_rgb[i])
    plt.title('Colored Img',fontsize = 20)
    plt.axis('off')
        
  plt.show()


checkpoint = torch.load(SAVEPATH + 'pix2pix(6).pth')
G.load_state_dict(checkpoint['G_state_dict'])

l_img = test_data[0]
ab_img = test_data[1]
gray_img = test_data[2]
rgb_img = test_data[3]
generate_images(G,l_img.to(device),ab_img.to(device),gray_img.to(device),rgb_img.to(device))