# Deep Colorization
### Deep learning final project for conversion of gray scale images to rgb
### Contributors: Bhumi Bhanushali, Avinash Hemaeshwara Raju, Kathan Nilesh Mehta, Atulay Ravishankar

### Download Dataset

In [None]:
# wget -N images.cocodataset.org/zips/train2017.zip
# wget -N images.cocodataset.org/zips/val2017.zip
# wget -N images.cocodataset.org/zips/test2017.zip
# pip3 install tensorboard
# tensorboard --logdir=runs

### Import Modules

In [94]:
import os
import time
import numpy as np 
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from network_definition import Colorization
from skimage import io, color
from skimage.transform import resize
import cv2

### Configuration

In [95]:
class Configuration:
    model_file_name = 'checkpoint.pt'
    load_model_to_train = False
    load_model_to_test = False
    device = "cuda" if torch.cuda.is_available() else "cpu"
    point_batches = 10

### Hyper Parameters

In [96]:
class HyperParameters:
    epochs = 1
    batch_size = 1
    learning_rate = 0.001
    num_workers = 8
    learning_rate_decay = 0.72

In [97]:
config = Configuration()
hparams = HyperParameters()
print('Device:',config.device)

Device: cpu


### Custom Dataloader

In [98]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, process_type):
        self.root_dir = root_dir
        self.files = [f for f in os.listdir(root_dir)]
        self.process_type = process_type
        print(self.process_type,'[0]\t',self.root_dir,self.files[0])

    def __len__(self):
        if self.process_type == 'train':
            return 100#len(self.files)
        else:
            return len(self.files)

    def __getitem__(self, index):
        self.rgb = io.imread(os.path.join(self.root_dir,self.files[index]))
        rgb_encoder = resize(self.rgb, (224, 224),anti_aliasing=True)
        rgb_inception = resize(self.rgb, (300, 300),anti_aliasing=True)

        self.lab_encoder = color.rgb2lab(rgb_encoder)
        l_encoder = self.lab_encoder[:,:,0]
        l_encoder = np.stack((l_encoder,)*3,axis = -1)
        l_encoder = torchvision.transforms.ToTensor()(l_encoder)
        ab_encoder = self.lab_encoder[:,:,1:3]
        ab_encoder = torchvision.transforms.ToTensor()(ab_encoder)

        self.lab_inception = color.rgb2lab(rgb_inception)
        l_inception = self.lab_inception[:,:,0]
        l_inception = np.stack((l_inception,)*3,axis = -1)
        l_inception = torchvision.transforms.ToTensor()(l_inception)

        return l_encoder, ab_encoder, l_inception, torchvision.transforms.ToTensor()(rgb_encoder)

    def show_rgb(self, index):
        self.__getitem__(index)
        print("RGB image size:", self.rgb.shape)
        cv2.imshow("RGB",self.rgb)
        cv2.waitKey(0)

    def show_lab_encoder(self, index):
        self.__getitem__(index)
        print("Encoder Lab image size:", self.lab_encoder.shape)
        cv2.imshow("Lab Encoder",self.lab_encoder)
        cv2.waitKey(0)

    def show_lab_inception(self, index):
        self.__getitem__(index)
        print("Inception Lab image size:", self.lab_inception.shape)
        cv2.imshow("Lab Inception",self.lab_encoder)
        cv2.waitKey(0)

### Encoder

In [99]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder,self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=2, padding=1), 
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        self.model = self.model.float()
        return self.model(x.float())

### Fusion Layer

In [100]:
class FusionLayer(nn.Module):
    def __init__(self):
        super(FusionLayer,self).__init__()

    def forward(self, inputs, mask=None):
        ip, emb = inputs
        emb = torch.stack([torch.stack([emb],dim=2)],dim=3)
        emb = emb.repeat(1,1,ip.shape[2],ip.shape[3])
        fusion = torch.cat((ip,emb),1)
        return fusion

### Decoder

In [101]:
class Decoder(nn.Module):
    def __init__(self, input_depth):
        super(Decoder,self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=input_depth, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2.0),
            nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2.0),
            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1),
            nn.Tanh(),
            nn.Upsample(scale_factor=2.0),
        )

    def forward(self, x):
        return self.model(x)

### Network Definition

In [102]:
class Colorization(nn.Module):
    def __init__(self, depth_after_fusion):
        super(Colorization,self).__init__()
        self.encoder = Encoder()
        self.fusion = FusionLayer()
        self.after_fusion = nn.Conv2d(in_channels=1256, out_channels=depth_after_fusion,kernel_size=1, stride=1,padding=0)
        self.decoder = Decoder(depth_after_fusion)

    def forward(self, img_l, img_emb):
        img_enc = self.encoder(img_l)
        fusion = self.fusion([img_enc, img_emb])
        fusion = self.after_fusion(fusion)
        return self.decoder(fusion)

