# Imports

In [1]:
import os
import torch
import torch.nn as nn
import numpy as np
from IPython import embed
from skimage import color
from PIL import Image
import ssl
from torchvision import transforms
from torchvision.datasets import Caltech256
import torch.nn.functional as F
#from torch.utils.data import Dataset
#from torch.utils.data import Dataloader

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Data

In [4]:
#Pre processing of data to index rbg training data

def preProcessing(path):
  if not os.path.exists(path + "/train"):
    os.mkdir(path + "/train")
    os.mkdir(path + "/train/rgb")

  idx=0
  for filename in os.listdir(path):
    if not os.path.isdir(path + "/" + filename):
      
      #get RGB image, resize, and put in rgb folder
      rgb_img=Image.open(path + "/" + filename)
      rgb_img=rgb_img.resize((w, h))
      rgb_img.save(path + "/train/target/" + str(idx) + "_rgb.png")

      idx+=1


In [None]:
#Scafolding for Dataset

def myDataset(h=256, w=256, rgb2lab=True, zhangmodel=True, edmodel=False):
  def __init__(self, img_dir, transform):
    self.img_dir=img_dir
    self.transform=transform
    self.dataset_length = len(os.listdir(img_dir + "/train/rgb"))
    
    #Set up dataset so images are indexed
    preProcessing(img_dir, h, w)

    if zhangmodel:
      print("ZHANG MODEL")

    elif edmodel:
      print("ED MODEL")

    def __len__(self):
      return self.dataset_length

    def __getitem__(self, idx):
      #Get RGB image, resize, and transform it
      rgb_img=Image.open(self.img_dir + "/" + str(idx) + "_rgb.png")
      rgb_img=self.transform(rgb_img.resize((w, h)))

      #Get CIE-Lab equivalent of the RGB image
      clab_array=color.rgb2lab(np.asarray(rgb_img))
      clab_img=Image.fromarray(clab_array, mode="LAB")

      #Get greyscale CIE-Lab equivalent of the RGB image
      clab_grayimg=Image.fromarray(clab_array[:, :, 0], mode="LAB")

      data={'input':clab_grayimg, 'target':clab_img}

      return data
    

# Model 1 ("Colorful Image Colorization" by Zhang et al.)

In [None]:
class BaseModel(nn.Module):
  '''
  A 8-blocks cnn model, each block has multiple cnn layer (22 in total)
  "prediction" in CIELAB space (L, a, b)
  For this model, it takes in "grayscale image" with only L value
  and it outputs a and b values
  '''
  def __init__(self, norm_layer = nn.BatchNorm2d):
    super(BaseModel, self).__init__()
    # layer 1
    self.layer1 = nn.Sequential([
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=True),
            nn.ReLU(True),
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=True),
            nn.ReLU(True),
            norm_layer(64)])

    # layer 2
    self.layer2 = nn.Sequential([
        nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),
        nn.ReLU(True),
        nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=True),
        nn.ReLU(True),
        norm_layer(128)])

    # layer 3
    self.layer3 = nn.Sequential([
        nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),
        nn.ReLU(True),
        nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),
        nn.ReLU(True),
        nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=True),
        nn.ReLU(True),
        norm_layer(256)])

    # layer 4
    self.layer4 = nn.Sequential([
        nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),
        nn.ReLU(True),
        nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),
        nn.ReLU(True),
        nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),
        nn.ReLU(True),
        norm_layer(512)])

    # layer 5
    self.layer5 = nn.Sequential([
        nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),
        nn.ReLU(True),
        nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),
        nn.ReLU(True),
        nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),
        nn.ReLU(True),
        norm_layer(512)])

    #layer 6
    self.layer6 = nn.Sequential([
        nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),
        nn.ReLU(True),
        nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),
        nn.ReLU(True),
        nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),
        nn.ReLU(True),
        norm_layer(512),])

    #layer 7
    self.layer7 = nn.Sequential([
        nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),
        nn.ReLU(True),
        nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),
        nn.ReLU(True),
        nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),
        nn.ReLU(True),
        norm_layer(512),])

    #layer 8
    self.layer8 = nn.Sequential([
        nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True),
        nn.ReLU(True),
        nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),
        nn.ReLU(True),
        nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),
        nn.ReLU(True),
        nn.Conv2d(256, 313, kernel_size=1, stride=1, padding=0, bias=True),])

    self.softmax = nn.Softmax(dim=1)
    # 2 means (a, b)
    self.model_out = nn.Conv2d(313, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=False)
    self.upsample = nn.Upsample(scale_factor=4, mode='bilinear')

  def forward(self, input_l):
    # model
    conv1_2 = self.layer1(self.normalize_l(input_l))
    conv2_2 = self.layer2(conv1_2)
    conv3_3 = self.layer3(conv2_2)
    conv4_3 = self.layer4(conv3_3)
    conv5_3 = self.layer5(conv4_3)
    conv6_3 = self.layer6(conv5_3)
    conv7_3 = self.layer7(conv6_3)
    conv8_3 = self.layer8(conv7_3)

    out_reg = self.model_out(self.softmax(conv8_3))

    # this is deal with nomalization
    # output is in [0,1] (ratio of a, b to L)
    # L ususally has a range [0, 100] (or 110)
    return 100*(self.upsample(out_reg))

