In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load in 

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the "../input/" directory.
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# Any results you write to the current directory are saved as output.

In [None]:
!wget http://data.csail.mit.edu/places/places205/testSetPlaces205_resize.tar.gz
!tar -xzf testSetPlaces205_resize.tar.gz

In [None]:
import os
os.makedirs('Color_Images/Train/color_data/', exist_ok=True) # 40,000 images
os.makedirs('Color_Images/Test/color_data/', exist_ok=True)   #  1,000 images
for i, file in enumerate(os.listdir('testSet_resize')):
  if i < 1000: # first 1000 will be val
    os.rename('testSet_resize/' + file, 'Color_Images/Test/color_data/' + file)
  else: # others will be val
    os.rename('testSet_resize/' + file, 'Color_Images/Train/color_data/' + file)

In [None]:
#len(os.listdir('Color_Images/Train/color_data/'))       # 40000 training images

In [None]:
# Make sure the images are there
from IPython.display import Image, display
display(Image(filename='Color_Images/Train/color_data/b13dcc2414fde1747442b4d068148a12.jpg'))

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch import optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import os, shutil, time

In [None]:
model = models.resnet18(pretrained=True)

In [None]:
child_counter = 0
for child in model.children():
    print(" child", child_counter, "is -")
    print(child)
    child_counter += 1

In [None]:
os.makedirs('Checkpoints', exist_ok=True)     # To save the checkpoints created

In [None]:
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
root = 'Color_Images/'

In [None]:
class GrayscaleImageFolder(datasets.ImageFolder):
  '''Custom images folder, which converts images to grayscale before loading'''
  def __getitem__(self, index):
    path, target = self.imgs[index]
    img = self.loader(path)
    if self.transform is not None:
      img_original = self.transform(img)
      img_original = np.asarray(img_original)
      img_lab = rgb2lab(img_original)
      img_lab = (img_lab + 128) / 255       #128 is added because the A & B channels have values in range -128 to 128.
      img_ab = img_lab[:, :, 1:3]           # We want only the AB channels.
      img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
      img_gray = rgb2gray(img_original) 
      #The rgb2gray conversion removes the channel dimension, so the output dimension will be (224,224).
      img_gray = torch.from_numpy(img_gray).unsqueeze(0).float() #Thats why we add a singleton dimension.
    if self.target_transform is not None:
      target = self.target_transform(target)
    return img_original, img_gray, img_ab, target

In [None]:
train_transform = transforms.Compose([transforms.Resize(300),
                                transforms.CenterCrop(224),
                                transforms.RandomHorizontalFlip(),
                                #transforms.RandomRotation(30),
                                transforms.RandomVerticalFlip(),  
                                #transforms.ToTensor()
                                #transforms.Normalize([0.5, 0.5, 0.5],
                                #                     [0.5, 0.5, 0.5])
                               ])
test_transform = transforms.Compose([transforms.Resize(300),
                                transforms.CenterCrop(224),
                                #transforms.ToTensor()
                                #transforms.Normalize([0.5, 0.5, 0.5],
                                #                     [0.5, 0.5, 0.5])
                                ])

# Pass transforms in here, then run the next cell to see how the transforms look
train_data = GrayscaleImageFolder(root+'Train', transform=train_transform)
test_data = GrayscaleImageFolder(root+'Test', transform=test_transform)

trainloader = torch.utils.data.DataLoader(train_data, batch_size=8, shuffle=True)
testloader = torch.utils.data.DataLoader(test_data, batch_size=8, shuffle=True)

In [None]:
def show_img(color_original, inp_gray, reconstructed_color):
    gray = np.transpose(inp_gray, (1, 2, 0))
    gray =  np.squeeze(gray, axis=(2,))
    
    fig=plt.figure(figsize=[15,5])
    
    # Normalizing
    #orig     = (color_original - color_original.min()) / (color_original.max() - color_original.min())
    #gray    = (gray - gray.min()) / (gray.max() - gray.min())
    #reconstructed_color = (reconstructed_color - reconstructed_color.min()) / (reconstructed_color.max() - reconstructed_color.min())
    
    fig.add_subplot(1, 3, 1, title='Original color')
    plt.imshow(color_original)
    
    fig.add_subplot(1, 3, 2, title='Gray')
    plt.imshow(gray, cmap = 'gray')
    
    fig.add_subplot(1, 3, 3, title='Reconstructed color')
    plt.imshow(reconstructed_color)
    
    fig.subplots_adjust(wspace = 0.5)
    plt.show()

