## All Imports Merged

In [0]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from skimage.color import rgb2lab, lab2rgb
from skimage.io import imread
from skimage.transform import resize
import sklearn.neighbors as ne
from sklearn.model_selection import train_test_split
import scipy.misc

from math import sqrt, pi
import time
import os
from os import listdir, walk
from os.path import join, isfile, isdir
import pdb
import random
import sys
import getopt

# http://pytorch.org/
from os.path import exists
from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\.\([0-9]*\)\.\([0-9]*\)$/cu\1\2/'
accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'

!pip install -q http://download.pytorch.org/whl/{accelerator}/torch-0.4.1-{platform}-linux_x86_64.whl torchvision
import torch
from torch.utils.data import Dataset
import torchvision.datasets as dsets
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from google.colab import drive


!pip install --no-cache-dir -I pillow

from IPython.display import Math, HTML
display(HTML("<script src='https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.3/"
               "latest.js?config=default'></script>"))


Collecting pillow
[?25l  Downloading https://files.pythonhosted.org/packages/62/94/5430ebaa83f91cc7a9f687ff5238e26164a779cca2ef9903232268b0a318/Pillow-5.3.0-cp36-cp36m-manylinux1_x86_64.whl (2.0MB)
[K    100% |████████████████████████████████| 2.0MB 32.1MB/s 
[?25hInstalling collected packages: pillow
Successfully installed pillow-5.3.0


In [0]:
cuda = True if torch.cuda.is_available() else False
drive.mount('/content/gdrive')
#defining the main path of the drive where all contents are saved.
StatePath = "gdrive/My Drive/AIProject/PytorchVersion"
DatasetPath = StatePath+"/flowers"

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/gdrive


In [0]:
os.makedirs(StatePath, exist_ok=True)
os.makedirs(StatePath+"/states", exist_ok=True)

## Hyper Parameters

In [0]:
epochs = 1000
batch_size = 10
imageSize = 128
learningRate = 0.001
print_freq = 10
save_freq = 2

## Color Utilities

The ab colorspace was quantized into bins with grid size 10. The number of quantized ab values $Q = 313$. These qauntized values are kept in $\texttt{pts_in_hull.npy}$. The following class $\texttt{NNEncode}$ implements important functions as discussed in research paper.



---


The function $\texttt{imgEncodeTorch}$ implements the $H_{gt}^{-1}$ function which converts ground truth colors to a vector $Z$ using a soft encoding scheme. Here the $ab$ colorspace (ground truth) is encoded into quantized $ab$ space according to the file $\texttt{pts_in_hull.npy}$.




In [0]:
class NNEncode():
    def __init__(self, NN=5, sigma=5, km_filepath=join(StatePath, 'static', 'pts_in_hull.npy'), train=True, location='cuda'):
        self.cc = np.load(km_filepath)
        self.NN = int(NN)
        self.sigma = sigma
        self.nbrs = ne.NearestNeighbors(
            n_neighbors=NN, algorithm='ball_tree').fit(self.cc)
        if train:
            self.weights = torch.load(StatePath+'/static/weights_test')
            if ('cuda' in location):
                self.weights = self.weights.cuda()


    # computes soft encoding of ground truth ab image, multiplied by weight (for class rebalancing)
    #for training
    def imgEncodeTorch(self, abimg):
        abimg = abimg.cuda()
        w, h = abimg.shape[1], abimg.shape[2]
        label = torch.zeros((w*h, 313))
        label = label.cuda()

        (dists, indexes) = self.nbrs.kneighbors(
            abimg.view(abimg.shape[0], -1).t(), self.NN)
        dists = torch.from_numpy(dists).float().cuda()
        indexes = torch.from_numpy(indexes).cuda()

        weights = torch.exp(-dists**2/(2*self.sigma**2)).cuda()
        weights = weights/torch.sum(weights, dim=1).view(-1, 1)

        pixel_indexes = torch.Tensor.long(torch.arange(
            start=0, end=abimg.shape[1]*abimg.shape[2])[:, np.newaxis])
        pixel_indexes = pixel_indexes.cuda()
        label[pixel_indexes, indexes] = weights
        label = label.t().contiguous().view(313, w, h)

        rebal_indexes = indexes[:, 0]
        rebal_weights = self.weights[rebal_indexes]
        rebal_weights = rebal_weights.view(w, h)
        rebal_label = rebal_weights * label

        return rebal_label
    def bin2color(self, idx):
        return self.cc[idx]
    def uint_color2tanh_range(img):
        return img / 128.0 - 1.0
    def tanh_range2uint_color(img):
        return (img * 128.0 + 128.0).astype(np.uint8)
    def modelimg2cvimg(img):
        cvimg = np.array(img[0, :, :, :]).transpose(1, 2, 0)
        return tanh_range2uint_color(cvimg)

This function is implemented to save the results of every $10^{th}$ epoch and show us how the model is learning an image.

In [0]:
def sample_image(grayImage, predImage, actualImage, batch, index):
    gen_imgs = np.concatenate((predImage, actualImage), axis=1)
    os.makedirs(StatePath+"/images/"+str(batch), exist_ok=True)
    scipy.misc.imsave(StatePath+"/images/"+str(batch)+"/"+str(index)+'.jpg', gen_imgs)

## Making Dataset

This function is used to make train, validate and tests datasets 

In [0]:
class CustomImages(Dataset):
    def __init__(self, root, train=True, val=False, color_space='lab', transform=None, test_size=0.1, val_size=0.125, location='cuda'):

        self.root_dir = root
        all_files = []
        for r, _, files in walk(self.root_dir):
          for f in files:
            if f.endswith('.jpg'):
              all_files.append(join(r, f))
        
        train_val_files, test_files = train_test_split(
            all_files, test_size=test_size, random_state=69)
        
        
        train_files, val_files = train_test_split(train_val_files,
                                                  test_size=val_size, random_state=69)
        
        if (train and val):
            self.filenames = val_files
        elif train:
            self.filenames = train_files
        else:
            self.filenames = test_files

        self.color_space = color_space
        if (self.color_space not in ['rgb', 'lab']):
            raise(NotImplementedError)
        self.transform = transform
        self.location = location
        self.nnenc = NNEncode(location=self.location)
        self.train = train

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img = imread(self.filenames[idx])
        if self.color_space == 'lab':
            img = rgb2lab(img)
        if self.transform is not None:
            img = self.transform(img)
        bwimg = img[:, :, 0:1].transpose(2, 0, 1)
        bwimg = torch.from_numpy(bwimg).float()
        abimg = img[:, :, 1:].transpose(2, 0, 1)    # abimg dim: 2, h, w
        abimg = torch.from_numpy(abimg).float()
        label = -1
        if (self.train):
            if ('cuda' in self.location):
                label = self.nnenc.imgEncodeTorch(abimg)
            #else:
             #   label = self.nnenc.imgEncode(abimg)
        return (bwimg, label, abimg)

If the image is of size greater than 128 by 128, we will rescale it using the following function.

In [0]:
class Rescale(object):

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image = sample
        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size
        new_h, new_w = int(new_h), int(new_w)
        img = resize(image, (new_h, new_w))[:self.output_size, :self.output_size, :]
        return img

## Class Rebalancing

The loss function is dominated by desaturated $ab$ values if the distribution of $ab$ values is strongly biased towards low ab values. 

This biasness is removed by reweighting the loss of each pixel at train time based on the pixel color rarity. Each pixel is weighed by factor $w \in R^Q$, based on its closest $ab$ bin.






In [0]:
# calculate the weight for each bin based on empirical probability, for class rebalancing
# only needs to be run once
def cal_emp_weights(dset, bins_num=313, sigma=5, lamda=0.5):
    cc = np.load(os.path.join(StatePath, 'static', 'pts_in_hull.npy'))
    nbrs = ne.NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(cc)

    bins_prob = torch.zeros(bins_num)

    print('Dataset length:', len(dset))
    for i in range(len(dset)):
        if (i%100==0):
            print('Reading Image:', i)
        _, _, abimg = dset[i]
        _, indexes = nbrs.kneighbors(abimg.view(abimg.shape[0],-1).t(), 1)
        bins_prob[torch.from_numpy(indexes).view(-1)] += 1
    bins_sum = bins_prob.sum()
    bins_prob /= bins_sum

    w = 1/((1 - lamda) * bins_prob + lamda / bins_num)
    w /= ((bins_prob * w).sum())
    torch.save(w, StatePath+'/static/weights_test')
    return w

In [0]:
entire_dataset = CustomImages(DatasetPath, train=True, test_size=0.1, val_size=0) #40 images for test
print("final lenght",len(entire_dataset))
a = cal_emp_weights(entire_dataset, 313)

final lenght 1440
Dataset length: 1440
Reading Image: 0
Reading Image: 100
Reading Image: 200
Reading Image: 300
Reading Image: 400
Reading Image: 500
Reading Image: 600
Reading Image: 700
Reading Image: 800
Reading Image: 900
Reading Image: 1000
Reading Image: 1100
Reading Image: 1200
Reading Image: 1300
Reading Image: 1400


## Loss Function

Euclidean loss is not robust to the inherent ambiguity and multimodal
nature of the colorization problem. If an image can contain a set of distinct $ab$ values, the optimal solution to the Euclidean loss will be the mean of the set. In color prediction, this averaging effect favors grayish, desaturated results. Thus, the research paper uses multinomial cross entropy loss to element desaturation of images. 

In [0]:
print(1)
class MultinomialCELoss(nn.Module):
    def __init__(self):
        super(MultinomialCELoss, self).__init__()

    # x dim: n, q, h, w
    # y dim: n, q, h, w
    # n number of cases
    # h, w height width
    # q number of bins
    # output: loss, as a float
    def forward(self, x, y):
        # softmax 
        x = x + 1e-8 #add a small number in x to avoid number 0.
        x = torch.log(x)
        zlogz = y*x
        loss = - zlogz.sum()
        loss /= (x.shape[0] * x.shape[2] * x.shape[3])
        return loss

## CNN architecture

This architecture uses multiple layers of CNN and maps the image pixels to a probability distribution of depth $313$. This result is described as $\hat Z$ in the research paper. The probability distribution that the model learns is then evaluated with the multinomial loss function described above.

$L_{cl}(\hat Z, Z) = -\sum{v(Z_{h,w})} \sum Z_{h,w,q} log (\hat Z_{h,w,q}) $


In [0]:
class ColorfulColorizer(nn.Module):
    def __init__(self):
        super(ColorfulColorizer, self).__init__()

        self.op_1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
        )
        self.op_2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128)
        )
        self.op_3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(256)
        )
        self.op_4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(512)
        )
        self.op_5 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, padding=2, dilation=2),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, padding=2, dilation=2),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, padding=2, dilation=2),
            nn.ReLU(),
            nn.BatchNorm2d(512)
        )
        self.op_6 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, padding=2, dilation=2),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, padding=2, dilation=2),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, padding=2, dilation=2),
            nn.ReLU(),
            nn.BatchNorm2d(512)
        )
        self.op_7 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(512)
        )
        self.op_8 = nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 313, kernel_size=1),
            nn.UpsamplingBilinear2d(scale_factor=4)
        )

        self.op_9 = nn.Sequential(
            nn.Softmax(dim=1)
        )
        self.op_1.apply(self.init_weights)
        self.op_2.apply(self.init_weights)
        self.op_3.apply(self.init_weights)
        self.op_4.apply(self.init_weights)
        self.op_5.apply(self.init_weights)
        self.op_6.apply(self.init_weights)
        self.op_7.apply(self.init_weights)
        self.op_8.apply(self.init_weights)

    def init_weights(self, m):
        if type(m) == nn.Conv2d:
            torch.nn.init.xavier_uniform_(m.weight)
            m.bias.data.fill_(0.01)

    def forward(self, x):
        out = self.op_1(x)
        out = self.op_2(out)
        out = self.op_3(out)
        out = self.op_4(out)
        out = self.op_5(out)
        out = self.op_6(out)
        out = self.op_7(out)
        out = self.op_8(out)
        out = self.op_9(out)
        return out