### Architecture Pipeline

In [103]:
model = Colorization(256).to(config.device) 
inception_model = models.inception_v3(pretrained=True).float().to(config.device)
inception_model = inception_model.float()
inception_model.eval()
loss_criterion = torch.nn.MSELoss(reduction='mean').to(config.device)
optimizer = torch.optim.Adam(model.parameters(),lr=hparams.learning_rate, weight_decay=1e-6)
milestone_list  = list(range(0,hparams.epochs,2))
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestone_list, gamma=hparams.learning_rate_decay)
writer = SummaryWriter()

### Data Loaders

In [104]:
train_dataset = CustomDataset('data/train','train')
validataion_dataset = CustomDataset('data/validation','validation')
test_dataset = CustomDataset('data/test','test')

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=hparams.batch_size, shuffle=True, num_workers=hparams.num_workers)
validation_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=hparams.batch_size, shuffle=False, num_workers=hparams.num_workers)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=hparams.num_workers)

train [0]	 data/train 000000109622.jpg
validation [0]	 data/validation 000000182611.jpg
test [0]	 data/test 000000220208.jpg


### Training & Validation Pipeline

In [105]:
for epoch in range(hparams.epochs):
    print('Starting epoch:',epoch+1)

    #*** Training step ***
    loop_start = time.time()
    avg_loss = 0.0
    batch_loss = 0.0
    main_start = time.time()
    
    for idx,(img_l_encoder, img_ab_encoder, img_l_inception, img_rgb) in enumerate(train_dataloader):
        #*** Move data to GPU if available ***
        img_l_encoder = img_l_encoder.to(config.device)
        img_ab_encoder = img_ab_encoder.to(config.device)
        img_l_inception = img_l_inception.to(config.device)

        #*** Initialize model & Optimizer ***
        model.train()
        optimizer.zero_grad()

        #*** Forward Propagation ***
        img_embs = inception_model(img_l_inception.float())
        output_ab = model(img_l_encoder,img_embs)

        #*** Back propogation ***
        loss = loss_criterion(output_ab, img_ab_encoder.float())
        loss.backward()

        #*** Weight Update ****
        optimizer.step()

        #*** Reduce Learning Rate ***
        scheduler.step()

        #*** Loss Calculation ***
        avg_loss += loss.item()
        batch_loss += loss.item()

        #*** Print stats after every point_batches ***
        if idx%config.point_batches==0: 
            loop_end = time.time()   
            print('Batch:',idx, '| Processing time for',config.point_batches,':',loop_end-loop_start,\
                  '| Batch Loss:', batch_loss/config.point_batches)
            loop_start = time.time()
            batch_loss = 0.0

    #*** Print Training Data Stats ***
    train_loss = avg_loss/len(train_dataloader)*hparams.batch_size
    writer.add_scalar('Loss/train', train_loss, epoch)
    print('Training Loss:',train_loss,'| Processed in ',time.time()-main_start,'s')

    #*** Validation Step ***       
    avg_loss = 0.0
    loop_start = time.time()
    for idx,(img_l_encoder, img_ab_encoder, img_l_inception, img_rgb) in enumerate(validation_dataloader):
        #*** Move data to GPU if available ***
        img_l_encoder = img_l_encoder.to(config.device)
        img_ab_encoder = img_ab_encoder.to(config.device)
        img_l_inception = img_l_inception.to(config.device)

        #*** Intialize Model to Eval Mode ***
        model.eval()

        #*** Forward Propagation ***
        img_embs = inception_model(img_l_inception.float())
        output_ab = model(img_l_encoder,img_embs)

        #*** Loss Calculation ***
        loss = loss_criterion(output_ab, img_ab_encoder.float())
        avg_loss += loss.item()

    val_loss = avg_loss/len(validation_dataloader)*hparams.batch_size
    writer.add_scalar('Loss/validation', val_loss, epoch)
    print('Validation Loss:', val_loss,'| Processed in ',time.time()-loop_start,'s')

    #*** Save the Model to disk ***
    checkpoint = {'model': model,'model_state_dict': model.state_dict(),\
                  'optimizer' : optimizer,'optimizer_state_dict' : optimizer.state_dict(),\
                  'train_loss':train_loss, 'val_loss':val_loss}
    torch.save(checkpoint, config.model_file_name)
    print("Model saved at:",os.getcwd(),'/',config.model_file_name)

