# CS470 - Introduction to Artificial Intelligence
## Project : Coloring black & white images and video

Authors: Ayoub Mellah 20196411, Quentin Nieloud 20196414, Malek Neila Rostom 20196507, Pablo Chabance 20196417



---



####Connection to Drive

In [0]:
from google.colab import drive

drive.mount('/gdrive')
gdrive_root = '/gdrive/My Drive'
gdrive_data = '/gdrive/My Drive/IA - Colorize'

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /gdrive


####Import librairies

In [0]:
import torch
import torchvision.models as models
from PIL import Image as image_pil
import torchvision.transforms as transforms 
import torch.nn as nn 
from skimage import io, color
from skimage.color import rgb2lab, lab2rgb, rgb2gray, gray2rgb
from skimage.io import imsave
from torchvision import datasets
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import os
import copy
import glob
from torchvision.models import resnet152
from torchvision import transforms
from PIL import Image
from torch.utils.data import DataLoader

#### Hyper-Parameters


In [0]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

max_epoch = 50
batch_size = 10
learning_rate = 0.001

####Construct Data Pipeline


In [0]:
class GrayscaleImageFolder(datasets.ImageFolder):
  def __getitem__(self, index):
    path, target = self.imgs[index]
    img = self.loader(path)
    if self.transform is not None:
      img_l = self.transform(img)
      img_embed = img.resize((299, 299))
      img_embed = np.asarray(img_embed)
      img_embed = gray2rgb(rgb2gray(img_embed))
      img_embed = torch.from_numpy(img_embed).unsqueeze(0).float()
      img_l = np.asarray(img_l)
      img_lab = rgb2lab(img_l)
      img_lab = (img_lab + 128) / 255 #a voir
      img_label = img_lab[:, :, 1:]
      img_label = torch.from_numpy(img_label.transpose((2, 0, 1))).float()
      img_l = rgb2gray(img_l)
      img_l = torch.from_numpy(img_l).unsqueeze(0).float()
    if self.target_transform is not None:
      target = self.target_transform(target)
    return img_l, img_embed, img_label, target

# Training 
traindir = os.path.join(gdrive_data, 'train_data')

train_transforms = transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip()])
train_imagefolder = GrayscaleImageFolder(traindir, train_transforms)
train_loader = DataLoader(train_imagefolder, batch_size=batch_size, shuffle=True)

# Validation
testdir = os.path.join(gdrive_data, 'test_data')

test_transforms = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224)])
test_imagefolder = GrayscaleImageFolder(testdir , test_transforms)
test_loader = torch.utils.data.DataLoader(test_imagefolder, batch_size=10, shuffle=False)

# Create generators for the Color Network
resnet = resnet152(pretrained=True, progress=True)

def training_generator():
  for batch_l, batch_emb, labels, _ in train_loader:
    batch_emb = batch_emb.permute(0,4,2,3,1).squeeze(4)
    embed = resnet(batch_emb)
    yield([batch_l, embed], labels)

def testing_generator():
  for batch_l, batch_emb, labels, _ in test_loader:
    batch_emb = batch_emb.permute(0,4,2,3,1).squeeze(4)
    embed = resnet(batch_emb)
    yield([batch_l, embed], labels)


Downloading: "https://download.pytorch.org/models/resnet152-b121ed2d.pth" to /root/.cache/torch/checkpoints/resnet152-b121ed2d.pth
100%|██████████| 230M/230M [00:03<00:00, 64.5MB/s]


#### ColorNet Model Architecture
##### Composed of one Encoder, one Decoder and one Fusion Network

In [0]:
# ENCODER
class Encoder(nn.Module):
  def __init__(self):
    super(Encoder, self).__init__()

    def block(input_size, output_size, stride=False):
      layers = [nn.Conv2d(in_channels=input_size, out_channels=output_size, kernel_size=3, stride=2, padding=1)] if stride \
      else [nn.Conv2d(in_channels=input_size, out_channels=output_size, kernel_size=3, padding=1)]
      
      layers.append(nn.ReLU())
      return layers

    self.model = nn.Sequential(
      *block(1, 64, True),
      *block(64, 128),
      *block(128, 128, True),
      *block(128, 256),
      *block(256, 256, True),
      *block(256, 512),
      *block(512, 512),
      *block(512, 256)
    )

  def forward(self, x):
    return self.model(x)

In [0]:
# FUSION
class Fusion(nn.Module):
  def __init__(self):
    super(Fusion, self).__init__()

    self.model = nn.Sequential(
      nn.Conv2d(1256, 256, 1),
      nn.ReLU()
    )

  def forward(self, encoder_output, embed):
    base = torch.zeros(10, 1000, 1, 1)
    output = embed.unsqueeze(2).unsqueeze(3)
    output[:, :, 0:, :] = base
    output[:, :, :, 0:] = base
    
    #output = embed.unsqueeze(2).unsqueeze(3)
    output = output.repeat(1, 1, 28, 28) 
    output = torch.cat((encoder_output, output), 1)
    return self.model(output)

