<a href="https://colab.research.google.com/github/CHOSIYEON/Multimedia_Colorization/blob/main/attention_Unet_Colorizatioin.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Get Dataset

In [None]:
from google.colab import drive
drive.mount('/content/drive/')

In [None]:
import os
import zipfile
import tqdm

file_name = "Multimedia_dataset.zip"
zip_path = os.path.join('/content/drive/MyDrive/Multimedia_dataset.zip')

!cp "{zip_path}" .
!unzip -q "{file_name}"
!rm "{file_name}"

In [None]:
import os
import zipfile
import tqdm

file_name = "colorization_test_dataset.zip"
zip_path = os.path.join('/content/drive/MyDrive/colorization_test_dataset.zip')

!cp "{zip_path}" .
!unzip -q "{file_name}"
!rm "{file_name}"

In [None]:
import matplotlib.pyplot as plt

train_root = './Multimedia_dataset/train'
val_root = './Multimedia_dataset/validation'

train_examples = os.listdir(train_root)
val_examples = os.listdir(val_root)

print(len(train_examples)) #4500
print(len(val_examples)) # 500

# image read
img = plt.imread(train_root + '/' + train_examples[1])
# image show
plt.imshow(img)
plt.show()

# Color-hint Transform

In [None]:
import torch
from torch.autograd import Variable
from torchvision import transforms

import cv2
import random
import numpy as np

class ColorHintTransform(object):
  def __init__(self, size=256, mode="training"):
    super(ColorHintTransform, self).__init__()
    self.size = size
    self.mode = mode
    self.transform = transforms.Compose([transforms.ToTensor()])

  def bgr_to_lab(self, img):
    lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
    l, ab = lab[:, :, 0], lab[:, :, 1:]
    return l, ab

  def hint_mask(self, bgr, threshold=[0.95, 0.97, 0.99]):
    h, w, c = bgr.shape
    mask_threshold = random.choice(threshold)
    mask = np.random.random([h, w, 1]) > mask_threshold
    return mask

  def img_to_mask(self, mask_img):
    mask = mask_img[:, :, 0, np.newaxis] >= 255
    return mask

  def __call__(self, img, mask_img=None):
    threshold = [0.95, 0.97, 0.99]
    if (self.mode == "training") | (self.mode == "validation"):
      image = cv2.resize(img, (self.size, self.size))
      mask = self.hint_mask(image, threshold)

      hint_image = image * mask

      l, ab = self.bgr_to_lab(image)
      l_hint, ab_hint = self.bgr_to_lab(hint_image)

      return self.transform(l), self.transform(ab), self.transform(ab_hint)

    elif self.mode == "testing":
      image = cv2.resize(img, (self.size, self.size))
      hint_image = image * self.img_to_mask(mask_img)

      l, _ = self.bgr_to_lab(image)
      _, ab_hint = self.bgr_to_lab(hint_image)

      return self.transform(l), self.transform(ab_hint)

    else:
      return NotImplementedError

# Dataloader for Colorization Dataset

In [None]:
import torch
import torch.utils.data  as data
import os
import cv2
from google.colab.patches import cv2_imshow

class ColorHintDataset(data.Dataset):
  def __init__(self, root_path, size):
    super(ColorHintDataset, self).__init__()

    self.root_path = root_path
    self.size = size
    self.transforms = None
    self.examples = None
    self.hint = None
    self.mask = None

  def set_mode(self, mode):
    self.mode = mode
    self.transforms = ColorHintTransform(self.size, mode)
    if mode == "training":
      train_dir = os.path.join(self.root_path, "train")
      self.examples = [os.path.join(self.root_path, "train", dirs) for dirs in os.listdir(train_dir)]
    elif mode == "validation":
      val_dir = os.path.join(self.root_path, "validation")
      self.examples = [os.path.join(self.root_path, "validation", dirs) for dirs in os.listdir(val_dir)]
    elif mode == "testing":
      hint_dir = os.path.join(self.root_path, "hint")
      mask_dir = os.path.join(self.root_path, "mask")
      self.hint = [os.path.join(self.root_path, "hint", dirs) for dirs in os.listdir(hint_dir)]
      self.mask = [os.path.join(self.root_path, "mask", dirs) for dirs in os.listdir(mask_dir)]
    else:
      raise NotImplementedError

  def __len__(self):
    if self.mode != "testing":
      return len(self.examples)
    else:
      return len(self.hint)

  def __getitem__(self, idx):
    if self.mode == "testing":
      hint_file_name = self.hint[idx]
      mask_file_name = self.mask[idx]
      hint_img = cv2.imread(hint_file_name)
      mask_img = cv2.imread(mask_file_name)

      input_l, input_hint = self.transforms(hint_img, mask_img)
      sample = {"l": input_l, "hint": input_hint,
                "file_name": "image_%06d.png" % int(os.path.basename(hint_file_name).split('.')[0])}
    else:
      file_name = self.examples[idx]
      img = cv2.imread(file_name)
      l, ab, hint = self.transforms(img)
      sample = {"l": l, "ab": ab, "hint": hint}

    return sample

