<a href="https://colab.research.google.com/github/Hyojinko/2022_CV_Project/blob/main/Colorization_ResUNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive  
drive.mount('/content/drive/')


Mounted at /content/drive/


In [2]:
import zipfile
import os
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torch.utils import data
from torch.autograd import Variable
from torchvision import transforms
from torchvision.transforms import Compose, ToTensor, ToPILImage
import cv2
import random
import numpy as np
from google.colab.patches import cv2_imshow
import tqdm
from PIL import Image


In [3]:
zip_path = '/content/drive/MyDrive/ComputerVision/colorization_dataset.zip'
file_name = 'colorization_dataset.zip'
!cp "{zip_path}" .

!unzip -q '{file_name}'
!rm '{file_name}'

In [4]:
import torch
from torch.autograd import Variable
from torchvision import transforms

import cv2
import random
import numpy as np

class ColorHintTransform(object):
  def __init__(self, size=256, mode="train"):
    super(ColorHintTransform, self).__init__()
    self.size = size
    self.mode = mode
    self.transform = transforms.Compose([transforms.ToTensor()])

  def bgr_to_lab(self, img):
    lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
    l, ab = lab[:, :, 0], lab[:, :, 1:]
    return l, ab

  def hint_mask(self, bgr, threshold=[0.95, 0.97, 0.99]):
    h, w, c = bgr.shape
    mask_threshold = random.choice(threshold) # 3 threshold random choice
    mask = np.random.random([h, w, 1]) > mask_threshold # Create a mask with only hint values
    return mask

  def img_to_mask(self, mask_img):
    mask = mask_img[:, :, 0, np.newaxis] >= 255
    return mask

  def __call__(self, img, mask_img=None):
    threshold = [0.95, 0.97, 0.99]
    if (self.mode == "train") | (self.mode == "val"):
      image = cv2.resize(img, (self.size, self.size))
      mask = self.hint_mask(image, threshold)

      hint_image = image * mask # hint_image we know

      l, ab = self.bgr_to_lab(image) # split image into l and ab
      l_hint, ab_hint = self.bgr_to_lab(hint_image) # split hint_image into l and ab
      return self.transform(l), self.transform(ab), self.transform(ab_hint), self.transform(mask) # l, ab, ab_hint, mask transform apply # Add mask


    elif self.mode == "test":
      image = cv2.resize(img, (self.size, self.size))
      mask = self.img_to_mask(mask_img)
      hint_image = image * self.img_to_mask(mask_img)
      l, _ = self.bgr_to_lab(image)
      _, ab_hint = self.bgr_to_lab(hint_image)
      
      return self.transform(l), self.transform(ab_hint), self.transform(mask)

    else:
      return NotImplementedError



In [5]:
import torch
import torch.utils.data  as data
import os
import cv2
from google.colab.patches import cv2_imshow

class ColorHintDataset(data.Dataset):
  def __init__(self, root_path, size):
    super(ColorHintDataset, self).__init__()

    self.root_path = root_path
    self.size = size
    self.transforms = None
    self.examples = None
    self.hint = None
    self.mask = None

  def set_mode(self, mode):
    self.mode = mode
    self.transforms = ColorHintTransform(self.size, mode)
    if mode == "train":
      train_dir = os.path.join(self.root_path, "train")
      self.examples = [os.path.join(self.root_path, "train", dirs) for dirs in os.listdir(train_dir)]
    elif mode == "val":
      val_dir = os.path.join(self.root_path, "val")
      self.examples = [os.path.join(self.root_path, "val", dirs) for dirs in os.listdir(val_dir)]
    elif self.mode == "test":
            hint_dir = os.path.join(self.root_path, "hint")
            mask_dir = os.path.join(self.root_path, "mask")
            self.hint = [os.path.join(self.root_path, "hint", dirs) for dirs in os.listdir(hint_dir)]
            self.mask = [os.path.join(self.root_path, "mask", dirs) for dirs in os.listdir(mask_dir)]
    else:
        raise NotImplementedError
  def __len__(self):
    if self.mode != "test":
      return len(self.examples)
    else:
      return len(self.hint)
  def __getitem__(self, idx):
    if self.mode == "test":
      hint_file_name = self.hint[idx]
      mask_file_name = self.mask[idx]
      hint_img = cv2.imread(hint_file_name)
      mask_img = cv2.imread(mask_file_name)

      input_l, input_hint, input_mask = self.transforms(hint_img, mask_img)
      sample = {"l": input_l, "hint": input_hint,"mask": input_mask,
                      "file_name": "image_%06d.png" % int(os.path.basename(hint_file_name).split('.')[0])}
    else:
      file_name = self.examples[idx]
      img = cv2.imread(file_name)
      l, ab, hint, mask = self.transforms(img) # Add mask
      sample = {"l": l, "ab": ab, "hint": hint, "mask": mask} # Add mask

    return sample