Starting epoch: 1
Batch: 0 | Processing time for 10 : 2.805928945541382 | Batch Loss: 33.74257202148438
Batch: 10 | Processing time for 10 : 17.890161752700806 | Batch Loss: 305.5652221679687
Batch: 20 | Processing time for 10 : 18.239235162734985 | Batch Loss: 217.08550338745118
Batch: 30 | Processing time for 10 : 18.205628871917725 | Batch Loss: 187.57070410251617
Batch: 40 | Processing time for 10 : 17.732488870620728 | Batch Loss: 135.05602016448975
Batch: 50 | Processing time for 10 : 17.673827171325684 | Batch Loss: 228.61508407592774
Batch: 60 | Processing time for 10 : 17.780488967895508 | Batch Loss: 213.291748046875
Batch: 70 | Processing time for 10 : 17.75102710723877 | Batch Loss: 145.469234085083
Batch: 80 | Processing time for 10 : 17.676505088806152 | Batch Loss: 197.42254810333253
Batch: 90 | Processing time for 10 : 17.57918405532837 | Batch Loss: 278.00376052856444
Training Loss: 208.44994093179702 | Processed in  179.03973412513733 s
Validation Loss: 208.2398022294

NameError: name 'train_acc' is not defined

### Inference

##### Convert Tensor Image -> Numpy Image -> Color  Image -> Tensor Image

In [116]:
print(model)

Colorization(
  (encoder): Encoder(
    (model): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (5): ReLU(inplace=True)
      (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): ReLU(inplace=True)
      (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (9): ReLU(inplace=True)
      (10): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): ReLU(inplace=True)
      (14): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (15): ReLU(inplace=True)
    )
  )
  (fusion): FusionLayer()
  (after_fusion): Conv2d(1256

In [123]:
def concatente_and_colorize(im_lab, img_ab):
    # Assumption is that im_lab is of size [1,3,224,224]
    print(im_lab.size(),img_ab.size())
    np_img = im_lab[0].detach().numpy().transpose(1,2,0)
    lab = np.empty([*np_img.shape[0:2], 3])
    lab[:, :, 0] = np.squeeze(((np_img + 1) * 50))
    lab[:, :, 1:] = img_ab[0].detach().numpy().transpose(1,2,0) * 127
    np_img = color.lab2rgb(lab)
    color_im = torch.stack([torchvision.transforms.ToTensor()(np_img)],dim=0)
    return color_im

In [None]:
def colorize(im_lab):
    # Assumption is that im_lab is of size [1,3,224,224]
    np_img = im_lab[0].detach().numpy().transpose(1,2,0)
    np_img = color.lab2rgb(np_img)
    color_im = torch.stack([torchvision.transforms.ToTensor()(np_img)],dim=0)
    return color_im

In [125]:
#*** Inference Step ***
avg_loss = 0.0
loop_start = time.time()
for idx,(img_l_encoder, img_ab_encoder, img_l_inception, img_rgb) in enumerate(test_dataloader):
        #*** Move data to GPU if available ***
        img_l_encoder = img_l_encoder.to(config.device)
        img_ab_encoder = img_ab_encoder.to(config.device)
        img_l_inception = img_l_inception.to(config.device)
        
        #*** Intialize Model to Eval Mode ***
        model.eval()
        
        #*** Forward Propagation ***
        img_embs = inception_model(img_l_inception.float())
        output_ab = model(img_l_encoder,img_embs)
        
        #*** Adding l channel to ab channels ***
        img_lab = concatente_and_colorize(torch.stack([img_l_encoder[:,0,:,:]],dim=1),output_ab)
#         img_lab = torch.cat((torch.stack([img_l_encoder[:,0,:,:]],dim=1).float(),output_ab),1)
#        img_lab = colorize(img_lab)
        
        #*** Printing to Tensor Board ***
        grid = torchvision.utils.make_grid(img_lab)
        writer.add_image('Output Lab Images', grid, 0)
        
        #*** Loss Calculation ***
        loss = loss_criterion(output_ab, img_ab_encoder.float())
        avg_loss += loss.item()
        
test_loss = avg_loss/len(test_dataloader)
writer.add_scalar('Loss/test', test_loss, epoch)
print('Test Loss:',avg_loss/len(test_dataloader),'| Processed in ',time.time()-loop_start,'s')

torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1,

  return xyz2rgb(lab2xyz(lab, illuminant, observer))


torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])


  return xyz2rgb(lab2xyz(lab, illuminant, observer))


torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])


  return xyz2rgb(lab2xyz(lab, illuminant, observer))


torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])


  return xyz2rgb(lab2xyz(lab, illuminant, observer))


torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])


  return xyz2rgb(lab2xyz(lab, illuminant, observer))


torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])


  return xyz2rgb(lab2xyz(lab, illuminant, observer))


torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])


  return xyz2rgb(lab2xyz(lab, illuminant, observer))


torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])


  return xyz2rgb(lab2xyz(lab, illuminant, observer))


torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])
torch.Size([1, 1, 224, 224]) torch.Size([1, 2, 224, 224])


KeyboardInterrupt: 

In [None]:
writer.close()