In [None]:
print(torch.cuda.is_available())

# Example for Loading

In [None]:
import torch
import torch.utils.data  as data
import os
import cv2
from google.colab.patches import cv2_imshow
import matplotlib.pyplot as plt
from torchvision import transforms
import tqdm
from PIL import Image
import numpy as np
from torchvision.transforms import Compose, ToTensor, ToPILImage

def tensor2im(input_image, imtype=np.uint8):
  if isinstance(input_image, torch.Tensor):
      image_tensor = input_image.data
  else:
      return input_image
  image_numpy = image_tensor[0].cpu().float().numpy()
  if image_numpy.shape[0] == 1:
      image_numpy = np.tile(image_numpy, (3, 1, 1))
  image_numpy = np.clip((np.transpose(image_numpy, (1, 2, 0)) ),0, 1) * 255.0
  return image_numpy.astype(imtype)

# Change to your data root directory
root_path = "/content/Multimedia_dataset/"
test_path = "/content/colorization_test_dataset/"
# Depend on runtime setting
use_cuda = True

train_dataset = ColorHintDataset(root_path, 128)
train_dataset.set_mode("training")
train_dataloader = data.DataLoader(train_dataset, batch_size=16, shuffle=True)

val_dataset = ColorHintDataset(root_path, 128)
val_dataset.set_mode("validation")
val_dataloader = data.DataLoader(val_dataset, batch_size=16, shuffle=True)

test_dataset = ColorHintDataset(test_path, 128)
test_dataset.set_mode("testing")
test_dataloader = data.DataLoader(test_dataset, batch_size=1, shuffle=True)

for i, data in enumerate(tqdm.tqdm(train_dataloader)):
  if use_cuda:
    l = data["l"].to('cuda')
    ab = data["ab"].to('cuda')
    hint = data["hint"].to('cuda')
  
  gt_image = torch.cat((l, ab), dim=1)
  hint_image = torch.cat((l, hint), dim=1)

  gt_np = tensor2im(gt_image)
  hint_np = tensor2im(hint_image)

  gt_bgr = cv2.cvtColor(gt_np, cv2.COLOR_LAB2BGR)
  hint_bgr = cv2.cvtColor(hint_np, cv2.COLOR_LAB2BGR)

  '''cv2_imshow(gt_bgr)
  cv2_imshow(hint_bgr)'''


for i, data in enumerate(tqdm.tqdm(val_dataloader)):
  if use_cuda:
    l = data["l"].to('cuda')
    ab = data["ab"].to('cuda')
    hint = data["hint"].to('cuda')
  
  gt_image = torch.cat((l, ab), dim=1)
  hint_image = torch.cat((l, hint), dim=1)

  gt_np = tensor2im(gt_image)
  hint_np = tensor2im(hint_image)

  gt_bgr = cv2.cvtColor(gt_np, cv2.COLOR_LAB2BGR)
  hint_bgr = cv2.cvtColor(hint_np, cv2.COLOR_LAB2BGR)

  '''cv2_imshow(gt_bgr)
  cv2_imshow(hint_bgr)'''