## Main - Training Data

In [0]:
def main(dset_root, batch_size, num_epochs, print_freq, encoder, criterion,
         optimizer, step_every_iteration=False):
    continue_training = True
    location = 'cuda'
    rescale = Rescale(imageSize)
    train_dataset = CustomImages(
        root=dset_root, train=True, location=location, transform=rescale, test_size=0)
    
    val_dataset = CustomImages(
        root=dset_root, train=True, val=True, location=location, transform=rescale) #val files
    
    
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True)

    val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                               batch_size=batch_size,
                                               shuffle=True)

    if continue_training and os.path.isfile('best_model.pkl'):
        encoder.load_state_dict(torch.load(
            'best_model.pkl', map_location=location))
        print('Model loaded!')


    if 'cuda' in location:
        print('Using:', torch.cuda.get_device_name(torch.cuda.current_device()))
        encoder.cuda()
        criterion.cuda()

    best_loss = 100
    losses = []

    for epoch in range(num_epochs):
        # train for one epoch
        epoch_losses = train(train_loader, encoder, criterion, optimizer, epoch, location, step_every_iteration, num_epochs, print_freq)
        losses.append(epoch_losses)

        if epoch % save_freq == 0:
          save_checkpoint(encoder.state_dict(), str(epoch)+".pkl")
          save_model_results(train_dataset, encoder, epoch)
          # coloring 5 random images and saving the output
          
          
          

        # evaluate on validation set
        val_loss = validate(val_loader, encoder, criterion, location, num_epochs, print_freq)
