<a href="https://colab.research.google.com/github/CHOSIYEON/Multimedia_Colorization/blob/main/%EC%B5%9C%EC%A2%85.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Get Dataset

In [None]:
from google.colab import drive
drive.mount('/content/drive/')

In [None]:
import os
import zipfile
import tqdm

file_name = "Multimedia_dataset.zip"
zip_path = os.path.join('/content/drive/MyDrive/Multimedia_dataset.zip')

!cp "{zip_path}" .
!unzip -q "{file_name}"
!rm "{file_name}"

In [None]:
import os
import zipfile
import tqdm

file_name = "colorization_test_dataset.zip"
zip_path = os.path.join('/content/drive/MyDrive/colorization_test_dataset.zip')

!cp "{zip_path}" .
!unzip -q "{file_name}"
!rm "{file_name}"

In [None]:
import matplotlib.pyplot as plt

train_root = './train'
val_root = './validation'

train_examples = os.listdir(train_root)
val_examples = os.listdir(val_root)

print(len(train_examples)) #4500
print(len(val_examples)) # 500

# image read
img = plt.imread(train_root + '/' + train_examples[1])
# image show
plt.imshow(img)
plt.show()

# Color-hint Transform

In [None]:
import torch
from torch.autograd import Variable
from torchvision import transforms

import cv2
import random
import numpy as np

class ColorHintTransform(object):
  def __init__(self, size=256, mode="training"):
    super(ColorHintTransform, self).__init__()
    self.size = size
    self.mode = mode
    self.transform = transforms.Compose([transforms.ToTensor()])

  def bgr_to_lab(self, img):
    lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
    l, ab = lab[:, :, 0], lab[:, :, 1:]
    return l, ab

  def hint_mask(self, bgr, threshold=[0.95, 0.97, 0.99]):
    h, w, c = bgr.shape
    mask_threshold = random.choice(threshold)
    mask = np.random.random([h, w, 1]) > mask_threshold
    return mask

  def img_to_mask(self, mask_img):
    mask = mask_img[:, :, 0, np.newaxis] >= 255
    return mask

  def __call__(self, img, mask_img=None):
    threshold = [0.95, 0.97, 0.99]
    if (self.mode == "training") | (self.mode == "validation"):
      image = cv2.resize(img, (self.size, self.size))
      mask = self.hint_mask(image, threshold)

      hint_image = image * mask

      l, ab = self.bgr_to_lab(image)
      l_hint, ab_hint = self.bgr_to_lab(hint_image)

      return self.transform(l), self.transform(ab), self.transform(ab_hint)

    elif self.mode == "testing":
      image = cv2.resize(img, (self.size, self.size))
      hint_image = image * self.img_to_mask(mask_img)

      l, _ = self.bgr_to_lab(image)
      _, ab_hint = self.bgr_to_lab(hint_image)

      return self.transform(l), self.transform(ab_hint)

    else:
      return NotImplementedError

# Dataloader for Colorization Dataset

In [None]:
import torch
import torch.utils.data  as data
import os
import cv2
from google.colab.patches import cv2_imshow