for i, data in enumerate(tqdm.tqdm(test_dataloader)):
  if use_cuda:
    l = data["l"].to('cuda')
    hint = data["hint"].to('cuda')
  
  hint_image = torch.cat((l, hint), dim=1)

  hint_np = tensor2im(hint_image)

  hint_bgr = cv2.cvtColor(hint_np, cv2.COLOR_LAB2BGR)

  # cv2_imshow(hint_bgr)


# Network Construction

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init

class up_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(up_conv,self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
		    nn.BatchNorm2d(ch_out),
			nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.up(x)
        return x

class Recurrent_block(nn.Module):
    def __init__(self,ch_out,t=2):
        super(Recurrent_block,self).__init__()
        self.t = t
        self.ch_out = ch_out
        self.conv = nn.Sequential(
            nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
		    nn.BatchNorm2d(ch_out),
			nn.ReLU(inplace=True)
        )

    def forward(self,x):
        for i in range(self.t):

            if i==0:
                x1 = self.conv(x)
            
            x1 = self.conv(x+x1)
        return x1

class RRCNN_block(nn.Module):
    def __init__(self,ch_in,ch_out,t=2):
        super(RRCNN_block,self).__init__()
        self.RCNN = nn.Sequential(
            Recurrent_block(ch_out,t=t),
            Recurrent_block(ch_out,t=t)
        )
        self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0)

    def forward(self,x):
        x = self.Conv_1x1(x)
        x1 = self.RCNN(x)
        return x+x1

class Attention_block(nn.Module):
    def __init__(self,F_g,F_l,F_int):
        super(Attention_block,self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
            )
        
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self,g,x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1+x1)
        psi = self.psi(psi)

        return x*psi

class UNet(nn.Module):
    def __init__(self,img_ch=3,output_ch=2,t=2):
        super(UNet,self).__init__()
        
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
        self.Upsample = nn.Upsample(scale_factor=2)

        self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t)

        self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t)
        
        self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t)
        
        self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t)
        
        self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t)
        

        self.Up5 = up_conv(ch_in=1024,ch_out=512)
        self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
        self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t)
        
        self.Up4 = up_conv(ch_in=512,ch_out=256)
        self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
        self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t)
        
        self.Up3 = up_conv(ch_in=256,ch_out=128)
        self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
        self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t)
        
        self.Up2 = up_conv(ch_in=128,ch_out=64)
        self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
        self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t)

        self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)


    def forward(self,x):
        # encoding path
        x1 = self.RRCNN1(x)

        x2 = self.Maxpool(x1)
        x2 = self.RRCNN2(x2)
        
        x3 = self.Maxpool(x2)
        x3 = self.RRCNN3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.RRCNN4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.RRCNN5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        x4 = self.Att5(g=d5,x=x4)
        d5 = torch.cat((x4,d5),dim=1)
        d5 = self.Up_RRCNN5(d5)
        
        d4 = self.Up4(d5)
        x3 = self.Att4(g=d4,x=x3)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_RRCNN4(d4)

        d3 = self.Up3(d4)
        x2 = self.Att3(g=d3,x=x2)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_RRCNN3(d3)

        d2 = self.Up2(d3)
        x1 = self.Att2(g=d2,x=x1)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_RRCNN2(d2)

        d1 = self.Conv_1x1(d2)

        return d1

# Network Training

In [None]:
import torch 
import torch.nn as nn
import torch.utils.data as data
import os
import matplotlib.pyplot as plt

print('train dataset length: ', len(train_dataloader))
print('validation dataset length: ', len(val_dataloader))

# 1. Network Setting
net = UNet().cuda()

# 2. Loss ans Optimizer setting
import torch.optim as optim
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters() , lr= 0.001)

# 3. 기타 변수들
train_info = []
val_info = []
object_epoch = 30

save_path = './ColorizationNetwork'
os.makedirs(save_path, exist_ok= True)
output_path = os.path.join(save_path, 'colorization_model.tar')


In [None]:
from skimage.measure.simple_metrics import compare_psnr

def batch_PSNR(img, imclean, data_range):
    Img = img.data.cpu().numpy().astype(np.float32)
    Iclean = imclean.data.cpu().numpy().astype(np.float32)
    PSNR = 0
    for i in range(Img.shape[0]):
        PSNR += compare_psnr(Iclean[i, :, :, :], Img[i, :, :, :], data_range=data_range)
    return (PSNR/Img.shape[0])

