In [2]:
import os
import glob
import time
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb
import cv2
import torchvision
from torchvision import models
import torch.nn.functional as F

import scipy
import torch
import torchvision.datasets as dset
from torch import nn, optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
np.random.seed(2021)
torch.manual_seed(2021)
import copy

# **Defining Dataset**

In [2]:
data_root = "../input/flickrfaceshq-dataset-ffhq"

class FaceData(torch.utils.data.Dataset):
    def __init__(self, transformations=None, root = '../input/flickrfaceshq-dataset-ffhq'):
        self.root = root
        self.transformations = transformations
        self.img_list = glob.glob(os.path.join(root,'*'))
        self.img_list.sort()
    
    def __len__(self):
        return len(glob.glob("../input/flickrfaceshq-dataset-ffhq/*"))

        
    def __getitem__(self, idx):
        ip_img = Image.open(self.img_list[idx])#.convert("RGB")
        ip_img = ip_img.resize((224,224))
        
        ip_img = np.array(ip_img)
        img_lab = rgb2lab(ip_img).astype("float32") # Converting RGB to L*a*b
        img_lab = transforms.ToTensor()(img_lab)

        L = img_lab[[0]] / 50. - 1. # Max for L is 100 so need /50 -1 to get b/w -1,1
        ab = img_lab[[1,2]]/100. #Max for a and b is 127
        
        return L, ab
        

In [3]:
data_set = FaceData()

sample_loader = torch.utils.data.DataLoader(data_set, batch_size=32)

# sample_batch=next(iter(sample_loader))
# len(sample_batch)

# type(sample_batch[0]),sample_batch[0].shape,sample_batch[1].shape
# concat_lab=torch.cat((sample_batch[0],sample_batch[1]),dim=1)
# concat_lab.shape
# concat_rgb = lab2rgb(torch.cat(((sample_batch[0]+1)*50.,sample_batch[1]*100.),dim=1).permute(0,2,3,1))

# concat_rgb = torch.tensor(concat_rgb).permute(0,3,1,2)

# grid1 = torchvision.utils.make_grid(sample_batch[0],nrow=16)
# grid2 = torchvision.utils.make_grid(concat_lab,nrow=16)
# grid3 =  torchvision.utils.make_grid(concat_rgb,nrow=16)

# fig, axs = plt.subplots(3,1,figsize=(50,50),sharex=True)#2 rows, 1 column

# axs[0].imshow(np.transpose(grid1,(1,2,0)))
# axs[1].imshow(np.transpose(grid2,(1,2,0)))
# axs[2].imshow(np.transpose(grid3,(1,2,0)))

# plt.subplots_adjust(top=0.5)

In [4]:
# indices = list(range(len(data_set.img_list)))
# val_split = int(0.9*len(data_set.img_list)) 
# test_split = int(val_split+ ((len(data_set.img_list)-val_split)//2))
indices = list(range(30000))

# train_indices,val_indices,test_indices = indices[:20000], indices[20000:25000],indices[25000:30000]
train_indices,val_indices,test_indices = indices[:20000], indices[20000:25000],indices[25000:30000]

train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indices)
val_sampler = torch.utils.data.sampler.SubsetRandomSampler(val_indices)
test_sampler = torch.utils.data.sampler.SubsetRandomSampler(test_indices)

train_data = torch.utils.data.DataLoader(data_set,batch_size =64, sampler=train_sampler,
                                        num_workers = 2)
val_data = torch.utils.data.DataLoader(data_set,batch_size = 32, sampler =val_sampler,
                                      num_workers =2)
test_data = torch.utils.data.DataLoader(data_set,batch_size = 32, sampler =test_sampler,
                                      num_workers =2)
print(f'Train_data_size={len(train_data)}; Val_data_size={len(val_data)} ; Test_data_size={len(test_data)}')

# **Defining Model**

In [5]:
class Image_Colorization_Model(nn.Module):
    def __init__(self):
        super(Image_Colorization_Model, self).__init__()
#         Encoder
        self.conv_preprocess1 = nn.Conv2d(1, 3, kernel_size=3, padding=1)
        self.conv_preprocess2 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
        model_resnet = models.resnet50(pretrained=True)
        self.conv1 = model_resnet.conv1
        self.bn1 = model_resnet.bn1
        self.relu = model_resnet.relu
        self.maxpool = model_resnet.maxpool
        self.layer1 = model_resnet.layer1
        self.layer2 = model_resnet.layer2