class ColorHintDataset(data.Dataset):
  def __init__(self, root_path, size):
    super(ColorHintDataset, self).__init__()

    self.root_path = root_path
    self.size = size
    self.transforms = None
    self.examples = None
    self.hint = None
    self.mask = None

  def set_mode(self, mode):
    self.mode = mode
    self.transforms = ColorHintTransform(self.size, mode)
    if mode == "training":
      train_dir = os.path.join(self.root_path, "train")
      self.examples = [os.path.join(self.root_path, "train", dirs) for dirs in os.listdir(train_dir)]
    elif mode == "validation":
      val_dir = os.path.join(self.root_path, "validation")
      self.examples = [os.path.join(self.root_path, "validation", dirs) for dirs in os.listdir(val_dir)]
    elif mode == "testing":
      hint_dir = os.path.join(self.root_path, "hint")
      mask_dir = os.path.join(self.root_path, "mask")
      self.hint = [os.path.join(self.root_path, "hint", dirs) for dirs in os.listdir(hint_dir)]
      self.mask = [os.path.join(self.root_path, "mask", dirs) for dirs in os.listdir(mask_dir)]
    else:
      raise NotImplementedError

  def __len__(self):
    if self.mode != "testing":
      return len(self.examples)
    else:
      return len(self.hint)

  def __getitem__(self, idx):
    if self.mode == "testing":
      hint_file_name = self.hint[idx]
      mask_file_name = self.mask[idx]
      hint_img = cv2.imread(hint_file_name)
      mask_img = cv2.imread(mask_file_name)

      input_l, input_hint = self.transforms(hint_img, mask_img)
      sample = {"l": input_l, "hint": input_hint,
                "file_name": "image_%06d.png" % int(os.path.basename(hint_file_name).split('.')[0])}
    else:
      file_name = self.examples[idx]
      img = cv2.imread(file_name)
      l, ab, hint = self.transforms(img)
      sample = {"l": l, "ab": ab, "hint": hint}

    return sample

In [None]:
print(torch.cuda.is_available())

# Example for Loading

In [None]:
import torch
import torch.utils.data  as data
import os
import cv2
from google.colab.patches import cv2_imshow
import matplotlib.pyplot as plt
from torchvision import transforms
import tqdm
from PIL import Image
import numpy as np
from torchvision.transforms import Compose, ToTensor, ToPILImage

def tensor2im(input_image, imtype=np.uint8):
  if isinstance(input_image, torch.Tensor):
      image_tensor = input_image.data
  else:
      return input_image
  image_numpy = image_tensor[0].cpu().float().numpy()
  if image_numpy.shape[0] == 1:
      image_numpy = np.tile(image_numpy, (3, 1, 1))
  image_numpy = np.clip((np.transpose(image_numpy, (1, 2, 0)) ),0, 1) * 255.0
  return image_numpy.astype(imtype)

# Change to your data root directory
root_path = "/content/"
# Depend on runtime setting
use_cuda = True

train_dataset = ColorHintDataset(root_path, 128)
train_dataset.set_mode("training")
train_dataloader = data.DataLoader(train_dataset, batch_size=4, shuffle=True)

val_dataset = ColorHintDataset(root_path, 128)
val_dataset.set_mode("validation")
val_dataloader = data.DataLoader(val_dataset, batch_size=4, shuffle=True)

test_dataset = ColorHintDataset(root_path, 128)
test_dataset.set_mode("testing")
test_dataloader = data.DataLoader(test_dataset, batch_size=1)

# for i, data in enumerate(tqdm.tqdm(train_dataloader)):
#   if use_cuda:
#     l = data["l"].to('cuda')
#     ab = data["ab"].to('cuda')
#     hint = data["hint"].to('cuda')
  
#   gt_image = torch.cat((l, ab), dim=1)
#   hint_image = torch.cat((l, hint), dim=1)

#   gt_np = tensor2im(gt_image)
#   hint_np = tensor2im(hint_image)

#   gt_bgr = cv2.cvtColor(gt_np, cv2.COLOR_LAB2BGR)
#   hint_bgr = cv2.cvtColor(hint_np, cv2.COLOR_LAB2BGR)
  
  # cv2_imshow(gt_bgr)
  # cv2_imshow(hint_bgr)