In [0]:
# DECODER
class Decoder(nn.Module):
  def __init__(self):
    super(Decoder, self).__init__()

    def block(input_size, output_size, tanh=False):
      layers = [nn.Conv2d(in_channels=input_size, out_channels=output_size, kernel_size=3, padding=1)]
      layers.append(nn.Tanh()) if tanh else layers.append(nn.ReLU())
      return layers

    self.model = nn.Sequential(
      *block(256, 128),
      nn.Upsample(scale_factor=(2,2)),
      *block(128, 64),
      nn.Upsample(scale_factor=(2,2)),
      *block(64, 32),
      *block(32, 16),
      *block(16, 2, tanh=True),
      nn.Upsample(scale_factor=(2,2))
    )

  def forward(self, x):
    return self.model(x)

In [0]:
# COLORNET
class ColorNet(nn.Module):
  def __init__(self):
    super(ColorNet, self).__init__()

    self.encoder = Encoder()
    self.fusion = Fusion()
    self.decoder = Decoder()

  def forward(self, x, embed):
    x = self.encoder(x)
    x = self.fusion(x, embed)
    x = self.decoder(x)
    return x

#### Training ColorNet

In [0]:
ckpt_file = os.path.join(gdrive_data, 'ckpt')

net = ColorNet().to(device)
optim = optim.Adam(net.parameters(), learning_rate, weight_decay=0)

train_losses = []

for epoch in range(max_epoch):
  net.train()
  for inputs, labels in training_generator():

    enc_inputs = inputs[0].to(device)
    embed_outputs = inputs[1].to(device)
    labels = labels.to(device)

    prediction = net(enc_inputs, embed_outputs)

    loss = F.mse_loss(prediction, labels)

    optim.zero_grad()
    loss.backward()
    optim.step()

    print('[Epoch:{}/{}] Train Loss:{:.4f}'.format(epoch, max_epoch, loss.item()))

  train_losses.append(loss)
  torch.save(net.state_dict(), ckpt_file + '/latest.pt')
  if epoch % 10:
    torch.save(net.state_dict(), ckpt_file + '/latest.pt')

[Epoch:0/50] Train Loss:0.2554
[Epoch:1/50] Train Loss:0.2386
[Epoch:2/50] Train Loss:0.1816
[Epoch:3/50] Train Loss:0.1730
[Epoch:4/50] Train Loss:0.0303
[Epoch:5/50] Train Loss:0.0766
[Epoch:6/50] Train Loss:0.0274
[Epoch:7/50] Train Loss:0.0243
[Epoch:8/50] Train Loss:0.0444
[Epoch:9/50] Train Loss:0.0190
[Epoch:10/50] Train Loss:0.0476
[Epoch:11/50] Train Loss:0.0230
[Epoch:12/50] Train Loss:0.0328
[Epoch:13/50] Train Loss:0.0170
[Epoch:14/50] Train Loss:0.0230
[Epoch:15/50] Train Loss:0.0188
[Epoch:16/50] Train Loss:0.0195
[Epoch:17/50] Train Loss:0.0176
[Epoch:18/50] Train Loss:0.0151
[Epoch:19/50] Train Loss:0.0168
[Epoch:20/50] Train Loss:0.0140
[Epoch:21/50] Train Loss:0.0154
[Epoch:22/50] Train Loss:0.0132
[Epoch:23/50] Train Loss:0.0137
[Epoch:24/50] Train Loss:0.0127
[Epoch:25/50] Train Loss:0.0151
[Epoch:26/50] Train Loss:0.0123
[Epoch:27/50] Train Loss:0.0146
[Epoch:28/50] Train Loss:0.0118
[Epoch:29/50] Train Loss:0.0136
[Epoch:30/50] Train Loss:0.0115
[Epoch:31/50] Trai

#### Testing ColorNet

In [0]:
ckpt_file = os.path.join(gdrive_data, 'ckpt')

net = ColorNet().to(device)
net.load_state_dict(torch.load(ckpt_file + '/latest.pt'))

for inputs, labels in testing_generator():
  enc_inputs = inputs[0].to(device)
  embed_outputs = inputs[1].to(device)
  labels = labels.to(device)

  prediction = net(enc_inputs, embed_outputs)
  
  prediction = prediction

  
  for i in range(len(prediction)):
    cur = np.zeros((224, 224, 3))
    tmp = inputs[0][i].permute(1,2,0)
    tmp = tmp * 100
    cur[:,:,0] = tmp[:,:,0]
    cur[:,:,1:] = prediction[i].cpu().detach().permute(1,2,0) * 255 - 128
    imsave(gdrive_data + "/result/img_"+str(i)+".png", lab2rgb(cur))