#         if (not step_every_iteration):
#             scheduler.step(val_loss.data.item())
        is_best = val_loss.data.item() < best_loss

        if is_best:
            print('New best score! Model saved as best_model.pkl')
            best_loss = val_loss.data.item()
            save_checkpoint(encoder.state_dict(), is_best)
    return losses

In [0]:
def save_checkpoint(state, is_best=False, filename='colorizer2.pkl'):
    torch.save(state, StatePath+"/states/"+filename)
    if is_best:
        torch.save(state, 'best_model.pkl')

After calculating the loss of each image between the ground truth encoded/ quantized ab space and the learned probability distribution $\hat Z$, the prediction of ab colorspace of images is done via taking annealed mean of the learned probability distribution. This is because taking mean of this distribution poses the same problems as they were with computing Euclidean Loss, desaturated images. Hence a function $H(\hat Z_{h,w})$ which takes the learned probability distribution as an input is implemented as described in research paper, and it outputs the annealed mean of the distribution for every pixel. This gives us the predicted ab colorspace for that image which is then converted to rgb colorspace to give results. 

According to the research paper, a temperature value $T = 0.38$ captures the vibrancy of the mode while maintaining the spatial coherence of the mean. 

In [0]:
def save_model_results(dset, model, batchesDone, location='cuda'):
  test_cases = np.floor(np.random.rand(5) * len(dset)).astype(int)
  test_cases = np.append(test_cases, [0], 0)
  outputs = []
  images = []
  labels = []
  for c in test_cases:
      image,_, label = dset[c]
      image = image.unsqueeze(0)
      with torch.no_grad():
          if 'cuda' in location:
            image = image.cuda()
            label = label.cuda()
          images.append(image)
          labels.append(label)
          output = model(image)
          outputs.append(output)
          
  T = 0.38
  q = 313  # number of colours
  nnenc = NNEncode()
  bin_index = np.arange(q)
  ab_list = nnenc.bin2color(bin_index)
  for i in range(len(test_cases)):
    l_layer = images[i].data[0].cpu().numpy()
    bin_probabilities = outputs[i].data[0].cpu().numpy()  # bin_probabilities dim: q, h, w
    ab_label = labels[i].data.cpu().numpy().astype('float64')

    # convert bin_probab -> ab_pred
    bin_probabilities = np.exp(np.log(bin_probabilities)/T)
    bin_sum = bin_probabilities.sum(0)
    bin_sum = bin_sum.reshape((1, bin_sum.shape[0], bin_sum.shape[1]))
    bin_probabilities /= bin_sum

    # ab_pred dim: 2, h, w
    ab_pred = (bin_probabilities[:, np.newaxis, :, :] * ab_list[:, :, np.newaxis, np.newaxis]).sum(0)

    img_input = l_layer[0]
