# [INFO8010] Colorful image colorization of cat images

##### This is a program trains a CNN to recolour greyscale images with vibrant and realistic colours.

Required packages:

- Pytorch
- scikit-image
- numpy
- matplotlib
- PIL
- scipy

In [None]:
import matplotlib.pyplot as plt 
import numpy as np
from PIL import Image
from skimage import color
from scipy.ndimage import gaussian_filter

import os
import os.path

import torch
import torch.nn as nn
import torch.utils.data as data
import torch.nn.functional as F
import torchvision

import pickle

## Data loading
Inherited the image data loader class to modify the get item function. With this modification images became uniquely identifiable. 

In [None]:
#Data loading
from torchvision import datasets, transforms, utils
from torchvision.datasets import folder, ImageFolder

transform = transforms.Compose([
    transforms.Resize(255),
    transforms.CenterCrop(224),
    transforms.PILToTensor(),
])

class CustomImageDataset(ImageFolder):
    def __init__(self,root, transform = None,):
        super().__init__(root,transform=transform)

    def __len__(self):
        return super().__len__()

    def __getitem__(self, index: int):
        """
            Gets an image from the dataset and its index
        Inputs : 
            index (int): Index

        Outputs : 
            tuple: (sample, index) sample is the image and index is the unique identifier of the image in the dataset
        """
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, index

dataset = CustomImageDataset("cat_dataset/training_set",transform=transform)
batch_size = 32
loader = data.DataLoader(dataset, batch_size = batch_size, shuffle = True)
images, indexes = next(iter(loader))

img = images[0]
plt.imshow(transforms.functional.to_pil_image(img))
plt.show()

img2 = img.numpy()


#### Functions to convert between RGB and LAB colour spaces.

In [None]:
def rgb2lab(img):
    '''
        Transforms an RGB image to a LAB image and changes the channel format.
            Inputs: 
                - img : RGB image of shape (3,H,W)
            Outputs : 
                - LAB image of shape (H,W,3) 
    '''

    lab_imgs = color.rgb2lab(img,channel_axis = 0)

    return np.transpose(lab_imgs,(1,2,0))

def separateChannels(imageLAB):
    '''
        Takes an LAB image and separates the image in its 3 channels
            Inputs: 
                - imageLAB : LAB image with format (H,W,3)
            Outputs : 
                - l, a, b = 3 matrices of shape (H,W)
    '''
    l = imageLAB[:,:,0]
    a = imageLAB[:,:,1]
    b = imageLAB[:,:,2]

    return l,a,b

l,a,b = separateChannels(rgb2lab(img2))

### Computing colour statistics and discretisation

Compute statistics about the probabilities of occurence of (a,b) pairs in the cat dataset and print the associated distribution. Select the in-gamut (a,b) pairs from the distribution. Compute a weight for each in-gamut (a,b) pairs (for class rebalancing)

##### Auxiliary functions

Define a set of functions that are useful to work with the matrix that represent the 2d distribution of (a,b) pairs.

In [None]:

def getDiscretisedColor(a,b,gridSize):
    '''
        Given an A and B value of the lab colourspace return a discretised value.
            Inputs: 
                - a : real number between [-110, 110]
                - b : real number between [-110, 110]
                - gridSize : integer that describes the span of a discretized box
            Outputs : 
                - (a,b) : tuple of values in range [-110, 110]
    '''

    a = np.round(a/gridSize) * gridSize
    b = np.round(b/gridSize) * gridSize
    return (a,b)

def getMatrixIndex(a,b,gridSize):
    '''
        Given an A and B value of the lab colourspace return the index of the discretized value
            Inputs: 
                - a : real number between [-110, 110]
                - b : real number between [-110, 110]
                - gridSize : integer that describes the span of a discretized box
            Outputs : 
                - i,j : tuple of positive integers
    '''
    i = (a + 500) / gridSize
    j = (b + 200) / gridSize
    return (int(i),int(j))

