In [1]:
from torch import nn
import torch
from torchvision.models import inception_v3
import torch.nn.functional as F

size = (1, 224, 224)

class ImageColorNet(nn.Module):
    def __init__(self, batch_size):
        super(ImageColorNet, self).__init__()

        self.batch_size = batch_size

        # encoder - input
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)

        self.conv2 = nn.Conv2d(32, 128, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)

        # add layers
        
        self.conv5 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.pool5 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)

        self.conv6 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.conv7 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.conv8 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)

        self.conv9 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.pool9 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)

        self.conv_10 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
        self.conv_11 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)

        self.conv_12 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.conv_13 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1)

        self.model_inceptionv3 = inception_v3(pretrained=True, aux_logits=False)
        self.linear1 = nn.Linear(2048, 128)
        self.linear2 = nn.Linear(128, 1)

        self.layernorm = nn.LayerNorm([28, 28])

        for i, param in self.model_inceptionv3.named_parameters():
            param.requires_grad = False

        num_ftrs = self.model_inceptionv3.fc.in_features
        self.model_inceptionv3.fc = nn.Linear(num_ftrs, num_ftrs*128)

        for name, child in self.model_inceptionv3.named_children():
            for params in child.parameters():
                params.requires_grad = False
                
        self.conv_15 = nn.Conv2d(384, 128, kernel_size=3, stride=1, padding=1)
        self.conv_16 = nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=1)
        self.conv_17 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.conv_18 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.conv_19 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1)
        self.conv_20 = nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1)
        self.conv_21 = nn.Conv2d(16, 2, kernel_size=3, stride=1, padding=(1, 1))
        self.conv_22 = nn.ConvTranspose2d(2, 2, kernel_size=3, stride=2, padding=1)

    def forward(self, output, output_inception):
        output = output.view(self.batch_size,1,225,225)
        
        x = F.relu(self.conv1(output))
        x = self.pool1(x)
        conv2 = F.relu(self.conv2(x))
        conv3 = F.relu(self.conv3(conv2))
        conv4 = F.relu(self.conv4(conv3))
        
        layer4 = conv2 + conv4
        
        x = F.relu(self.conv5(layer4))
        x = self.pool5(x)
        conv6 = F.relu(self.conv6(x))
        conv7 = F.relu(self.conv7(conv6))
        conv8 = F.relu(self.conv8(conv7))
        
        layer8 = conv6 + conv8
        
        x = F.relu(self.conv9(layer8))
        x = self.pool9(x)
        conv_10 = F.relu(self.conv_10(x))
        conv_11 = F.relu(self.conv_11(conv_10))
        conv_12 = F.relu(self.conv_12(conv_11))
        
        layer_12 = conv_10 + conv_12
        
        x = F.relu(self.conv_13(layer_12))
        
        inception_output = self.model_inceptionv3.forward(output_inception.view(self.batch_size,3,299,299))
        inception_output = self.linear1(inception_output.view(self.batch_size, 128, 2048))
        inception_output = self.linear2(inception_output.view(self.batch_size, 128, 128))
        
        inception_output = inception_output.view(self.batch_size, 128)
        inception_output = inception_output.repeat(28, 28)
        inception_output = inception_output.view(self.batch_size, 128, 28, 28)
        inception_output = self.layernorm(inception_output)
        
        x = torch.cat([inception_output, x], dim=1)
        
        x = self.conv_15(x)
        x = F.relu(self.conv_16(x))
        x = F.relu(self.conv_17(x))
        x = F.relu(self.conv_18(x))
        x = F.relu(self.conv_19(x))
        x = F.relu(self.conv_20(x))
        x = F.relu(self.conv_21(x))
        x = F.relu(self.conv_22(x))
        
        return x

In [4]:
from pytorch_model_summary import summary

print(summary(ImageColorNet(32), torch.zeros(32,1,225,225), torch.zeros(32,3,299,299), show_input=False, show_hierarchical=False))