#         Decoder
        self.upsample = nn.Upsample(scale_factor = 2, mode = 'nearest')
        self.conv_decode2_1 = nn.Conv2d(768, 128, kernel_size=3, padding=1)
        self.conv_decode2_2 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        
        self.conv_decode1_1 = nn.Conv2d(128, 16, kernel_size=3, padding=1)
        self.conv_decode1_2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)
        
        self.conv_decode0_1 = nn.Conv2d(11, 4, kernel_size=3, padding=1)
        self.conv_decode0_2 = nn.Conv2d(4, 2, kernel_size=1)
        
        
        
    def forward(self,x):
        ######ENCODER############
        x_pre = F.relu(self.conv_preprocess1(x))
        x_pre = F.relu(self.conv_preprocess2(x_pre))
        
        encode_1 = self.conv1(x_pre)
        encode_1 = self.bn1(encode_1)
        encode_1 = self.relu(encode_1)
        
        x_mp = self.maxpool(encode_1)
        
        encode_2 = self.layer1(x_mp)
        bottle_neck = self.layer2(encode_2)
        
        ########DECODER#########
        decode_2 = self.upsample(bottle_neck)
        decode_2 = torch.cat((encode_2,decode_2),dim=1)
#         print(decode_2.shape)
        decode_2 = F.relu(self.conv_decode2_1(decode_2))
        decode_2 = F.relu(self.conv_decode2_2(decode_2))
        
        decode_1 = self.upsample(decode_2)
        decode_1 = torch.cat((encode_1,decode_1),dim=1)
        decode_1 = F.relu(self.conv_decode1_1(decode_1))
        decode_1 = F.relu(self.conv_decode1_2(decode_1))
        
        decode_pre = self.upsample(decode_1)
        decode_pre = torch.cat((x_pre,decode_pre),dim=1)
        decode_pre = F.relu(self.conv_decode0_1(decode_pre))
        decode_pre = F.relu(self.conv_decode0_2(decode_pre))
        
        return decode_pre

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
learning_rate = 1e-3
# encoder_model = ResNet().to(device)
# decoder_model = Decoder(device).to(device)
# print(next(decoder_model.parameters()).device)
model = Image_Colorization_Model().cuda()
# model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
L1loss = nn.L1Loss()
os.mkdir('./weights')

# **Training Colorization Model**

In [7]:
# os.mkdir('./weights')
# def plot_loss_L1(train_list):
#     plt.figure(figsize=(20,20))
#     plt.plot(train_list,label='Train loss L1')
#     plt.xlabel('Iteration')
#     plt.ylabel('L1 Loss')
#     plt.legend(loc='upper right')
#     plt.savefig(os.path.join('./', 'Loss_graph_L1'))
#     plt.close()

# def plot_loss_mse(train_list):
#     plt.figure(figsize=(20,20))
#     plt.plot(train_list,label='Train loss MSE')
#     plt.xlabel('Iteration')
#     plt.ylabel('MSE Loss')
#     plt.legend(loc='upper right')
#     plt.savefig(os.path.join(os.path.join('./', 'Loss_graph_mse')))
#     plt.close()

# model.train()

# train_loss_l1 = []
# train_loss_mse = []

# for i in range(70):#Total=70 epoch
# #     b=0
#     for data in train_data:
# #         print(len(data))
#         gray = data[0].to(device)  
# #         plt.imshow(gray)
#         true_ab_space = data[1].to(device)
#         optimizer.zero_grad()
# #         print(f'x:{gray.shape}')
# #         print(f'y:{true_ab_space.shape}')
              
#         predicted_ab_space = model(gray)
# #         print(predicted_ab_space.shape)
# #         print(true_ab_space.shape)
#         loss = L1loss(predicted_ab_space, true_ab_space)
    
#         train_loss_l1.append(loss.item())
#         train_loss_mse.append((F.mse_loss(predicted_ab_space, true_ab_space)).item())
# #         loss = F.mse_loss(predicted_ab_space, true_ab_space)
        
#         loss.backward()
# #         print(loss)
# #         b+=1
# #         print(b)
#         optimizer.step()
        
# #         plot_loss_L1(train_loss_l1)
# #         plot_loss_mse(train_loss_mse)
        