# Network Construction

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module):
  def __init__(self):
    super(UNet, self).__init__()

    def conv(in_channel, out_channel):
      layers = []
      layers += [nn.Conv2d(in_channels = in_channel, out_channels = out_channel, kernel_size = 3, stride = 1, padding = 1,
                           bias = True)]
      layers += [nn.BatchNorm2d(num_features = out_channel)]
      layers += [nn.ReLU()]

      model = nn.Sequential(*layers)
      return model

    def pool():
      model = nn.MaxPool2d(kernel_size=2)
      return model 

    def unpool(in_channel, out_channel):
      model = nn.ConvTranspose2d(in_channels = in_channel, out_channels = out_channel, kernel_size = 2, stride = 2, 
                                 padding = 0, bias = True)
      return model

    # Contracting path 
    self.enc1 = conv(3, 32)
    self.enc1_ = conv(32, 32)
    self.pool1 = pool()

    self.enc2 = conv(32, 64)
    self.enc2_ = conv(64, 64)
    self.pool2 = pool()

    self.enc3 = conv(64, 128)
    self.enc3_ = conv(128, 128)
    self.pool3 = pool()

    self.enc4 = conv(128, 256)
    self.enc4_ = conv(256, 256)
    self.pool4 = pool()

    self.enc5 = conv(256, 512)
    self.enc5_ = conv(512, 512)
    self.pool5 = pool()

    self.enc6 = conv(512, 1024)

    # Expanding path
    self.dec6 = conv(1024, 512)
    self.unpool5 = unpool(512, 512)

    self.dec5_ = conv(1024, 512)
    self.dec5 = conv(512, 256)
    self.unpool4 = unpool(256, 256)

    self.dec4_ = conv(512, 256)
    self.dec4 = conv(256, 128)
    self.unpool3 = unpool(128, 128)

    self.dec3_ = conv(256, 128)
    self.dec3 = conv(128, 64)
    self.unpool2 = unpool(64, 64)

    self.dec2_ = conv(128, 64)
    self.dec2 = conv(64, 32)
    self.unpool1 = unpool(32, 32)

    self.dec1_ = conv(64, 32)
    self.dec1 = conv(32, 32)
    
    self.fc = nn.Conv2d(in_channels=32, out_channels=2, kernel_size=1, stride=1, padding=0, bias=True)

  def forward(self, x):
    enc1 = F.relu(self.enc1(x))
    enc1 = F.relu(self.enc1_(enc1))
    pool1 = F.relu(self.pool1(enc1))

    enc2 = F.relu(self.enc2(pool1))
    enc2 = F.relu(self.enc2_(enc2))
    pool2 = F.relu(self.pool2(enc2))

    enc3 = F.relu(self.enc3(pool2))
    enc3 = F.relu(self.enc3_(enc3))
    pool3 = F.relu(self.pool3(enc3))

    enc4 = F.relu(self.enc4(pool3))
    enc4 = F.relu(self.enc4_(enc4))
    pool4 = F.relu(self.pool4(enc4))

    enc5 = F.relu(self.enc5(pool4))
    enc5 = F.relu(self.enc5_(enc5))
    pool5 = F.relu(self.pool5(enc5))

    enc6 = F.relu(self.enc6(pool5))

    dec6 = F.relu(self.dec6(enc6))

    unpool5 = F.relu(self.unpool5(dec6))
    merge5 = torch.cat((unpool5, enc5), dim = 1)
    dec5 = F.relu(self.dec5_(merge5))
    dec5 = F.relu(self.dec5(dec5))

    unpool4 = F.relu(self.unpool4(dec5))
    merge4 = torch.cat((unpool4, enc4), dim = 1)
    dec4 = F.relu(self.dec4_(merge4))
    dec4 = F.relu(self.dec4(dec4))

    unpool3 = F.relu(self.unpool3(dec4))
    merge3 = torch.cat((unpool3, enc3), dim = 1)
    dec3 = F.relu(self.dec3_(merge3))
    dec3 = F.relu(self.dec3(dec3))

    unpool2 = F.relu(self.unpool2(dec3))
    merge2 = torch.cat((unpool2, enc2), dim = 1)
    dec2 = F.relu(self.dec2_(merge2))
    dec2 = F.relu(self.dec2(dec2))

    unpool1 = F.relu(self.unpool1(dec2))
    merge1 = torch.cat((unpool1, enc1), dim = 1)
    dec1 = F.relu(self.dec1_(merge1))
    dec1 = F.relu(self.dec1(dec1))

    out = self.fc(dec1)

    return out