------------------------------------------------------------------------------
         Layer (type)            Output Shape         Param #     Tr. Param #
             Conv2d-1      [32, 32, 225, 225]             320             320
          MaxPool2d-2      [32, 32, 112, 112]               0               0
             Conv2d-3     [32, 128, 112, 112]          36,992          36,992
             Conv2d-4     [32, 128, 112, 112]         147,584         147,584
             Conv2d-5     [32, 128, 112, 112]         147,584         147,584
             Conv2d-6     [32, 128, 112, 112]         147,584         147,584
          MaxPool2d-7       [32, 128, 56, 56]               0               0
             Conv2d-8       [32, 256, 56, 56]         295,168         295,168
             Conv2d-9       [32, 256, 56, 56]         590,080         590,080
            Conv2d-10       [32, 256, 56, 56]         590,080         590,080
            Conv2d-11       [32, 256, 56, 56]         590,080  

In [6]:
import torch
from torchvision import datasets, transforms
from glob import glob
import numpy as np
from PIL import Image
import pytorch_colors as colors

batch_size = 32
net = ImageColorNet(32)
net.eval()

transform = transforms.Compose([transforms.Resize((225,225)),
                                transforms.ToTensor()])

inception_transform = transforms.Compose([transforms.Resize((299,299)),
                                transforms.Grayscale(num_output_channels=3),
                                transforms.ToTensor()])

dataset = datasets.ImageFolder("E:\\Drives-Linux-ubuntu-2020\\home\\aswin\\Documents\\Deep-Learning-Projects\\open_images_dataset\\dataset\\test", transform=transform)
inception_dataset = datasets.ImageFolder("E:\\Drives-Linux-ubuntu-2020\\home\\aswin\\Documents\\Deep-Learning-Projects\\open_images_dataset\\dataset\\test", transform=inception_transform)

dataloader = enumerate(torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False))
inception_dataloader = enumerate(torch.utils.data.DataLoader(inception_dataset, batch_size=batch_size, shuffle=False))
i, (images, labels) = next(iter(dataloader))
lab_images = rgb2lab(images)
i, (inception_images, inception_labels) = next(iter(inception_dataloader))
checkpoint = torch.load('checkpoints/img_color-1_170.checkpoint.pth')
with torch.no_grad():
    net.load_state_dict(checkpoint['model_state_dict'])
    o = net.forward(lab_images[:,0,:,:] / 255, inception_images)

In [12]:
new_images = torch.cat([lab_images[:,0,:,:].unsqueeze(1), o], dim=1)
rgb_images = lab2rgb(new_images)*255
import cv2
cv2.imwrite("image.png", rgb_images[0].permute(1,2,0).int().numpy())

True

In [15]:
cv2.imwrite("image.png", rgb_images[3].permute(1,2,0).int().numpy())

True

In [5]:
# https://github.com/richzhang/colorization-pytorch/blob/master/util/util.py

import torch

# Color conversion code
def rgb2xyz(rgb): # rgb from [0,1]

    mask = (rgb > .04045).type(torch.FloatTensor)
    if(rgb.is_cuda):
        mask = mask.cuda()

    rgb = (((rgb+.055)/1.055)**2.4)*mask + rgb/12.92*(1-mask)

    x = .412453*rgb[:,0,:,:]+.357580*rgb[:,1,:,:]+.180423*rgb[:,2,:,:]
    y = .212671*rgb[:,0,:,:]+.715160*rgb[:,1,:,:]+.072169*rgb[:,2,:,:]
    z = .019334*rgb[:,0,:,:]+.119193*rgb[:,1,:,:]+.950227*rgb[:,2,:,:]
    out = torch.cat((x[:,None,:,:],y[:,None,:,:],z[:,None,:,:]),dim=1)

    return out