#     if (i+1) % 10 == 0:
#         print('%d iterations' % (i+1))
#         print('L1_Loss %.3f' % np.mean(train_loss_l1[-100:]))
#         print('MSE_Loss: %.3f' % np.mean(train_loss_mse[-100:]))
#         print()

#     if (i+1)%5 == 0:
#         torch.save({'epoch':i,
#                     'model_state_dict':model.state_dict(),
#                     'optimizer_state_dict':optimizer.state_dict(),
#                     'loss':{'L1_Loss':train_loss_l1.copy(),'MSE_Loss':train_loss_mse.copy()},
#                    },os.path.join('./weights',f'_epoch_{i+1}'))

# **Testing Trained Model**

In [8]:
# torch.cuda.memory_allocated(), torch.cuda.current_device()

In [9]:
chkpt = torch.load("../input/color-model-epoch-70/_epoch_70 (1)",map_location=torch.device('cpu'))
# chkpt_inpaint = torch.load("../input/inpaint35/inpaint_epoch_35 (1)",map_location=torch.device('cpu'))
model.load_state_dict(chkpt['model_state_dict'])

# inpaint_model.load_state_dict(chkpt_inpaint['model_state_dict'])

model.eval()

with torch.no_grad():
    (gray_val, true_ab_val) = next(iter(val_data))
#     show(gray_val.cpu().numpy().transpose(0,2,3,1)[0])
    
    gray_val_orig = copy.deepcopy(gray_val)
    true_ab_val_orig = copy.deepcopy(true_ab_val)
    
    gray_val = gray_val.to(device)

    pred_ab = model(gray_val)

pred_ab = (pred_ab*100.).permute(0, 2, 3, 1).contiguous()
gray_val = ((gray_val+1)*50.).permute(0,2,3,1).contiguous()

ab_rgb = (true_ab_val_orig*100.).permute(0, 2, 3, 1).contiguous()
gray_rgb = ((gray_val_orig+1)*50.).permute(0,2,3,1).contiguous()

true_rgb = torch.tensor(lab2rgb(torch.cat((gray_rgb,ab_rgb),dim=3).cpu()))
true_rgb = true_rgb.permute(0,3,1,2)
# gray_val_orig = gray_val_orig.cpu().numpy().transpose(0,2,3,1)

# pred_ab = (pred_ab*100.).cpu().numpy().transpose(0,2,3,1)
# gray_val = ((gray_val+1)*50.).cpu().numpy().transpose(0,2,3,1)
# pred_lab = np.concatenate((gray_val,pred_ab),axis=3)
# print(pred_ab.shape,gray_val.shape,pred_lab.shape)
# pred_rgb =  lab2rgb(pred_lab)

pred_rgb = torch.tensor(lab2rgb(torch.cat((gray_val,pred_ab),dim=3).cpu()))
pred_rgb = pred_rgb.permute(0,3,1,2)

In [10]:
from torchvision.utils import make_grid
def show(img):
#     print(img.shape)
    npimg = img.numpy()
    npimg = img
    print(npimg.shape)
    _, ax = plt.subplots(figsize=(20,20))
    fig = ax.imshow(np.transpose(npimg, (1,2,0)))
    fig.axes.get_xaxis().set_visible(True)
    fig.axes.get_yaxis().set_visible(True)
    plt.savefig("./output")

show(make_grid(pred_rgb.cpu().data) )
# show(pred_rgb[0])

In [11]:
show(make_grid(((gray_val/100.)).permute(0,3,1,2).cpu().data))

In [12]:
show(make_grid(true_rgb.cpu().data) )

In [13]:
pip install torch-summary

In [14]:
from torchsummary import summary
summary(model, (1, 224, 224))

In [15]:
loss = chkpt['loss']['L1_Loss']
print(len(loss))
f = plt.figure(figsize=(20,8))
ax = f.add_subplot(1,2,1)
ax.plot(loss)
# ax.set_yscale('log')
ax.set_title('L1 Loss')
ax.set_xlabel('Iteration')
f.savefig('./loss')

In [16]:
loss = chkpt['loss']['MSE_Loss']
f = plt.figure(figsize=(20,8))
ax = f.add_subplot(1,2,1)
ax.plot(loss)
# ax.set_yscale('log')
ax.set_title('MSE_Loss Loss')
ax.set_xlabel('Iteration')