In [None]:
import time

net = UNet()
s = time.time()

for i in range(3):
  dummy_input = torch.rand(50, 3, 128, 128)
  out = net(dummy_input)
  # output_np = tensor2im(out)
  # output_bgr = cv2.cvtColor(output_np, cv2.COLOR_LAB2BGR)
  # cv2_imshow(output_bgr)
  print(out.shape)

e = time.time()

print(e-s)

# Network Training

In [None]:
import torch 
import torch.nn as nn
import torch.utils.data as data
import os
import matplotlib.pyplot as plt

print('train dataset length: ', len(train_dataloader))
print('validation dataset length: ', len(val_dataloader))

# 1. Network Setting
net = UNet().cuda()

# 2. Loss ans Optimizer setting
import torch.optim as optim
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters() , lr= 0.001)

# 3. 기타 변수들
train_info = []
val_info = []
object_epoch = 175

save_path = './ColorizationNetwork'
os.makedirs(save_path, exist_ok= True)
output_path = os.path.join(save_path, 'colorization_model.tar')

In [None]:
from skimage.measure.simple_metrics import compare_psnr

def batch_PSNR(img, imclean, data_range):
    Img = img.data.cpu().numpy().astype(np.float32)
    Iclean = imclean.data.cpu().numpy().astype(np.float32)
    PSNR = 0
    for i in range(Img.shape[0]):
        PSNR += compare_psnr(Iclean[i, :, :, :], Img[i, :, :, :], data_range=data_range)
    return (PSNR/Img.shape[0])

In [None]:
import tqdm 

def train_1epoch(net, dataloader):
  total_loss = 0 # 1 epoch loss
  iteration = 1 # iteration number

  net.train() # training mode
  psnr_train = []

  for data in tqdm.auto.tqdm(dataloader):
    # 1. 데이터 준비
    l = data['l']
    ab = data['ab']
    hint = data['hint']

    # use GPU
    if use_cuda: 
      l = data['l'].to('cuda')
      ab = data['ab'].to('cuda')
      hint = data['hint'].to('cuda')

    gt_image = torch.cat((l, ab), dim=1)
    hint_image = torch.cat((l, hint), dim=1)

    # 2. gradient 초기화
    optimizer.zero_grad() # gradient zero

    # 3. 네트워크 결과 얻기 (Forward)
    output = net(hint_image)

    # 4. loss 얻기
    loss = criterion(output, ab)

    # 5. gradient 계산
    loss.backward()

    # 6. gradient 적용
    optimizer.step()

    total_loss += loss.detach() # detach -> 계산 그래프의 분리, detach 그래디언트 계산함
    iteration += 1

    # psnr
    output_image = torch.cat((l, output), dim= 1)
    psnr = batch_PSNR(gt_image, output_image, 1.)
    psnr_train.append(psnr)


  mean_psnr = np.mean(psnr_train)
  total_loss /= iteration

  return mean_psnr, total_loss

In [None]:
def validation_1epoch(net, dataloader):
  total_loss = 0 # 1 epoch loss
  iteration = 1 # iteration number

# validation 을 위한 코드 구조
 # -----------------------------------------------------------------------------
  net.eval() # validation mode
  psnr_val = []

  for data in tqdm.auto.tqdm(dataloader):
    # 1. 데이터 준비
    l = data['l']
    ab = data['ab']
    hint = data['hint']

    # use GPU
    if use_cuda: 
      l = data['l'].to('cuda')
      ab = data['ab'].to('cuda')
      hint = data['hint'].to('cuda')

    gt_image = torch.cat((l, ab), dim=1)
    hint_image = torch.cat((l, hint), dim=1)

    # 3. 네트워크 결과 얻기 (Forward)
    output = net(hint_image)

    # 4. loss 얻기
    loss = criterion(output, ab)

    # psnr
    output_image = torch.cat((l, output), dim= 1)
    psnr = batch_PSNR(gt_image, output_image, 1.)
    psnr_val.append(psnr)

 # -----------------------------------------------------------------------------

    # detach(): Tensor 기울기 계산 그래프에서 제거
    total_loss += loss.detach() # detach -> 계산 그래프의 분리, detach 그래디언트 계산함
    iteration += 1
  mean_psnr = np.mean(psnr_val)
  total_loss /= iteration

  return mean_psnr, total_loss