#     img_input = np.concatenate((l_layer, torch.zeros([2,128,128])), axis=0)
    img_pred = np.concatenate((l_layer, ab_pred), axis=0)
    img_actual = np.concatenate((l_layer, ab_label), axis=0)
    
#     img_input = lab2rgb(img_input.transpose(1, 2, 0))
    img_pred = lab2rgb(img_pred.transpose(1, 2, 0))
    img_actual = lab2rgb(img_actual.transpose(1, 2, 0))
    
    sample_image(img_input, img_pred, img_actual, batchesDone, i)

In [0]:
def train(train_loader, model, criterion, optimizer, epoch,
          location, step_every_iteration,num_epochs, print_freq):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    epoch_losses = []

    # switch to train mode
    model.train()

    end = time.time()
    for i, (image, target, _) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        image_var = Variable(image)
        target_var = Variable(target)

        if 'cuda' in location:
            image_var = image_var.cuda()
            target_var = target_var.cuda()

        # compute output
        output = model(image_var)
        
        loss = criterion(output, target_var)
        
        losses.update(loss.data.item(), image.size(0))
        epoch_losses.append(loss.data.item())
        
        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        batchDone = epoch * len(train_loader) + i
        
        if batchDone % print_freq == 0:
          print('Epoch: [{0}/{1}][{2}/{3}]\t'
            'BatchTime(Average) {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
            'DataTime(Average) {data_time.val:.3f} ({data_time.avg:.3f})\t'
            'Loss(Average) {loss.val:.4f} ({loss.avg:.4f})\t'
            .format(
             epoch, num_epochs, i, len(train_loader), batch_time=batch_time,
              data_time=data_time, loss=losses))
    return epoch_losses