In [6]:
import torch
import torch.utils.data  as data
import os
import cv2
from google.colab.patches import cv2_imshow
import matplotlib.pyplot as plt
from torchvision import transforms
import tqdm
from PIL import Image
import numpy as np

def tensor2im(input_image, imtype=np.uint8): # Tensor type -> image type
  if isinstance(input_image, torch.Tensor):
      image_tensor = input_image.data
  else:
      return input_image
  image_numpy = image_tensor[0].cpu().float().numpy()
  if image_numpy.shape[0] == 1:
      image_numpy = np.tile(image_numpy, (3, 1, 1))
  image_numpy = np.clip((np.transpose(image_numpy, (1, 2, 0)) ),0, 1) * 255.0
  return image_numpy.astype(imtype)

# Change to your data root directory
root_path = "/content/cv_project"
# Depend on runtime setting
use_cuda = True

# Get dataset
train_dataset = ColorHintDataset(root_path, 256)
train_dataset.set_mode("train")
print('Train length : ', len(train_dataset))
train_dataloader = data.DataLoader(train_dataset, batch_size=4, shuffle=True)

val_dataset = ColorHintDataset(root_path, 256)
val_dataset.set_mode("val")
print('Validation length : ', len(val_dataset))
val_dataloader = data.DataLoader(val_dataset, batch_size=4, shuffle=False)


Train length :  10000
Validation length :  2000


# **Define ResUNet**


---



In [7]:
import torch
import torch.nn as nn

class upsample(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(upsample, self).__init__()

        self.up = nn.Sequential(
            nn.Conv2d(ch_in,ch_out,3,1,1),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True ),
            nn.ConvTranspose2d(ch_out , ch_out , 3,2,1,1),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True )
        )

    def forward(self, x):
        x = self.up(x)

        return x



In [8]:
class ResidualBlock(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(ResidualBlock, self).__init__()

        self.RCNN = nn.Sequential(
            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out)
        )
        self.Conv_1x1 = nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        x = self.Conv_1x1(x)

        x1 = self.RCNN(x)
        return x + x1