def xyz2rgb(xyz):

    r = 3.24048134*xyz[:,0,:,:]-1.53715152*xyz[:,1,:,:]-0.49853633*xyz[:,2,:,:]
    g = -0.96925495*xyz[:,0,:,:]+1.87599*xyz[:,1,:,:]+.04155593*xyz[:,2,:,:]
    b = .05564664*xyz[:,0,:,:]-.20404134*xyz[:,1,:,:]+1.05731107*xyz[:,2,:,:]

    rgb = torch.cat((r[:,None,:,:],g[:,None,:,:],b[:,None,:,:]),dim=1)
    rgb = torch.max(rgb,torch.zeros_like(rgb)) # sometimes reaches a small negative number, which causes NaNs

    mask = (rgb > .0031308).type(torch.FloatTensor)
    if(rgb.is_cuda):
        mask = mask.cuda()

    rgb = (1.055*(rgb**(1./2.4)) - 0.055)*mask + 12.92*rgb*(1-mask)

    return rgb

def xyz2lab(xyz):
    # 0.95047, 1., 1.08883 # white
    sc = torch.Tensor((0.95047, 1., 1.08883))[None,:,None,None]
    if(xyz.is_cuda):
        sc = sc.cuda()

    xyz_scale = xyz/sc

    mask = (xyz_scale > .008856).type(torch.FloatTensor)
    if(xyz_scale.is_cuda):
        mask = mask.cuda()

    xyz_int = xyz_scale**(1/3.)*mask + (7.787*xyz_scale + 16./116.)*(1-mask)

    L = 116.*xyz_int[:,1,:,:]-16.
    a = 500.*(xyz_int[:,0,:,:]-xyz_int[:,1,:,:])
    b = 200.*(xyz_int[:,1,:,:]-xyz_int[:,2,:,:])
    out = torch.cat((L[:,None,:,:],a[:,None,:,:],b[:,None,:,:]),dim=1)

    return out

def lab2xyz(lab):
    y_int = (lab[:,0,:,:]+16.)/116.
    x_int = (lab[:,1,:,:]/500.) + y_int
    z_int = y_int - (lab[:,2,:,:]/200.)
    if(z_int.is_cuda):
        z_int = torch.max(torch.Tensor((0,)).cuda(), z_int)
    else:
        z_int = torch.max(torch.Tensor((0,)), z_int)

    out = torch.cat((x_int[:,None,:,:],y_int[:,None,:,:],z_int[:,None,:,:]),dim=1)
    mask = (out > .2068966).type(torch.FloatTensor)
    if(out.is_cuda):
        mask = mask.cuda()

    out = (out**3.)*mask + (out - 16./116.)/7.787*(1-mask)

    sc = torch.Tensor((0.95047, 1., 1.08883))[None,:,None,None]
    sc = sc.to(out.device)

    out = out*sc

    return out

def rgb2lab(rgb):
    lab = xyz2lab(rgb2xyz(rgb))
    l_rs = (lab[:,[0],:,:])/100
    ab_rs = (lab[:,1:,:,:]+128)/255
    out = torch.cat((l_rs,ab_rs),dim=1)
    return out

def lab2rgb(lab_rs):
    l = lab_rs[:,[0],:,:]*100
    ab = lab_rs[:,1:,:,:]*255 - 128
    lab = torch.cat((l,ab),dim=1)
    out = xyz2rgb(lab2xyz(lab))
    return out

In [5]:
lab_images[:,1:,:,:]