# Model 1 Loss

In [None]:
def RebalanceLoss

In [None]:
def GetClassWeights

# Model 2: Encoder Decoder


In [3]:
%%capture
!pip install pretrainedmodels

In [4]:
import pretrainedmodels
from pretrainedmodels import utils

In [5]:
class Encoder(nn.Module):
  def __init__(self):
    super(Encoder, self).__init__()
      
    #what is the padding?
    self.layer1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=2, padding=1)

    self.layer2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)

    self.layer3 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1)

    self.layer4 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)

    self.layer5 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1)

    self.layer6 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1)

    self.layer7 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1)

    self.layer8 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1)

  def forward(self, x):
    x=F.relu(self.layer1(x))
    x=F.relu(self.layer2(x))
    x=F.relu(self.layer3(x))
    x=F.relu(self.layer4(x))
    x=F.relu(self.layer5(x))
    x=F.relu(self.layer6(x))
    x=F.relu(self.layer7(x))
    x=F.relu(self.layer8(x))

    return x

In [6]:
class Decoder(nn.Module):
  def __init__(self):
    super(Decoder, self).__init__()

    self.layer1=nn.Conv2d(in_channels=1257, out_channels=256, kernel_size=1)
    self.layer2=nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1)
    self.layer3=nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1)
    self.layer4=nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
    self.layer5=nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1)
    self.layer6=nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, padding=1)


  def forward(self, x):
    x=F.relu(self.layer1(x))
    x=F.relu(self.layer2(x))
    x=F.interpolate(x, scale_factor=2)
    
    x=F.relu(self.layer3(x))
    x=F.relu(self.layer4(x))
    x=F.interpolate(x, scale_factor=2)

    x=F.relu(self.layer5(x))
    x=torch.tanh(self.layer6(x))
    x=F.interpolate(x, scale_factor=2)

    return x

In [7]:
#feature extractor
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
inception = pretrainedmodels.__dict__["inceptionresnetv2"](
            num_classes=1001, 
            pretrained="imagenet+background")
inception.eval()

Downloading: "http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth" to /root/.cache/torch/hub/checkpoints/inceptionresnetv2-520b38e4.pth
100%|██████████| 213M/213M [10:28<00:00, 356kB/s]