In [None]:
import tqdm 

def train_1epoch(net, dataloader):
  total_loss = 0 # 1 epoch loss
  iteration = 1 # iteration number

  net.train() # training mode
  psnr_train = []

  for data in tqdm.auto.tqdm(dataloader):
    # 1. 데이터 준비
    l = data['l']
    ab = data['ab']
    hint = data['hint']

    # use GPU
    if use_cuda: 
      l = data['l'].to('cuda')
      ab = data['ab'].to('cuda')
      hint = data['hint'].to('cuda')

    gt_image = torch.cat((l, ab), dim=1)
    hint_image = torch.cat((l, hint), dim=1)

    # 2. gradient 초기화
    optimizer.zero_grad() # gradient zero

    # 3. 네트워크 결과 얻기 (Forward)
    output = net(hint_image)

    # 4. loss 얻기
    loss = criterion(output, ab)

    # 5. gradient 계산
    loss.backward()

    # 6. gradient 적용
    optimizer.step()

    total_loss += loss.detach() # detach -> 계산 그래프의 분리, detach 그래디언트 계산함
    iteration += 1

    # psnr
    output_image = torch.cat((l, output), dim= 1)
    psnr = batch_PSNR(gt_image, output_image, 1.)
    psnr_train.append(psnr)


  mean_psnr = np.mean(psnr_train)
  total_loss /= iteration

  return mean_psnr, total_loss

In [None]:
def validation_1epoch(net, dataloader):
  total_loss = 0 # 1 epoch loss
  iteration = 1 # iteration number

# validation 을 위한 코드 구조
 # -----------------------------------------------------------------------------
  net.eval() # validation mode
  psnr_val = []

  for data in tqdm.auto.tqdm(dataloader):
    # 1. 데이터 준비
    l = data['l']
    ab = data['ab']
    hint = data['hint']

    # use GPU
    if use_cuda: 
      l = data['l'].to('cuda')
      ab = data['ab'].to('cuda')
      hint = data['hint'].to('cuda')

    gt_image = torch.cat((l, ab), dim=1)
    hint_image = torch.cat((l, hint), dim=1)

    # 3. 네트워크 결과 얻기 (Forward)
    output = net(hint_image)

    # 4. loss 얻기
    loss = criterion(output, ab)

    # psnr
    output_image = torch.cat((l, output), dim= 1)
    psnr = batch_PSNR(gt_image, output_image, 1.)
    psnr_val.append(psnr)

 # -----------------------------------------------------------------------------

    # detach(): Tensor 기울기 계산 그래프에서 제거
    total_loss += loss.detach() # detach -> 계산 그래프의 분리, detach 그래디언트 계산함
    iteration += 1
  mean_psnr = np.mean(psnr_val)
  total_loss /= iteration

  return mean_psnr, total_loss

In [None]:
best_psnr = 0

for epoch in range(object_epoch):
  train_psnr , train_loss = train_1epoch(net, train_dataloader)
  print('[Training] Epoch {}: PSNR: {}, loss: {}'.format(epoch, train_psnr, train_loss))
  train_info.append({
      'loss': train_loss,
      'psnr': train_psnr
  }
  )

  '''with torch.no_grad(): # gradient 를 계산하지 않겠다 
    val_psnr, val_loss = validation_1epoch(net, val_dataloader)
  print('[Validation] Epoch {}: psnr: {}, loss: {}'.format(epoch, val_psnr, val_loss))
  val_info.append({
      'loss': val_loss,
      'psnr': val_psnr
  }
  )

  if best_psnr < val_psnr:
    best_psnr = val_psnr
    torch.save({
      'memo': 'Colorization Model',
      'psnr': val_psnr,
      'loss': val_loss,
      'model_weight': net.state_dict()
  }, output_path)'''

In [None]:
import numpy as np
import matplotlib.pyplot as plt

epoch_axis = np.arange(0, object_epoch)
plt.title('PSNR')
'''plt.plot(epoch_axis, [info['psnr'] for info in train_info], epoch_axis, [info['psnr'] for info in val_info], 'r-')
plt.legend(['Train', 'Validation'])'''
plt.plot(epoch_axis, [info['psnr'] for info in train_info], 'r-')
plt.legend(['Train'])
plt.show()