In [None]:
best_psnr = 0

for epoch in range(object_epoch):
  train_psnr , train_loss = train_1epoch(net, train_dataloader)
  print('[Training] Epoch {}: PSNR: {}, loss: {}'.format(epoch, train_psnr, train_loss))
  train_info.append({
      'loss': train_loss,
      'psnr': train_psnr
  }
  )

  with torch.no_grad(): # gradient 를 계산하지 않겠다 
    val_psnr, val_loss = validation_1epoch(net, val_dataloader)
  print('[Validation] Epoch {}: psnr: {}, loss: {}'.format(epoch, val_psnr, val_loss))
  val_info.append({
      'loss': val_loss,
      'psnr': val_psnr
  }
  )

  if best_psnr < val_psnr:
    best_psnr = val_psnr
    torch.save({
      'memo': 'Colorization Model',
      'psnr': val_psnr,
      'loss': val_loss,
      'model_weight': net.state_dict()
  }, output_path)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

epoch_axis = np.arange(0, object_epoch)
plt.title('PSNR')
plt.plot(epoch_axis, [info['psnr'] for info in train_info], epoch_axis, [info['psnr'] for info in val_info], 'r-')
plt.legend(['Train', 'Validation'])
plt.show()

print("\n")

plt.figure()
plt.title('Loss')
plt.plot(epoch_axis, [info['loss'] for info in train_info], epoch_axis, [info['loss'] for info in val_info], 'r-')
plt.legend(['Train', 'Validation'])
plt.show()

# Model Testing

In [None]:
import torch
import os

model_path = os.path.join(save_path, 'colorization_model.tar')
state_dict = torch.load(model_path)

print(state_dict['memo'])
print(state_dict.keys())
print(state_dict['loss'])

net = UNet().cuda()
net.load_state_dict(state_dict['model_weight'], strict= True) 

In [None]:
import tqdm

net = UNet().cuda()
net.load_state_dict(state_dict['model_weight'], strict= True)

# os.makedirs('/content/output', exist_ok=True)
os.makedirs('/content/drive/MyDrive/output', exist_ok=True)
result_path = '/content/drive/MyDrive/output/'

def test_1epoch(net, dataloader):
  net.eval()

  for sample in tqdm.auto.tqdm(dataloader):
    if use_cuda:
      l = sample["l"].to('cuda')
      hint = sample["hint"].to('cuda')
      file_name = sample["file_name"]

    hint_image = torch.cat((l, hint), dim=1)
    output_image = net(hint_image)
    output_image = torch.cat((l,output_image),dim=1)

    hint_np = tensor2im(hint_image)
    output_np = tensor2im(output_image)

    hint_bgr = cv2.cvtColor(hint_np, cv2.COLOR_LAB2BGR)
    output_bgr = cv2.cvtColor(output_np, cv2.COLOR_LAB2BGR)

    # cv2_imshow(hint_bgr)
    # cv2_imshow(output_bgr)

    cv2.imwrite(result_path+file_name[0], output_bgr)


res = test_1epoch(net, test_dataloader)

In [None]:
sorted(os.listdir('/content/drive/MyDrive/output'), key=lambda x: int((x.split('_')[1]).split('.')[0]))

root = '/content/drive/MyDrive/output/'
re = os.listdir(root)
print(len(re))