def getColorValue(i,j,gridSize):
    '''
        Given two positive integer indexes get the colour value corresponding to the discretised cell value.
            Inputs: 
                - i : positive integer
                - j : positive integer
                - gridSize : integer that describes the span of a discretized box
            Outputs : 
                - (a,b) : tuple of values in range [-110, 110]
    '''
    a = i * gridSize - 500
    b = j * gridSize - 200
    return (a,b)

##### Probability distribution of (a,b) pairs

In [None]:

# Initialise the proba distribution of ab pairs in the images dataset (discretised).
gridSize = 10 
colorProbabilities = np.zeros((1000 // gridSize, 400 // gridSize))

# Compute the proba distribution of the ab pairs in the images dataset (discretised).
nbOfAnalysedPixels = 0
for data in loader:
    images,_ = data
    for image in images: 
        image = image.numpy()
        image = rgb2lab(image)[:][:][1:2]
        for h in range (image.shape[0]):
            for w in range (image.shape[1]):
                (a,b) = getDiscretisedColor(image[h][w][0],image[h][w][1],gridSize)
                (i,j) = getMatrixIndex(a,b,gridSize)
                colorProbabilities[i][j] += 1
                nbOfAnalysedPixels += 1
colorProbabilities = colorProbabilities / nbOfAnalysedPixels

# Smooth the proba distribution of the ab pairs in the images dataset.
sigma = 3 # gaussian kernel parameter
colorProbabilities_smooth = gaussian_filter(colorProbabilities, sigma=sigma) 
treshold = 0.001 # add treshold for in-gamut color selection


In [None]:

# Display (a,b) distributions in 2d plot.
plt.imshow(colorProbabilities, interpolation='none')
plt.xlim(10, 30) # need to be dynamic...
plt.ylim(40, 70)
plt.show()

# Display distribution in 2d plot (smooth)
plt.imshow(colorProbabilities_smooth, interpolation='none')
plt.xlim(10, 30) # need to be dynamic...
plt.ylim(40, 70)
plt.show()


##### Select in-gamut colours from (a,b) pairs

In [None]:

# Get the vector of proba of ab pairs that are "in gamut"
inGamutColors = []
inGamutColorsProbas = []
inGamutIndex = {}
currentColorIndex = 0
for i in range (colorProbabilities_smooth.shape[0]):
    for j in range (colorProbabilities_smooth.shape[1]):
        currentColorPorba = colorProbabilities_smooth[i][j]
        if currentColorPorba > treshold: # put a treshold ? if we use smoothing 
            (a,b) = getColorValue(i,j,gridSize)
            inGamutColors.append((a,b))
            inGamutColorsProbas.append(currentColorPorba)
            inGamutIndex[(a,b)] = currentColorIndex
            currentColorIndex += 1
        
Q = len(inGamutColors) 
p_smooth = torch.tensor(inGamutColorsProbas) 
print(Q)
print(inGamutColors)

##### Class rebalancing

In [None]:

# Set the parameters (from paper, need empirical value).
lambda_uniform = 1/2 

# Compute the weight vector.
pixelsWeights = torch.reciprocal((1 - lambda_uniform) * p_smooth + lambda_uniform / Q)

# Normalise the weight vector according to p_smooth (E[W] = 1).
E_W = torch.sum(p_smooth * pixelsWeights)
scale_factor = 1 / E_W
pixelsWeights = scale_factor * pixelsWeights 

# Print the weight vector.
print(pixelsWeights)

## Define and instantiate CNN and forward pass
#### The CNN is consistent with the description of the reference paper.

In [None]:
class ColorizationCNN(nn.Module):
    def __init__(self, nb_colour_bins = 313):
        '''
            Initialization function to define the neural network depending on the
            number of discretised colour bins.

                Inputs: 
                    - nb_colour_bins : positive integer describing the number of colours in gamut.
                Outputs : 
                    - ColorizationCNN  object
        '''
        super().__init__()

        self.l_cent = 50.
        self.l_norm = 100.
        self.ab_norm = 110.
        
        channels_block_1 = 64
        channels_block_2 = 128
        channels_block_3 = 256
        channels_block_4 = 512
        channels_block_5 = 512 #dilated
        channels_block_6 = 512 #dilated
        channels_block_7 = 512 
        channels_block_8 = 128 # transpose convolution necessary

        # first conv block : 2 convs. from luminosity image to 64 features map from 3x3 kernels. 50% downsampling and normalization at the end.
        self.convBlock1 = nn.Sequential(nn.Conv2d(1,channels_block_1,(3,3), padding =1), 
                nn.ReLU(True), #inplace for memory efficiency can be used as no skip connections are used.
                nn.Conv2d(channels_block_1,channels_block_1,(3,3), padding =1,stride=2), #50% downsampling achieved with a 2 stride. 
                nn.ReLU(True),
                nn.BatchNorm2d(channels_block_1) #normalization over the 64 channels created
        )

        # second conv block. 2 covs. from 64 features to 128 features map from 3x3 kernels. 50% downsampling and normalization at the end.
        self.convBlock2 = nn.Sequential(nn.Conv2d(64,channels_block_2,(3,3), padding =1,), 
                nn.ReLU(True),
                nn.Conv2d(channels_block_2,channels_block_2,(3,3), padding =1,stride=2), #50% downsampling achieved with a 2 stride. 
                nn.ReLU(True),
                nn.BatchNorm2d(channels_block_2)
        )

        # third conv block. 3 convs. from 64 to 128 features map from 3x3 kernels. 50% downsampling and normalization at the end.
        self.convBlock3 = nn.Sequential(nn.Conv2d(channels_block_2,channels_block_3,(3,3), padding = 1,), 
                nn.ReLU(True),
                nn.Conv2d(channels_block_3,channels_block_3,(3,3), padding =1),
                nn.ReLU(True),
                nn.Conv2d(channels_block_3,channels_block_3,(3,3), padding =1,stride=2), #50% downsampling achieved with a 2 stride. 
                nn.ReLU(True),
                nn.BatchNorm2d(channels_block_3)
        )

        # fourth conv block. 3 convs. from 256 to 512 features map from 3x3 kernels. 50% downsampling and normalization at the end.
        self.convBlock4 = nn.Sequential(nn.Conv2d(channels_block_3,channels_block_4,(3,3), padding =1,),
                nn.ReLU(True),
                nn.Conv2d(channels_block_4,channels_block_4,(3,3), padding = 1),
                nn.ReLU(True),
                nn.Conv2d(channels_block_4,channels_block_4,(3,3), padding = 1), #50% downsampling achieved with a 2 stride. 
                nn.ReLU(True),
                nn.BatchNorm2d(channels_block_4)
        )
        
        #fifth conv block. 3 convs. no change in nb feature maps. 3x3 kernels with 2 dilation and 2 padding to not downscale. normalization at the end.

        self.convBlock5 = nn.Sequential(nn.Conv2d(channels_block_4,channels_block_5,(3,3),dilation=2,padding=2),
                nn.ReLU(True),
                nn.Conv2d(channels_block_5,channels_block_5,(3,3),dilation=2,padding=2),
                nn.ReLU(True),
                nn.Conv2d(channels_block_5,channels_block_5,(3,3),dilation=2,padding=2), #50% downsampling achieved with a 2 stride. 
                nn.ReLU(True),
                nn.BatchNorm2d(channels_block_5)
        )

        #sixth conv block. same as 5
        self.convBlock6 = nn.Sequential(nn.Conv2d(channels_block_5,channels_block_6,(3,3),dilation=2,padding=2),
                nn.ReLU(True),
                nn.Conv2d(channels_block_6,channels_block_6,(3,3),dilation=2,padding=2),
                nn.ReLU(True),
                nn.Conv2d(channels_block_6,channels_block_6,(3,3),dilation=2,padding=2), #50% downsampling achieved with a 2 stride. 
                nn.ReLU(True),
                nn.BatchNorm2d(channels_block_6)
        )

        #seventh conv block : 3 convs with 3x3 kernels.
        self.convBlock7 = nn.Sequential(nn.Conv2d(channels_block_6,channels_block_7,(3,3),padding =1,),
                nn.ReLU(True),
                nn.Conv2d(channels_block_7,channels_block_7,(3,3),padding =1,),
                nn.ReLU(True),
                nn.Conv2d(channels_block_7,channels_block_7,(3,3),padding =1,), #50% downsampling achieved with a 2 stride. 
                nn.ReLU(True),
                nn.BatchNorm2d(channels_block_7)
        )

        #eighth conv block : 1 inverse conv to upsample then 2 convs with 3x3 kernels default parameters. Final convolution with 1x1 for classification into a colour bin
        self.convBlock8 = nn.Sequential(nn.ConvTranspose2d(channels_block_7,channels_block_8,(4,4),stride = 2, padding =1),
                nn.ReLU(True),
                nn.Conv2d(channels_block_8,channels_block_8,(3,3),padding = 1),
                nn.ReLU(True),
                nn.Conv2d(channels_block_8,channels_block_8,(3,3),padding = 1), #50% downsampling achieved with a 2 stride. 
                nn.ReLU(True),
                nn.Conv2d(channels_block_8,nb_colour_bins,kernel_size=1) #1x1 kernel for classification in each colour bin (value will be soft maxed for probability)
        )

        self.outputLayer = nn.Conv2d(nb_colour_bins,2,kernel_size=1,dilation=1, bias = False) # 1x1 kernel to get 2 channel values of a and b respectively
    
    def forward(self, luminosity_image):
        '''
            Forward pass of the convolutional neural network.
                Input:
                    - Tensor of size (batch_size, 1, H, W)
                Output : 
                    - Tensor of size (batch_size, Q, H/4, W/4)
        '''
        h1 = self.convBlock1((luminosity_image-self.l_cent)/self.l_norm) #normalize luminosity to be on scale of 0 to 100
        h2 = self.convBlock2(h1)
        h3 = self.convBlock3(h2)
        h4 = self.convBlock4(h3)
        h5 = self.convBlock5(h4)
        h6 = self.convBlock6(h5)
        h7 = self.convBlock7(h6)
        h8 = self.convBlock8(h7)

        colour_bin_proba = (nn.Softmax(dim=1))(h8)

        return colour_bin_proba
        
        

## Define Loss

### Soft encoding

In [None]:

# Define a distance measure bewteen two colors (a,b).
colorDistance = lambda c1, c2: np.sqrt((c1[0]-c2[0])**2 + (c1[1]-c2[1])**2)


# Define a Gaussian kernel.
def gaussianKernel(distances):
    sigma = 5
    weights = np.exp(-(distances**2) / sigma)
    return weights


def getColorDistribution(Y,nbOfNeighbors):
    '''
        Convert a true batch of images Y[BxHxWx2] to pixels color distributions Z[BxHxWxQ] with the soft encoding scheme.
            Input : 
                Y = true batch of images Y[BxHxWx2] 
                nbOfNeighbors = the number of Neighbors that we keep for the soft encoding.
            Output :
                Z = pixels color distributions Z[BxHxWxQ] (for the whole batch)

    '''
    # Initiate a tensor to store the distributions produced from Y.
    Z = torch.zeros(Y.shape[0], Y.shape[1], Y.shape[2], Q)

    # Produce a color distribution for each pixel of the image.
    for i in range (Z.shape[0]):
        print(i)
        for h in range (Z.shape[1]):
            for w in range (Z.shape[2]):
                color_true = Y[i][h][w]
                distances = np.array([colorDistance(color,color_true) for color in inGamutColors]) # not efficient... possible to do better ?
                nearestNeighborsIndexs = (-distances).argsort()[:nbOfNeighbors]
                #nearestNeighborsIndexs = np.argpartition(distances, -nbOfNeighbors)[-nbOfNeighbors:]
                weights = torch.from_numpy(gaussianKernel(distances[nearestNeighborsIndexs]))
                for j in range (weights.shape[0]):
                    Z[i][h][w][nearestNeighborsIndexs[j]] = weights[j]
    
    # Return the produced distributions.
    return Z
    

def getColorDistribution_1hot(Y):
    '''
        Convert a true batch of images Y[BxHxWx2] to pixels color distributions Z[BxHxWxQ] with the 1-hot encoding scheme.
            Input : 
                Y = true batch of images Y[BxHxWx2] 
            Output :
                Z = pixels color distributions Z[BxHxWxQ] (for the whole batch)

    '''
    # Initiate a tensor to store the distributions produced from Y.
    Z = torch.zeros(Y.shape[0], Y.shape[1], Y.shape[2], Q)

    # Produce a color distribution for each pixel of the image.
    for i in range (Z.shape[0]):
        for h in range (Z.shape[1]):
            for w in range (Z.shape[2]):
                color_true = Y[i,h,w]
                a, b = getDiscretisedColor(color_true[0],color_true[1],gridSize)
                if (int(a),int(b)) in inGamutIndex:
                    gamutIndex = inGamutIndex[(int(a),int(b))] # bug, not always in gamut... due to pooling of Y ?
                    Z[i,h,w,gamutIndex] = 1
                else:
                    Z[i,h,w,:] = Z[i,h,w,:] + 1/Q
    
    # Return the produced distributions.
    return Z


### Point estimate

In [None]:

def getPictureEstimate(Z,T):
    '''
        Convert the pixel color distributions in Z[HxWxQ] to true picture estimate Y[HxWx2] (point estimate)
            Input : 
                Z = tensor representation of a picture as pixel color probability distributions [HxWxQ]
                T = scalar representing a temperature 

            Output :
                true picture estimate Y[HxWx2] where the channal are a and b

    '''
    # Initiate a tensor to store the image estimated from Z.
    Y_estimate = torch.zeros(Z.shape[0], Z.shape[1], 2)
    
    # Estimate the Lab color for each pixel of the image.
    for h in range (Y_estimate.shape[0]):
        for w in range (Y_estimate.shape[1]):
            # Re-ajust the temperture of the current distribution.
            reajustedDistribution = torch.exp(torch.log10(Z[h,w,:]) / T)  / torch.sum(torch.exp(torch.log10(Z[h,w,:]) / T)) # check again...

            # Compute the anneled-mean of the current distribution. 
            a, b = 0, 0
            for q in range (Z.shape[2]):
                a += reajustedDistribution[q] * inGamutColors[q][0]
                b += reajustedDistribution[q] * inGamutColors[q][1]
            
            # Estimate the Lab color for the current pixel.
            Y_estimate[h][w][0] = a
            Y_estimate[h][w][1] = b
    
    # Return the estimated picture.
    return Y_estimate


### Loss function

In [None]:
# Take Z[BxHxWxQ] as input where B is the batch size

#v(Z_hw) weight in paper (section 2)

def getPixelsWeights(Z_batch):
    W = torch.argmax(Z_batch, dim=3)
    for i in range(W.size(dim=0)):
        for h in range(W.size(dim=1)):
            for w in range(W.size(dim=2)):
                W[i,h,w] = pixelsWeights[W[i,h,w]]

    return W

In [None]:
# Take Z[BxHxWxQ] as input where B is the batch size

#loss function 

def multinomialCrossEntropyLoss(Z_estimate_batch, Z_batch):
    W = getPixelsWeights(Z_batch)
    L = - torch.sum(torch.sum(torch.sum(W * torch.sum(Z_batch * torch.log10(Z_estimate_batch + sys.float_info.epsilon), dim=3), dim=2), dim=1))

    return L

#Loss = multinomialCrossEntropyLoss(torch.rand(4,224,224,Q), torch.rand(4,224,224,Q)) # test
#print(Loss)

## Main training loop and additional function TBD

### Soft encoding (pre computation of the distributions in a dictionary)

In [None]:
soft_encoded_images = {}

for data in loader:
    images, indexes = data

    #print(indexes.shape[0])

    lab_images = []

    for i in range(images.shape[0]):
        lab_im = rgb2lab(images[i].numpy())
        lab_images.append(lab_im)

    lab_images = torch.tensor(np.array(lab_images))

    downsampler_to_quarter_size = nn.AvgPool2d(4, stride=4)
    downsampled_lab = downsampler_to_quarter_size(torch.permute(lab_images,(0,3,1,2)))
    downsampled_lab = torch.permute(downsampled_lab,(0,2,3,1))

    #print(downsampled_lab[:,:,:,1:3].shape)
        
    imagesDistributions = getColorDistribution_1hot(downsampled_lab[:,:,:,1:3])

    for i in range(indexes.shape[0]):
        soft_encoded_images[indexes[i].item()] = imagesDistributions[i][:][:][:].numpy()

In [None]:

a_file = open("soft_encoded_images_1hot_cat.pkl", "wb")
pickle.dump(soft_encoded_images, a_file)
a_file.close()

#a_file = open("soft_encoded_images_1hot_cat.pkl", "rb")
#soft_encoded_images = pickle.load(a_file)
#print(output)
#a_file.close()


Define a training step: for each image in a batch : prediction, loss, backprop loss

In [None]:
def training_step(optim,loss_fct,network,dataloader):
    '''
        Function to execute a single training epoch over the dataset
            Inputs : 
                optim      : pytorch optimize used for the training step
                loss_fct   : loss function used to compare the predicted output to ground truth
                network    : pytorch neural network to be trained
                dataloader : pytorch data loader to provide the training data used in the training step

            Outputs :
                average loss over the epoch
    '''

    iteration_losses = []
    analysed_images = 0

    for data in dataloader:
        images, indexes = data

        lab_images_distributions = []
        l_images = []

        #convert RGB images to lab. Separate channels and retrieve soft encodings of images.
        for i in range(images.shape[0]):
            lab_im = rgb2lab(images[i].numpy())
            lab_images_distributions.append(soft_encoded_images[indexes[i].item()]) # get pre-computed distributions (soft encoding)
            l_images.append(lab_im[:,:,0])

        lab_images_distributions = torch.tensor(np.array(lab_images_distributions)).to("cuda")

        l_images = torch.tensor(np.array(l_images))
        l_images = torch.unsqueeze(l_images,dim=1).to("cuda")

        #forward pass of network, yielding (batch_size, Q, H/4, W/2) tensor
        predicted_colour_probability = network.forward(l_images.float().to("cuda"))
        predicted_colour_probability = torch.permute(predicted_colour_probability,(0,3,2,1))
        
        # compute losses
        loss = loss_fct(predicted_colour_probability, lab_images_distributions)

        # save losses
        iteration_losses.append(loss.item())

        # Back propagate
        optim.zero_grad() 
        loss.backward()
        optim.step()
        
        analysed_images += batch_size
        print(f"fraction of the dataset analysed: {round(analysed_images/3829 * 100, 2)}%")

    return sum(iteration_losses) / len(iteration_losses) # average loss of training iteration

In [None]:
#colCNN = ColorizationCNN(nb_colour_bins = Q).to("cuda")

# to retrain a trained model
PATH = os.getcwd() + "/network_cat_370_epochs_4_8.pth"
network = ColorizationCNN(nb_colour_bins = Q)
network.load_state_dict(torch.load(PATH))
network.eval()
colCNN = network.to("cuda")
#initial_lr = 3e-5 Put current learning rate ? save it

initial_lr = 0.00001
optimizer = torch.optim.Adam(colCNN.parameters(),lr = initial_lr , weight_decay=0)
#loss_plateau_scheduler1 = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, threshold = 20, factor = 0.3, verbose=True, patience= 2)
#loss_plateau_scheduler2 = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, threshold = 2,  factor = 0.3, verbose=True, patience = 2)
nb_epochs = 100
losses = np.zeros(nb_epochs)

for epoch in range(nb_epochs):
    print('epoch :',epoch)
    losses[epoch] = training_step(optimizer,multinomialCrossEntropyLoss,colCNN,loader)
    print('current loss :',losses)
    #if nb_epochs in {10,20,30,40,50,60,70,80,90,100}:
        #print("network saved")
        #PATH = os.getcwd() + f"/network_cat_{nb_epochs}_epochs_1.pth"
        #torch.save(colCNN.to('cpu').state_dict(), PATH) 
    #loss_plateau_scheduler1.step(losses[epoch])
    #loss_plateau_scheduler2.step(losses[epoch])

Plot losses

In [None]:
print(losses)

x_axis = range(1,nb_epochs+1)
plt.plot(x_axis,losses, label="train loss")

plt.xlabel('number of training steps')
plt.ylabel('avg loss')
plt.title('multinomial cross entropy loss')
plt.legend()
plt.show()


Show the neural network colorization of a specific image vs the real colors of that image

In [None]:
def show_images(img):
    plt.imshow(transforms.functional.to_pil_image(img))
    plt.show()

    
def show_col_vs_truth(network,data_loader,batch_size):
    '''
        Function to output figure with one batch of colored images vs its ground truth
        Inputs: 
            -network : Neural network for a*b* channel value predictions
            -data_loader : pytorch data loader to load a batch of images

        Outputs:
            - Figure of colored images vs ground truth
    '''
    with torch.no_grad():
        
        # Load RGB images.
        images, indexes = next(iter(data_loader))

        # Get l channel from RGB images.
        l_images = []
        for i in range(batch_size):
            lab_im = rgb2lab(images[i].numpy())
            l_images.append(lab_im[:,:,0])

        # Get a tensor of shape (Batch,1,H,W) out of the l images.
        l_images = torch.tensor(np.array(l_images))
        l_images = torch.unsqueeze(l_images,dim=1)

        # Predictinng the colour bin probability yields tensor of size (B,Q,H,W).
        predicted_colour_probability = network.forward(l_images.float())

        # Permute prediction to get (B,H,W,Q) to later get colour estimate.
        predicted_colour_probability = torch.permute(predicted_colour_probability,(0,2,3,1)) # or (0,3,2,1) ? or it is the same ?
        
        predicted_rgb_images = []
        for i in range(batch_size):

            # Get estimate in formate (H,W,2). format (H,W,Q) required.
            estimate = getPictureEstimate(predicted_colour_probability[i,:,:,:], 0.38)
            
            # Permute estimate to get (2,H,W) format.
            estimate = torch.permute(estimate,(2,0,1))
            
            # Define upsampler.
            bilinear_upsampler_to_224 = torch.nn.Upsample((224,224),mode='bilinear',align_corners=True)

            # Upsample. to get (2,224,224) format
            estimate = bilinear_upsampler_to_224(torch.unsqueeze(estimate, dim=0))[0,:,:,:]
            
            # Artificial modification, to solve issue.
            estimate = torch.permute(estimate,(0,2,1))

            # Fuse predicted channels and light channel. 
            predicted_lab_image = torch.cat((l_images[i,:,:,:],estimate), dim = 0)
            
            #permute to get (224,224,3)
            predicted_lab_image = predicted_lab_image.permute(1,2,0)

            # Convert to RGB and add to list. in format(3,224,224)
            predicted_rgb_images.append(color.lab2rgb(predicted_lab_image).transpose(2,0,1)) 
            
        # Transform list to tensor.
        predicted_rgb_images = torch.tensor(np.asarray(predicted_rgb_images))
        
        # Print 4 ground truth pictures along with the grey version and recolored one.
        show_images(utils.make_grid(images,nrow=batch_size))
        show_images(utils.make_grid(torchvision.transforms.Grayscale()(images),nrow=batch_size))
        show_images(utils.make_grid(predicted_rgb_images,nrow=batch_size))



In [None]:
PATH = os.getcwd() + "/network_cat_370_epochs_4_8.pth"
torch.save(colCNN.to('cpu').state_dict(), PATH)   #to save a network

# Load the network from file
#network = ColorizationCNN(nb_colour_bins = Q)
#network.load_state_dict(torch.load(PATH))
#network.eval()

network = colCNN.to('cpu')

dataset_for_test = CustomImageDataset("cat_dataset/training_set",transform=transform)
loader_for_test = data.DataLoader(dataset, batch_size = 4, shuffle = True)

dataset_for_test_new = CustomImageDataset("cat_dataset/testing_set",transform=transform)
loader_for_test_new = data.DataLoader(dataset_for_test_new, batch_size = 4, shuffle = True)

show_col_vs_truth(network,loader_for_test, 4)

show_col_vs_truth(network,loader_for_test_new, 4)