In [9]:
class AttentionBlock(nn.Module):
    def __init__(self, upsample, downsample, ch_result):
        super(AttentionBlock, self).__init__()

        self.skip = nn.Sequential(
            nn.Conv2d(upsample, ch_result, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(ch_result)
        )

        self.up = nn.Sequential(
            nn.Conv2d(downsample, ch_result, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(ch_result)
        )

        self.concat = nn.Sequential(
            nn.Conv2d(ch_result, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        skip1 = self.skip(g)

        up1 = self.up(x)
        attention = self.relu(skip1 + up1)
        attention = self.concat(attention)

        return x * attention

In [10]:

class ResUNet(nn.Module):
    def __init__(self, img_ch=4, output_ch=3):
        super(ResUNet, self).__init__()

        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Dsample1 = ResidualBlock(ch_in=img_ch, ch_out=64 )
        self.Dsample1_1 = ResidualBlock(ch_in=64 , ch_out=64)

        self.Dsample2 = ResidualBlock(ch_in=64, ch_out=128)
        self.Dsample2_1 = ResidualBlock(ch_in=128, ch_out=128)

        self.Dsample3 = ResidualBlock(ch_in=128, ch_out=256)
        self.Dsample3_1 = ResidualBlock(ch_in=256, ch_out=256)

        self.Dsample4 = ResidualBlock(ch_in=256, ch_out=512 )
        self.Dsample4_1 = ResidualBlock(ch_in=512, ch_out=512)

        self.Dsample5 = ResidualBlock(ch_in=512, ch_out=1024)
        self.Dsample5_1 = ResidualBlock(ch_in=1024, ch_out=1024)

        self.Up5 = upsample(ch_in=1024, ch_out=512)
        self.Attention5 = AttentionBlock(upsample=512, downsample=512, ch_result=256)
        self.Usample5 = ResidualBlock(ch_in=1024, ch_out=512)
        self.Usample5_1 = ResidualBlock(ch_in=512, ch_out=512)

        self.Up4 = upsample(ch_in=512, ch_out=256)
        self.Attention4 = AttentionBlock(upsample=256, downsample=256, ch_result=128)
        self.Usample4 = ResidualBlock(ch_in=512, ch_out=256 )
        self.Usample4_1 = ResidualBlock(ch_in=256, ch_out=256)

        self.Up3 = upsample(ch_in=256, ch_out=128)
        self.Attention3 = AttentionBlock(upsample=128, downsample=128, ch_result=64)
        self.Usample3 = ResidualBlock(ch_in=256, ch_out=128 )
        self.Usample3_1 = ResidualBlock(ch_in=128, ch_out=128)

        self.Up2 = upsample(ch_in=128, ch_out=64)
        self.Attention2 = AttentionBlock(upsample=64, downsample=64, ch_result=32)
        self.Usample2 = ResidualBlock(ch_in=128, ch_out=64)
        self.Usample2_1 = ResidualBlock(ch_in=64, ch_out=64)

        self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)

    def forward(self, x):

        # encoding path
        x1 = self.Dsample1(x)
        x1 = self.Dsample1_1(x1)

        x2 = self.Maxpool(x1)
        x2 = self.Dsample2(x2)
        x2 = self.Dsample2_1(x2)

        x3 = self.Maxpool(x2)
        x3 = self.Dsample3(x3)
        x3 = self.Dsample3_1(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Dsample4(x4)
        x4 = self.Dsample4_1(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Dsample5(x5)
        x5 = self.Dsample5_1(x5)

        d5 = self.Up5(x5)
        x4 = self.Attention5(g=d5, x=x4)
        d5 = torch.cat((x4, d5), dim=1)
        d5 = self.Usample5(d5)
        d5 = self.Usample5_1(d5)

        d4 = self.Up4(d5)
        x3 = self.Attention4(g=d4, x=x3)
        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.Usample4(d4)
        d4 = self.Usample4_1(d4)

        d3 = self.Up3(d4)
        x2 = self.Attention3(g=d3, x=x2)
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.Usample3(d3)
        d3 = self.Usample3_1(d3)

        d2 = self.Up2(d3)
        x1 = self.Attention2(g=d2, x=x1)
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.Usample2(d2)
        d2 = self.Usample2_1(d2)

        d1 = self.Conv_1x1(d2)

        return d1

In [11]:
def train_1epoch(net, train_dataloader):
  epoch = 0
  total_loss = 0
  for i, data in enumerate(tqdm.auto.tqdm(train_dataloader)):
    l = data['l'].to('cuda')
    ab = data['ab'].to('cuda')
    hint  = data['hint'].to('cuda')
    mask = data["mask"].cuda()

    gt_img = torch.cat((l, ab), dim=1).cuda()
    hint_img = torch.cat((l, hint, mask), dim=1).cuda()
    output_hint = net(hint_img)
    loss = criterion(output_hint, gt_img)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    total_loss += loss.detach()
    epoch += 1

 

  total_loss /= epoch
  total_loss = total_loss.cpu()
  total_loss = total_loss.numpy()
  return  total_loss


In [12]:
def val_1epoch(net, val_dataloader):
  total_loss = 0
  epoch = 0
  net.eval()
  i = 0
  for i, data in enumerate(tqdm.auto.tqdm(val_dataloader)):
    l = data['l'].to('cuda')
    ab = data['ab'].to('cuda')
    hint = data['hint'].to('cuda')
    mask = data['mask'].to('cuda')

    gt_img = torch.cat((l, ab), dim=1).cuda()

    hint_img = torch.cat((l, hint, mask), dim=1).cuda()
    output_hint = net(hint_img)
    loss = criterion(output_hint, gt_img)
    total_loss += loss.detach()


    gt_np = tensor2im(gt_img)
    hint_np = tensor2im(output_hint)
  
    

    gt_bgr = cv2.cvtColor(gt_np, cv2.COLOR_LAB2RGB)

    hint_bgr = cv2.cvtColor(hint_np, cv2.COLOR_LAB2RGB)

    os.makedirs('/content/cv_project/label',exist_ok=True)
    cv2.imwrite('/content/cv_project/label/gt_'+str(epoch+1)+'.jpg',gt_bgr)

    os.makedirs('/content/cv_project/predictions',exist_ok=True)
    cv2.imwrite('/content/cv_project/predictions/pred_'+str(epoch+1)+'.jpg',hint_bgr)

    epoch += 1
  total_loss /= epoch
  total_loss = total_loss.cpu()
  total_loss = total_loss.numpy()
  cv2_imshow(gt_bgr)
  cv2_imshow(hint_bgr)
  return  total_loss

best_losses = 10
net = ResUNet().cuda()

 

In [13]:
save_path = './Result'
os.makedirs(save_path, exist_ok=True)
output_path = os.path.join(save_path, 'validation_model.tar')


# **Training ResUNet**


---



In [None]:
from torch.nn.functional import mse_loss as mse   

net = ResUNet().cuda()
lr = 0.0001
object_epoch = 10


criterion = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=lr)

def psnr(input: torch.Tensor, target: torch.Tensor, max_val: float) -> torch.Tensor:
  if not isinstance(input, torch.Tensor):
      raise TypeError(f"Expected torch.Tensor but got {type(target)}.")

  if not isinstance(target, torch.Tensor):
      raise TypeError(f"Expected torch.Tensor but got {type(input)}.")

  if input.shape != target.shape:
      raise TypeError(f"Expected tensors of equal shapes, but got {input.shape} and {target.shape}")

  return 10. * torch.log10(max_val ** 2 / mse(input, target, reduction='mean'))

def psnr_loss(input: torch.Tensor, target: torch.Tensor, max_val: float) -> torch.Tensor:
  return -1. * psnr(input, target, max_val)

# ================== class PSNRLoss ================== 

class PSNRLoss(nn.Module):
  def __init__(self, max_val: float) -> None:
    super(PSNRLoss, self).__init__()
    self.max_val: float = max_val

  def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    return psnr_loss(input, target, self.max_val)

# ====================================================

criterion = PSNRLoss(2.)
# criterion = nn.BCELoss()
# criterion = nn.MSELoss()
# criterion = nn.BCEWithLogitsLoss()
# criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(net.parameters(), lr=0.00025) # 1e-2 # 0.0005 # 0.00025 # 0.0002
epochs = 150 
best_losses = 10


net.cuda()

val_losses = 100
train_info = []
val_info = []

for epoch in range(object_epoch):
  train_loss = train_1epoch(net, train_dataloader)
  print('Epoch: {} loss: {}'.format(epoch, train_loss))
  train_info.append({'loss':train_loss})

  with torch.no_grad():
    val_loss = val_1epoch(net, val_dataloader)
  print('[VALIDATION] Epoch: {} loss: {}'.format(epoch, val_loss))
  val_info.append({'loss':val_loss})

  if best_losses > val_loss:
    best_losses = val_loss
    torch.save(net.state_dict(), os.path.join(save_path,'PSNR-epoch-{}-losses-{:.5f}.pth'.format(epoch + 1, best_losses)))
    





  0%|          | 0/2500 [00:00<?, ?it/s]

# **Testing**


---



In [None]:
file_name = 'test_dataset.zip'
zip_path = os.path.join('/content/drive/MyDrive/ComputerVision/test_dataset.zip')

!cp '{zip_path}' .
!unzip -q '{file_name}'

In [None]:
from google.colab import drive
from PIL import Image
def image_save(img, path):
  if isinstance(img, torch.Tensor):
    img = np.asarray(transforms.ToPILImage()(img))
  
  img = cv2.cvtColor(img, cv2.COLOR_LAB2BGR)
  cv2.imwrite(path, img)

check_point = ''
use_cuda = True

test_dataset = ColorHintDataset(root_path, 256)
test_dataset.set_mode('test')
test_dataloader = data.DataLoader(test_dataset, batch_size=1, shuffle=False)

net = ResUNet().cuda()
net.load_state_dict(torch.load(check_point))
os.makedirs('outputs/predict', exist_ok=True)
os.makedirs('outputs/predict/test', exist_ok=True)
def test(model, test_dataloader):
  model.eval()
  for i, data in enumerate(test_dataloader):
    l = data['l'].cuda()
    hint = data['hint'].to('cuda')
    mask = data['mask'].to('cuda')
    file_name = data['file_name']
    with torch.no_grad():
      hint_img = torch.cat((l, hint, mask), dim=1)
      output_hint = net(hint_img)
      out_hint_np = tensor2im(output_hint)
      output_bgr = cv2.cvtColor(out_hint_np, cv2.COLOR_LAB2BGR)
      fname = str(file_name).replace("[",'')
      fname = fname.replace("']", '')
      cv2.imwrite('outputs/predict/test/'+str(fname),output_bgr)

test(net, test_dataloader)

Test length : 1000