InceptionResNetV2(
  (conv2d_1a): BasicConv2d(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (conv2d_2a): BasicConv2d(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (conv2d_2b): BasicConv2d(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (maxpool_3a): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2d_3b): BasicConv2d(
    (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (conv2d_4a): 

In [27]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [28]:
#https://github.com/lauradang/automatic-image-colorization/blob/master/notebooks/inception_resnet.ipynb
class EDModel(nn.Module):
    def __init__(self):
      super(EDModel, self).__init__()

      self.encoder=Encoder()
      self.decoder=Decoder()

    def forward(self, x, feature):
      enout=self.encoder(x)

      extract_feat=(inception(feature)).view(-1, 1001, 1, 1)
      rows = torch.cat([extract_feat] * 28, dim=3)
      embedding_block = torch.cat([rows] * 28, dim=2)
      fusion_block = torch.cat([enout, embedding_block], dim=1)

      return self.decoder(fusion_block)

In [9]:
%%capture
!unzip /content/drive/MyDrive/Senior_Year/flickr.zip

In [10]:
#Pre processing of data to index rbg training data

def preProcessing(path, w=256, h=256):
  if not os.path.exists(path + "/../train"):
    os.mkdir(path + "/../train")

  idx=0
  for filename in os.listdir(path):
    if not os.path.isdir(path + "/" + filename) and ".csv" not in filename:
      
      #get RGB image, resize, and put in rgb folder
      rgb_img=Image.open(path + "/" + filename)
      rgb_img=rgb_img.resize((w, h))
      rgb_img.save(path + "/../train/" + str(idx) + "_rgb.png")

      idx+=1

In [11]:
#Set up dataset so images are indexed
preProcessing("/content/flickr30k_images/flickr30k_images", 256, 256)

In [12]:
from skimage import io
class myEDDataset():
  def __init__(self, img_dir):
    self.img_dir=img_dir
    self.dataset_length = len(os.listdir(img_dir + "/../train"))

  def __len__(self):
    return int((self.dataset_length)/10)

  def __getitem__(self, idx):
    rgb=Image.open(self.img_dir + "/../train/" + str(idx) + "_rgb.png")
    rgb_resh=rgb.resize((224, 224))

    #Get CIE-Lab equivalent of the RGB image
    clab_array=color.rgb2lab(np.asarray(rgb_resh, dtype = np.float32))

    #Get greyscale CIE-Lab equivalent of the RGB image
    clab_grayarray = clab_array[:, :, 0]
    clab_grayarray = clab_grayarray[np.newaxis,:, :]

    clab_array=clab_array[:, :, 1:3]
    clab_array = clab_array.transpose((2, 0, 1)) # (2, 224, 224)

    #get incep input
    target_img=rgb.resize((299, 299))
    incep_array=color.rgb2lab(np.asarray(target_img, dtype = np.float32))
    incep_array = incep_array.transpose((2, 0, 1)) # (3, 299, 299)

    #Transform
    input=clab_grayarray
    target=clab_array
    incep=incep_array

    data={'input': input, 'target': target, 'incep': incep}

    return data

In [36]:
def train(model, train_loader, criterion, optimizer, device):
  model.train()
  
  total_loss=0.0
  num=0
  for data in enumerate(train_loader):
    input=data[1]['input'].to(device)
    feature=data[1]['incep'].to(device)
    
    optimizer.zero_grad()
    output=model(input, feature)
    loss = criterion(output, data[1]['target'].to(device))

    loss.backward()
    optimizer.step()

    total_loss+=loss.item()
    num+=1
  
  return total_loss/len(train_loader)

def val(model, val_loader, criterion, device):
  model.eval()
  
  total_loss=0.0
  for data in enumerate(train_loader):
    input=data[1]['input'].to(device)
    feature=data[1]['incep'].to(device)

    output=model(input, feature)
    loss = criterion(output, data[1]['target'].to(device))

    total_loss+=loss
    
  return total_loss

In [38]:
inception.to(device)

#EDModel
model = EDModel()
model.to(device)

#Dataset
img_dir="/content/flickr30k_images/flickr30k_images"
dataset = myEDDataset(img_dir)
subsets=torch.utils.data.random_split(dataset, [0.9, 0.1])

#Dataloaders:
train_loader=torch.utils.data.DataLoader(subsets[0])
val_loader=torch.utils.data.DataLoader(subsets[1])

#optimizer and criterion
optimizer=torch.optim.Adam(model.parameters(), lr=0.001)
criterion=nn.MSELoss()

#keep track of loss
training_loss=[]
validation_loss=[]

MAX_EPOCHS=20
latest_loss=10000000000000
for epoch in range(MAX_EPOCHS):
  print("Epoch " + str(epoch) + "/" + str(MAX_EPOCHS))

  train_loss=train(model, train_loader, criterion, optimizer, device)
  val_loss=val(model, val_loader, criterion, device)

  print("-----------------------------------------")
  print("Train loss: " + str(train_loss))
  print("Val loss: " + str(val_loss))
  print("-----------------------------------------")

  training_loss.append(train_loss)
  validation_loss.append(val_loss)

  if val_loss<latest_loss:
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': val_loss,
            }, '/content/drive/MyDrive/Senior_Year/checkpoint.pt')

OutOfMemoryError: ignored

Citations: \\
https://github.com/lauradang/automatic-image-colorization/blob/master/notebooks/inception_resnet.ipynb \\
https://arxiv.org/pdf/1712.03400.pdf 