print("\n")

plt.figure()
plt.title('Loss')
'''plt.plot(epoch_axis, [info['loss'] for info in train_info], epoch_axis, [info['loss'] for info in val_info], 'r-')
plt.legend(['Train', 'Validation'])'''
plt.plot(epoch_axis, [info['loss'] for info in train_info],'r-')
plt.legend(['Train'])
plt.show()

# Model Testing

In [None]:
import torch
import os

model_path = os.path.join(save_path, 'colorization_model.tar')
state_dict = torch.load(model_path)

print(state_dict['memo'])
print(state_dict.keys())
print(state_dict['loss'])

net = UNet().cuda()
net.load_state_dict(state_dict['model_weight'], strict= True)

In [None]:
for i, data in enumerate(tqdm.tqdm(train_dataloader)):
  if use_cuda:
    l = data["l"].to('cuda')
    ab = data["ab"].to('cuda')
    hint = data["hint"].to('cuda')
      
  gt_image = torch.cat((l, ab), dim=1)
  hint_image = torch.cat((l, hint), dim=1)
  output_image = net(hint_image)
  output_image = torch.cat((l,output_image),dim=1)

  gt_np = tensor2im(gt_image)
  hint_np = tensor2im(hint_image)
  output_np = tensor2im(output_image)

  gt_bgr = cv2.cvtColor(gt_np, cv2.COLOR_LAB2BGR)
  hint_bgr = cv2.cvtColor(hint_np, cv2.COLOR_LAB2BGR)
  output_bgr = cv2.cvtColor(output_np, cv2.COLOR_LAB2BGR)
      
  cv2_imshow(gt_bgr)
  cv2_imshow(hint_bgr)
  cv2_imshow(output_bgr)

  # input()

In [None]:
import tqdm

net = UNet().cuda()
net.load_state_dict(state_dict['model_weight'], strict= True)

# os.makedirs('/content/output', exist_ok=True)
os.makedirs('/content/drive/MyDrive/output', exist_ok=True)
result_path = '/content/drive/MyDrive/output/'

def test_1epoch(net, dataloader):
  net.eval()

  for sample in tqdm.auto.tqdm(dataloader):
    if use_cuda:
      l = sample["l"].to('cuda')
      hint = sample["hint"].to('cuda')
      file_name = sample["file_name"]

    hint_image = torch.cat((l, hint), dim=1)
    output_image = net(hint_image)
    output_image = torch.cat((l,output_image),dim=1)

    hint_np = tensor2im(hint_image)
    output_np = tensor2im(output_image)

    hint_bgr = cv2.cvtColor(hint_np, cv2.COLOR_LAB2BGR)
    output_bgr = cv2.cvtColor(output_np, cv2.COLOR_LAB2BGR)

    cv2_imshow(hint_bgr)
    cv2_imshow(output_bgr)

    # cv2.imwrite(result_path+file_name[0], output_bgr)


res = test_1epoch(net, test_dataloader)

# Model Testing - Test Set

In [None]:
import tqdm

net = UNet().cuda()
net.load_state_dict(state_dict['model_weight'], strict= True)

os.makedirs('/content/drive/MyDrive/output', exist_ok=True)
result_path = '/content/drive/MyDrive/output/'

def test_1epoch(net, dataloader):
  net.eval()

  for sample in tqdm.auto.tqdm(dataloader):
    if use_cuda:
      l = sample["l"].to('cuda')
      hint = sample["hint"].to('cuda')
      file_name = sample["file_name"]

    hint_image = torch.cat((l, hint), dim=1)
    output_image = net(hint_image)
    output_image = torch.cat((l,output_image),dim=1)

    hint_np = tensor2im(hint_image)
    output_np = tensor2im(output_image)

    hint_bgr = cv2.cvtColor(hint_np, cv2.COLOR_LAB2BGR)
    output_bgr = cv2.cvtColor(output_np, cv2.COLOR_LAB2BGR)

    cv2_imshow(hint_bgr)
    cv2_imshow(output_bgr)

    cv2.imwrite(result_path+file_name[0], output_bgr)

res = test_1epoch(net, test_dataloader)