tensor([[[[ 1.2142e-01,  1.2142e-01,  1.2142e-01,  ...,  2.3181e-01,
            2.5050e-01,  2.2902e-01],
          [ 1.2142e-01,  1.2142e-01,  1.2142e-01,  ...,  2.0750e-01,
            2.3300e-01,  2.2099e-01],
          [ 1.2142e-01,  1.2142e-01,  1.2142e-01,  ...,  1.8631e-01,
            2.0323e-01,  2.2191e-01],
          ...,
          [ 9.0152e-02,  1.0846e-01,  9.0712e-02,  ...,  2.5054e-01,
            2.7492e-01,  2.8065e-01],
          [ 8.0519e-02,  9.6463e-02,  8.8521e-02,  ...,  2.7767e-01,
            2.9780e-01,  3.0633e-01],
          [ 7.9336e-02,  7.5061e-02,  7.5681e-02,  ...,  2.9588e-01,
            3.0299e-01,  3.0178e-01]],

         [[-3.4652e-02, -3.4652e-02, -3.4652e-02,  ...,  2.2017e-01,
            1.5568e-01,  1.0370e-01],
          [-3.4652e-02, -3.4652e-02, -3.4652e-02,  ...,  2.1241e-01,
            2.1569e-01,  1.9296e-01],
          [-3.4652e-02, -3.4652e-02, -3.4652e-02,  ...,  1.1260e-01,
            1.9418e-01,  2.0991e-01],
          ...,
     

In [2]:
import cv2

color_transform = transforms.Compose([transforms.Resize((224,224)),
                                transforms.ToTensor()])
color_dataset = datasets.ImageFolder(
  "E:\\Drives-Linux-ubuntu-2020\\home\\aswin\\Documents\\Deep-Learning-Projects\\open_images_dataset\\dataset\\test", 
  transform=color_transform)
color_dataloader = enumerate(torch.utils.data.DataLoader(color_dataset, batch_size=batch_size, shuffle=False))

In [3]:
i, (color_images, color_labels) = next(iter(color_dataloader))

In [11]:
cv2.imwrite("image.png", cv2.cvtColor(color_images[0].permute(1, 2, 0).numpy()*255, cv2.COLOR_BGR2RGB))

True

In [15]:
net = ImageColorNet(32)
for i, param in net.named_parameters():
    print(i)

conv1.weight
conv1.bias
conv2.weight
conv2.bias
conv3.weight
conv3.bias
conv4.weight
conv4.bias
conv5.weight
conv5.bias
conv6.weight
conv6.bias
conv7.weight
conv7.bias
conv8.weight
conv8.bias
conv9.weight
conv9.bias
conv_10.weight
conv_10.bias
conv_11.weight
conv_11.bias
conv_12.weight
conv_12.bias
conv_13.weight
conv_13.bias
model_inceptionv3.Conv2d_1a_3x3.conv.weight
model_inceptionv3.Conv2d_1a_3x3.bn.weight
model_inceptionv3.Conv2d_1a_3x3.bn.bias
model_inceptionv3.Conv2d_2a_3x3.conv.weight
model_inceptionv3.Conv2d_2a_3x3.bn.weight
model_inceptionv3.Conv2d_2a_3x3.bn.bias
model_inceptionv3.Conv2d_2b_3x3.conv.weight
model_inceptionv3.Conv2d_2b_3x3.bn.weight
model_inceptionv3.Conv2d_2b_3x3.bn.bias
model_inceptionv3.Conv2d_3b_1x1.conv.weight
model_inceptionv3.Conv2d_3b_1x1.bn.weight
model_inceptionv3.Conv2d_3b_1x1.bn.bias
model_inceptionv3.Conv2d_4a_3x3.conv.weight
model_inceptionv3.Conv2d_4a_3x3.bn.weight
model_inceptionv3.Conv2d_4a_3x3.bn.bias
model_inceptionv3.Mixed_5b.branch1x1.conv.

In [None]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False)
images, labels = next(iter(dataloader))

In [None]:
list1=[1,2,3,4,5]
list1=iter(list1)
# It will return first value of iterator i.e. '1'
print(next(list1))
# It will return second value of iterator i.e. '2'
print(next(list1))
# It will return third value of iterator i.e. '3'
print(next(list1))
# It will return fourth value of iterator i.e. '4'
print(next(list1))