In [None]:
def to_rgb(ab_img, gray_img):
    plt.clf() # clear matplotlib 
    color_image = torch.cat((gray_img, ab_img), 0).numpy() # combine channels
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
    color_image = lab2rgb(color_image.astype(np.float64))
    
    return color_image

In [None]:
class Colarization(nn.Module):
    def __init__(self, input_size=128):
        super(Colarization, self).__init__()
    
        ## First half: Encoding
        resnet = models.resnet18(num_classes=365) 
        # Change first conv layer to accept single-channel (grayscale) input
        resnet.conv1.weight = nn.Parameter(resnet.conv1.weight.sum(dim=1).unsqueeze(1)) 
        # Extract midlevel features from ResNet-gray
        self.midlevel_resnet = nn.Sequential(*list(resnet.children())[0:6])

        ## Second half: Decoding (Upsampling)
        self.upsample = nn.Sequential(     
        nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(128),
        nn.ReLU(),
        nn.Upsample(scale_factor=2),
        nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.Upsample(scale_factor=2),
        nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(32),
        nn.ReLU(),
        nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1),
        nn.Upsample(scale_factor=2)
    )
    
    def forward(self, x):

        # Pass input through ResNet-gray to extract features
        #midlevel_features = self.midlevel_resnet(input)
        x = self.midlevel_resnet(x)
        # Upsample to get colors
        output = self.upsample(x)
        return output

In [None]:
autoencoder = Colarization()
criterion = nn.MSELoss()
optimizer = optim.Adam(autoencoder.parameters(), lr=0.001, weight_decay=0.0)    #weight_decay=4e-3

In [None]:
autoencoder.to(device)

In [None]:
train_loss = []
valid_loss = []

# Training the model 

epochs = 120
#steps = 0


for e in range(epochs):
    running_loss = 0
    running_iter = 0
    print('****************************************')
    print('Starting epoch:',e+1)
    for i, (color_img, input_gray, input_ab, target) in enumerate(trainloader):
        input_gray = input_gray.to(device)
        input_ab = input_ab.to(device)
        target = target.to(device)
        
        optimizer.zero_grad()
        output_ab = autoencoder(input_gray) 
        #print(output.shape)
        loss = criterion(output_ab, input_ab)           
        loss.backward()              # For gradient calcultion
        optimizer.step()             # Optimizng - Tuning the weights of the model
    
        running_iter +=1
        running_loss += loss.item()
    
    autoencoder.eval()
    test_loss = 0
    test_iter = 0
    
    with torch.no_grad():
        for i, (color_img, input_gray, input_ab, target) in enumerate(testloader):
            
            input_gray = input_gray.to(device)
            input_ab = input_ab.to(device)
            
            output_ab = autoencoder(input_gray)
            loss = criterion(output_ab, input_ab)
        
            test_iter +=1
            test_loss += loss.item()
      
    
    # Visualizing the first image of the last batch in the validation set
    #original_gray = input_gray.cpu()
    inp_gray = input_gray[0].cpu()
    out_ab = output_ab[0].cpu()
    
    color_img = color_img[0].data.numpy()
    in_gray = inp_gray.data.numpy()
    ot_ab = out_ab.data.numpy()
    
    #print('input gray:',inp_gray.shape)
    #print('output ab channels:',out_ab.shape)
    print("Epoch:",e+1)
    print('Train loss:',running_loss)
    print('Test loss:',test_loss)
    
    reconstructed_color = to_rgb(out_ab, inp_gray)
    
    #print('input gray:',inp_gray.shape)
    #print('output ab channels:',out_ab.shape)
    #print('reconstructed color:',reconstructed_color.shape)
    
    show_img(color_img, in_gray, reconstructed_color)
    train_loss.append(running_loss / running_iter)
    valid_loss.append(test_loss / test_iter)
    test_loss = 0
    running_loss = 0
    if((e+1)%15 == 0):                    # Saving model every 15 epochs
        print('Saving model at epoch:',e)
        torch.save(autoencoder.state_dict(), 'Checkpoints/checkpoint_colorize_'+str(e)+'.pth')
    autoencoder.train()
    print('****************************************')