In [0]:
def validate(val_loader, model, criterion, location,num_epochs, print_freq):
    batch_time = AverageMeter()
    losses = AverageMeter()
    loss = 0
    # switch to evaluate mode
    model.eval()

    end = time.time()
    for i, (image, target, _) in enumerate(val_loader):
        with torch.no_grad():
          image_var = Variable(image)
          target_var = Variable(target)

        if 'cuda' in location:
            image_var = image_var.cuda()
            target_var = target_var.cuda()

        # compute output
        output = model(image_var)
        loss = criterion(output, target_var)
        losses.update(loss.data.item(), image.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

    return loss

In [0]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [0]:
## Training the model here by calling main() which will run the training loop
dset_root = DatasetPath
encoder = ColorfulColorizer()
criterion = MultinomialCELoss()
optimizer = torch.optim.SGD(encoder.parameters(), lr=learningRate)
main(dset_root, batch_size, epochs, print_freq, encoder, criterion, optimizer)

Using: Tesla K80


  warn("The default mode, 'constant', will be changed to 'reflect' in "


Epoch: [0/1000][0/140]	BatchTime(Average) 0.885 (0.885)	DataTime(Average) 0.626 (0.626)	Loss(Average) 4.8431 (4.8431)	
Epoch: [0/1000][10/140]	BatchTime(Average) 2.037 (1.664)	DataTime(Average) 1.843 (1.464)	Loss(Average) 4.9223 (4.8267)	
Epoch: [0/1000][20/140]	BatchTime(Average) 1.266 (1.598)	DataTime(Average) 1.069 (1.402)	Loss(Average) 5.8862 (4.9357)	
Epoch: [0/1000][30/140]	BatchTime(Average) 2.144 (1.618)	DataTime(Average) 1.951 (1.423)	Loss(Average) 4.9003 (4.8795)	
Epoch: [0/1000][40/140]	BatchTime(Average) 1.503 (1.599)	DataTime(Average) 1.311 (1.405)	Loss(Average) 4.2474 (4.8750)	
Epoch: [0/1000][50/140]	BatchTime(Average) 2.494 (1.616)	DataTime(Average) 2.309 (1.423)	Loss(Average) 5.0968 (4.8343)	
Epoch: [0/1000][60/140]	BatchTime(Average) 1.988 (1.647)	DataTime(Average) 1.794 (1.452)	Loss(Average) 4.7179 (4.7546)	
Epoch: [0/1000][70/140]	BatchTime(Average) 1.354 (1.630)	DataTime(Average) 1.161 (1.436)	Loss(Average) 5.6158 (4.7691)	
Epoch: [0/1000][80/140]	BatchTime(Average

`imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imwrite`` instead.
  after removing the cwd from sys.path.


New best score! Model saved as best_model.pkl
Epoch: [1/1000][0/140]	BatchTime(Average) 1.055 (1.055)	DataTime(Average) 0.876 (0.876)	Loss(Average) 5.8726 (5.8726)	
Epoch: [1/1000][10/140]	BatchTime(Average) 1.292 (1.304)	DataTime(Average) 1.097 (1.110)	Loss(Average) 5.9741 (4.9421)	
Epoch: [1/1000][20/140]	BatchTime(Average) 1.266 (1.295)	DataTime(Average) 1.069 (1.100)	Loss(Average) 4.2507 (4.8344)	
Epoch: [1/1000][30/140]	BatchTime(Average) 1.333 (1.287)	DataTime(Average) 1.127 (1.092)	Loss(Average) 4.9865 (4.9288)	
Epoch: [1/1000][40/140]	BatchTime(Average) 1.316 (1.288)	DataTime(Average) 1.119 (1.093)	Loss(Average) 5.0666 (4.8857)	
Epoch: [1/1000][50/140]	BatchTime(Average) 1.265 (1.284)	DataTime(Average) 1.070 (1.089)	Loss(Average) 4.8100 (4.8825)	
Epoch: [1/1000][60/140]	BatchTime(Average) 1.309 (1.284)	DataTime(Average) 1.112 (1.089)	Loss(Average) 4.9589 (4.8713)	
Epoch: [1/1000][70/140]	BatchTime(Average) 1.333 (1.287)	DataTime(Average) 1.140 (1.091)	Loss(Average) 4.0681 (4.83

## Testing Images from Test Dataset

In [0]:
rescale = Rescale(imageSize)
test_dataset = CustomImages(
    root=DatasetPath, train=False, transform=rescale)

location = 'cuda'
test_cases = np.floor(np.random.rand(5) * len(test_dataset)).astype(int)
test_cases = np.append(test_cases, [0], 0)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)

encoder = ColorfulColorizer()
encoder.load_state_dict(torch.load(StatePath+'/states/colorizer.pkl'))
if 'cuda' in location:
    print('Using:', torch.cuda.get_device_name(torch.cuda.current_device()))
    encoder.cuda()

encoder.eval()
# encoder.parameters()

outputs = []
images = []
labels = []
for c in test_cases:
    print('Encoding image number:', c)
    image,_, label = test_dataset[c]
    image = image.unsqueeze(0)
    with torch.no_grad():
        if 'cuda' in location:
          image = image.cuda()
          label = label.cuda()
        images.append(image)
        labels.append(label)
        print(image.shape)
        output = encoder(image)
        outputs.append(output)

In [0]:
T = 0.38
q = 313  # number of colours
nnenc = NNEncode()

In [0]:
bin_index = np.arange(q)
print('Getting ab_list')
ab_list = nnenc.bin2color(bin_index)   # q, 2

In [0]:
f, axarr = plt.subplots(len(test_cases), 3)
for i in range(len(test_cases)):
    l_layer = images[i].data[0].cpu().numpy()
    bin_probabilities = outputs[i].data[0].cpu().numpy()  # bin_probabilities dim: q, h, w
    ab_label = labels[i].data.cpu().numpy().astype('float64')

    # convert bin_probab -> ab_pred
    bin_probabilities = np.exp(np.log(bin_probabilities)/T)
    bin_sum = bin_probabilities.sum(0)
    bin_sum = bin_sum.reshape((1, bin_sum.shape[0], bin_sum.shape[1]))
    bin_probabilities /= bin_sum

    # ab_pred dim: 2, h, w
    ab_pred = (bin_probabilities[:, np.newaxis, :, :] * ab_list[:, :, np.newaxis, np.newaxis]).sum(0)

    img_input = l_layer[0]
#     img_input = np.concatenate((l_layer, torch.zeros([2,128,128])), axis=0)
    img_pred = np.concatenate((l_layer, ab_pred), axis=0)
    img_actual = np.concatenate((l_layer, ab_label), axis=0)
    
#     img_input = lab2rgb(img_input.transpose(1, 2, 0))
    img_pred = lab2rgb(img_pred.transpose(1, 2, 0))
    img_actual = lab2rgb(img_actual.transpose(1, 2, 0))
    
    axarr[i][0].imshow(img_input)
    axarr[i][1].imshow(img_pred)
    axarr[i][2].imshow(img_actual)
    sample_image(img_input, img_pred, img_actual, 1, i)
plt.show()