In [35]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# For conversion
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from PIL import Image
# For utilities
import os, shutil, time

In [1]:
%load_ext tensorboard
%tensorboard --logdir=runs

In [37]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [38]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()
print(use_gpu)

True


In [39]:
# # calculate mean and std deviation

# from pathlib import Path

# imageFilesDir = Path('../datasets/dataset/train/colour')
# files = list(imageFilesDir.rglob('*.jpg'))

# # Since the std can't be calculated by simply finding it for each image and averaging like  
# # the mean can be, to get the std we first calculate the overall mean in a first run then  
# # run it again to get the std.

# mean = np.array([0.,0.,0.])
# stdTemp = np.array([0.,0.,0.])
# std = np.array([0.,0.,0.])

# numSamples = len(files)

# for i in range(numSamples):
#     im = np.asarray(Image.open(str(files[i])).convert("RGB"))
#     im = im / 255.
    
#     for j in range(3):
#         mean[j] += np.mean(im[:,:,j])

# mean = (mean/numSamples)

# for i in range(numSamples):
#     im = np.asarray(Image.open(str(files[i])).convert("RGB"))
#     im = im / 255.
    
#     for j in range(3):
#         stdTemp[j] += ((im[:,:,j] - mean[j])**2).sum()/(im.shape[0]*im.shape[1])

# std = np.sqrt(stdTemp/numSamples)

# print(mean)
# print(std)

In [40]:
SIZE = 64
class LabImageFolder(torch.utils.data.Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC),
                transforms.RandomCrop(SIZE),
                transforms.RandomHorizontalFlip(), 
#                 torchvision.transforms.ToTensor(),
#                 torchvision.transforms.Normalize((0.44549203, 0.39243516, 0.29437839), (0.2846632, 0.24304812, 0.26394198)),
            ])
        elif split == 'val':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC), 
                transforms.RandomCrop(SIZE),
#                 torchvision.transforms.ToTensor(),
#                 torchvision.transforms.Normalize((0.44549203, 0.39243516, 0.29437839), (0.2846632, 0.24304812, 0.26394198)),
            ])
            
        self.split = split
        self.size = SIZE
        self.paths = [os.path.join(paths, file) for file in os.listdir(paths) if os.path.isfile(
            os.path.join(paths, file))]
        
        
    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("RGB")
        img_original = self.transforms(img)
        img_original = np.asarray(img_original)
        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        img_gray = rgb2gray(img_original)
        img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
        return img_gray, img_ab
    
    def __len__(self):
        return len(self.paths)

In [41]:
# Training
batch_size = 128
train_imagefolder = LabImageFolder('../../datasets/dataset/train/colour')
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=batch_size, shuffle=True)
# Validation 
val_imagefolder = LabImageFolder('../../datasets/dataset/test/colour' , 'val')
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=batch_size, shuffle=False)

In [42]:
kernel_size=3
stride_en=2
stride_de=1
padding=1
scale_factor=2
padding_mode='zeros'
class ColorizationNet(nn.Module):
  def __init__(self, input_size=128):
    super(ColorizationNet, self).__init__()

    
#     resnet = models.resnet18(num_classes=10)
#     resnet.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
#     resnet.conv1.weight = nn.Parameter(resnet.conv1.weight.sum(dim=1).unsqueeze(1)) 
#     self.midlevel_resnet = nn.Sequential(*list(resnet.children())[0:6])
    
    
    self.encoder = nn.Sequential(       
      nn.Conv2d(1, 16, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode),
      nn.BatchNorm2d(16),
      nn.ReLU(),
      nn.Conv2d(16, 32, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode),
      nn.BatchNorm2d(32),
      nn.ReLU(),
      nn.Conv2d(32, 64, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode),
      nn.BatchNorm2d(64),
      nn.ReLU(),   
    )     
        
    self.decoder = nn.Sequential(  
      nn.ConvTranspose2d(64, 32, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode),
      nn.BatchNorm2d(32),
      nn.ReLU(),
      nn.Upsample(scale_factor=scale_factor),   
      nn.ConvTranspose2d(32, 16, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode),
      nn.BatchNorm2d(16),
      nn.ReLU(),
      nn.Upsample(scale_factor=scale_factor),
      nn.ConvTranspose2d(16, 8, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode),
      nn.BatchNorm2d(8),
      nn.ReLU(),
      nn.ConvTranspose2d(8, 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode),
      nn.Upsample(scale_factor=scale_factor)
    )

  def forward(self, input):

    encoder = self.encoder(input)
    # Upsample to get colors
    output = self.decoder(encoder)
    return output

In [43]:
model = ColorizationNet()

In [44]:
criterion = nn.MSELoss()

In [45]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [46]:
# # Move model and loss function to GPU
if use_gpu: 
    criterion = criterion.cuda()
    model = model.cuda()

In [47]:
from torchsummary import summary

# summary(model, (1, 32, 32))
summary(model, (1, 64, 64))
# summary(model, (1, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 32, 32]             160
       BatchNorm2d-2           [-1, 16, 32, 32]              32
              ReLU-3           [-1, 16, 32, 32]               0
            Conv2d-4           [-1, 32, 16, 16]           4,640
       BatchNorm2d-5           [-1, 32, 16, 16]              64
              ReLU-6           [-1, 32, 16, 16]               0
            Conv2d-7             [-1, 64, 8, 8]          18,496
       BatchNorm2d-8             [-1, 64, 8, 8]             128
              ReLU-9             [-1, 64, 8, 8]               0
  ConvTranspose2d-10             [-1, 32, 8, 8]          18,464
      BatchNorm2d-11             [-1, 32, 8, 8]              64
             ReLU-12             [-1, 32, 8, 8]               0
         Upsample-13           [-1, 32, 16, 16]               0
  ConvTranspose2d-14           [-1, 16,

In [48]:
class AverageMeter(object):
  '''A handy class from the PyTorch ImageNet tutorial''' 
  def __init__(self):
    self.reset()
  def reset(self):
    self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
  def update(self, val, n=1):
    self.val = val
    self.sum += val * n
    self.count += n
    self.avg = self.sum / self.count

def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
  '''Show/save rgb image from grayscale and ab channels
     Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
  plt.clf() # clear matplotlib 
  color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
  color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
  color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
  color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
  color_image = lab2rgb(color_image.astype(np.float64))
  grayscale_input = grayscale_input.squeeze().numpy()
  if save_path is not None and save_name is not None: 
    plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
    plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

In [49]:
color_imgs = 'outputs/color/'
gray_imgs = 'outputs/gray/'

In [50]:
def validate(val_loader, model, criterion, save_images, epoch):
    _loss = AverageMeter()

    model.eval()
    already_saved_images = False
    for gray, ab in val_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()

        # Run model and record loss
        output_ab = model(gray) # throw away class predictions
        loss = criterion(output_ab, ab)
        
        _loss.update(loss.item(), gray.size(0))

        # Save images to file
        if save_images and not already_saved_images:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)): # save at most 5 images
                save_path = {'grayscale': gray_imgs, 'colorized': color_imgs}
                save_name = 'img-{}-epoch-{}.jpg'.format(j, epoch)
                to_rgb(gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

    print('Validate: Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
         loss=_loss))

    print('Finished validation.')

    writer.add_scalar("Loss/test", _loss.avg, epoch)
    return _loss.avg

In [51]:
# writer.add_graph(model, [model, train_loader.dataset])
# writer.close()

Tracer cannot infer type of (ColorizationNet(
  (encoder): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
  )
  (decoder): Sequential(
    (0): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Upsample(scale_factor=2.0, mode=nearest)
    (4): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(16, eps=1e-05, mo

RuntimeError: Tracer cannot infer type of (ColorizationNet(
  (encoder): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
  )
  (decoder): Sequential(
    (0): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Upsample(scale_factor=2.0, mode=nearest)
    (4): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): Upsample(scale_factor=2.0, mode=nearest)
    (8): ConvTranspose2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): ConvTranspose2d(8, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): Upsample(scale_factor=2.0, mode=nearest)
  )
), <__main__.LabImageFolder object at 0x7fc9d812ce20>)
:Cannot infer concrete type of torch.nn.Module

In [52]:
def train(train_loader, model, criterion, optimizer, epoch):
    print(f'Starting training epoch {epoch}')
    _loss = AverageMeter()
    
    model.train()

    for gray, ab in train_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()
            
        optimizer.zero_grad()

        output_ab = model(gray) 
        loss = criterion(output_ab, ab) 
        loss.backward()
        optimizer.step()
        
        _loss.update(loss.item(), gray.size(0))
        
    print(f'Epoch: {epoch}, Loss {_loss.val:.4f} ({_loss.avg:.4f})')
    
    print(f'Finished training epoch {epoch}')
    writer.add_scalar("Loss/train", _loss.avg, epoch)

In [53]:
# Make folders and set parameters
checkpoints = 'checkpoints'
os.makedirs(color_imgs, exist_ok=True)
os.makedirs(gray_imgs, exist_ok=True)
os.makedirs(checkpoints, exist_ok=True)
save_images = True
best_losses = 1e10
epochs = 200

In [54]:
# Train model
for epoch in range(epochs):
    # Train for one epoch, then validate
    train(train_loader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(val_loader, model, criterion, save_images, epoch)
    # Save checkpoint and replace old best model if current model is better
    if losses < best_losses:
        best_losses = losses
        torch.save(model.state_dict(), '{}/model-epoch-{}-losses-{:.3f}.pth'.format(checkpoints,epoch+1,losses))

Starting training epoch 0
Epoch: 0, Loss 0.0099 (0.0343)
Finished training epoch 0


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: Loss 0.0080 (0.0080)	
Finished validation.
Starting training epoch 1
Epoch: 1, Loss 0.0074 (0.0082)
Finished training epoch 1
Validate: Loss 0.0074 (0.0074)	
Finished validation.
Starting training epoch 2
Epoch: 2, Loss 0.0083 (0.0079)
Finished training epoch 2


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: Loss 0.0075 (0.0074)	
Finished validation.
Starting training epoch 3
Epoch: 3, Loss 0.0076 (0.0075)
Finished training epoch 3
Validate: Loss 0.0070 (0.0072)	
Finished validation.
Starting training epoch 4
Epoch: 4, Loss 0.0083 (0.0074)
Finished training epoch 4


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: Loss 0.0074 (0.0072)	
Finished validation.
Starting training epoch 5
Epoch: 5, Loss 0.0078 (0.0073)
Finished training epoch 5
Validate: Loss 0.0069 (0.0070)	
Finished validation.
Starting training epoch 6
Epoch: 6, Loss 0.0066 (0.0071)
Finished training epoch 6
Validate: Loss 0.0066 (0.0067)	
Finished validation.
Starting training epoch 7
Epoch: 7, Loss 0.0069 (0.0070)
Finished training epoch 7
Validate: Loss 0.0066 (0.0067)	
Finished validation.
Starting training epoch 8
Epoch: 8, Loss 0.0069 (0.0070)
Finished training epoch 8
Validate: Loss 0.0066 (0.0067)	
Finished validation.
Starting training epoch 9
Epoch: 9, Loss 0.0072 (0.0071)
Finished training epoch 9
Validate: Loss 0.0066 (0.0067)	
Finished validation.
Starting training epoch 10
Epoch: 10, Loss 0.0074 (0.0069)
Finished training epoch 10
Validate: Loss 0.0066 (0.0067)	
Finished validation.
Starting training epoch 11
Epoch: 11, Loss 0.0079 (0.0069)
Finished training epoch 11
Validate: Loss 0.0066 (0.0066)	
Finished v

  return func(*args, **kwargs)


Validate: Loss 0.0067 (0.0067)	
Finished validation.
Starting training epoch 13
Epoch: 13, Loss 0.0069 (0.0068)
Finished training epoch 13
Validate: Loss 0.0065 (0.0065)	
Finished validation.
Starting training epoch 14
Epoch: 14, Loss 0.0061 (0.0068)
Finished training epoch 14
Validate: Loss 0.0064 (0.0065)	
Finished validation.
Starting training epoch 15
Epoch: 15, Loss 0.0066 (0.0067)
Finished training epoch 15
Validate: Loss 0.0067 (0.0066)	
Finished validation.
Starting training epoch 16
Epoch: 16, Loss 0.0066 (0.0067)
Finished training epoch 16
Validate: Loss 0.0064 (0.0064)	
Finished validation.
Starting training epoch 17
Epoch: 17, Loss 0.0070 (0.0067)
Finished training epoch 17
Validate: Loss 0.0065 (0.0065)	
Finished validation.
Starting training epoch 18
Epoch: 18, Loss 0.0066 (0.0067)
Finished training epoch 18
Validate: Loss 0.0065 (0.0065)	
Finished validation.
Starting training epoch 19
Epoch: 19, Loss 0.0063 (0.0067)
Finished training epoch 19
Validate: Loss 0.0066 (0.00

  return func(*args, **kwargs)


Validate: Loss 0.0068 (0.0068)	
Finished validation.
Starting training epoch 21
Epoch: 21, Loss 0.0064 (0.0066)
Finished training epoch 21
Validate: Loss 0.0063 (0.0064)	
Finished validation.
Starting training epoch 22
Epoch: 22, Loss 0.0070 (0.0066)
Finished training epoch 22
Validate: Loss 0.0066 (0.0065)	
Finished validation.
Starting training epoch 23
Epoch: 23, Loss 0.0066 (0.0065)
Finished training epoch 23
Validate: Loss 0.0066 (0.0065)	
Finished validation.
Starting training epoch 24
Epoch: 24, Loss 0.0064 (0.0065)
Finished training epoch 24
Validate: Loss 0.0065 (0.0065)	
Finished validation.
Starting training epoch 25
Epoch: 25, Loss 0.0067 (0.0065)
Finished training epoch 25
Validate: Loss 0.0064 (0.0063)	
Finished validation.
Starting training epoch 26
Epoch: 26, Loss 0.0062 (0.0065)
Finished training epoch 26
Validate: Loss 0.0064 (0.0064)	
Finished validation.
Starting training epoch 27
Epoch: 27, Loss 0.0066 (0.0064)
Finished training epoch 27
Validate: Loss 0.0063 (0.00

  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: Loss 0.0062 (0.0062)	
Finished validation.
Starting training epoch 45
Epoch: 45, Loss 0.0062 (0.0062)
Finished training epoch 45
Validate: Loss 0.0061 (0.0062)	
Finished validation.
Starting training epoch 46
Epoch: 46, Loss 0.0064 (0.0061)
Finished training epoch 46
Validate: Loss 0.0063 (0.0062)	
Finished validation.
Starting training epoch 47
Epoch: 47, Loss 0.0058 (0.0061)
Finished training epoch 47
Validate: Loss 0.0065 (0.0065)	
Finished validation.
Starting training epoch 48
Epoch: 48, Loss 0.0059 (0.0061)
Finished training epoch 48
Validate: Loss 0.0064 (0.0065)	
Finished validation.
Starting training epoch 49
Epoch: 49, Loss 0.0060 (0.0061)
Finished training epoch 49
Validate: Loss 0.0060 (0.0060)	
Finished validation.
Starting training epoch 50
Epoch: 50, Loss 0.0061 (0.0061)
Finished training epoch 50


  return func(*args, **kwargs)


Validate: Loss 0.0068 (0.0067)	
Finished validation.
Starting training epoch 51
Epoch: 51, Loss 0.0061 (0.0060)
Finished training epoch 51
Validate: Loss 0.0065 (0.0066)	
Finished validation.
Starting training epoch 52
Epoch: 52, Loss 0.0060 (0.0060)
Finished training epoch 52
Validate: Loss 0.0064 (0.0063)	
Finished validation.
Starting training epoch 53
Epoch: 53, Loss 0.0057 (0.0060)
Finished training epoch 53
Validate: Loss 0.0060 (0.0061)	
Finished validation.
Starting training epoch 54
Epoch: 54, Loss 0.0061 (0.0060)
Finished training epoch 54
Validate: Loss 0.0060 (0.0060)	
Finished validation.
Starting training epoch 55
Epoch: 55, Loss 0.0061 (0.0060)
Finished training epoch 55
Validate: Loss 0.0065 (0.0063)	
Finished validation.
Starting training epoch 56
Epoch: 56, Loss 0.0070 (0.0060)
Finished training epoch 56


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: Loss 0.0062 (0.0062)	
Finished validation.
Starting training epoch 57
Epoch: 57, Loss 0.0058 (0.0060)
Finished training epoch 57
Validate: Loss 0.0062 (0.0062)	
Finished validation.
Starting training epoch 58
Epoch: 58, Loss 0.0055 (0.0059)
Finished training epoch 58
Validate: Loss 0.0059 (0.0061)	
Finished validation.
Starting training epoch 59
Epoch: 59, Loss 0.0053 (0.0059)
Finished training epoch 59
Validate: Loss 0.0061 (0.0063)	
Finished validation.
Starting training epoch 60
Epoch: 60, Loss 0.0059 (0.0059)
Finished training epoch 60
Validate: Loss 0.0060 (0.0061)	
Finished validation.
Starting training epoch 61
Epoch: 61, Loss 0.0054 (0.0059)
Finished training epoch 61
Validate: Loss 0.0059 (0.0060)	
Finished validation.
Starting training epoch 62
Epoch: 62, Loss 0.0055 (0.0059)
Finished training epoch 62
Validate: Loss 0.0060 (0.0061)	
Finished validation.
Starting training epoch 63
Epoch: 63, Loss 0.0056 (0.0058)
Finished training epoch 63


  return func(*args, **kwargs)


Validate: Loss 0.0066 (0.0063)	
Finished validation.
Starting training epoch 64
Epoch: 64, Loss 0.0065 (0.0058)
Finished training epoch 64
Validate: Loss 0.0061 (0.0063)	
Finished validation.
Starting training epoch 65
Epoch: 65, Loss 0.0054 (0.0058)
Finished training epoch 65
Validate: Loss 0.0060 (0.0060)	
Finished validation.
Starting training epoch 66
Epoch: 66, Loss 0.0057 (0.0058)
Finished training epoch 66
Validate: Loss 0.0064 (0.0064)	
Finished validation.
Starting training epoch 67
Epoch: 67, Loss 0.0063 (0.0058)
Finished training epoch 67
Validate: Loss 0.0059 (0.0060)	
Finished validation.
Starting training epoch 68
Epoch: 68, Loss 0.0055 (0.0058)
Finished training epoch 68
Validate: Loss 0.0060 (0.0061)	
Finished validation.
Starting training epoch 69
Epoch: 69, Loss 0.0051 (0.0057)
Finished training epoch 69
Validate: Loss 0.0060 (0.0062)	
Finished validation.
Starting training epoch 70
Epoch: 70, Loss 0.0056 (0.0058)
Finished training epoch 70
Validate: Loss 0.0060 (0.00

  return func(*args, **kwargs)


Validate: Loss 0.0064 (0.0064)	
Finished validation.
Starting training epoch 72
Epoch: 72, Loss 0.0063 (0.0057)
Finished training epoch 72
Validate: Loss 0.0061 (0.0060)	
Finished validation.
Starting training epoch 73
Epoch: 73, Loss 0.0059 (0.0057)
Finished training epoch 73
Validate: Loss 0.0063 (0.0067)	
Finished validation.
Starting training epoch 74
Epoch: 74, Loss 0.0052 (0.0057)
Finished training epoch 74
Validate: Loss 0.0057 (0.0060)	
Finished validation.
Starting training epoch 75
Epoch: 75, Loss 0.0058 (0.0057)
Finished training epoch 75
Validate: Loss 0.0061 (0.0061)	
Finished validation.
Starting training epoch 76
Epoch: 76, Loss 0.0061 (0.0057)
Finished training epoch 76
Validate: Loss 0.0065 (0.0064)	
Finished validation.
Starting training epoch 77
Epoch: 77, Loss 0.0057 (0.0056)
Finished training epoch 77
Validate: Loss 0.0059 (0.0061)	
Finished validation.
Starting training epoch 78
Epoch: 78, Loss 0.0059 (0.0056)
Finished training epoch 78
Validate: Loss 0.0059 (0.00

  return func(*args, **kwargs)


Validate: Loss 0.0062 (0.0061)	
Finished validation.
Starting training epoch 82
Epoch: 82, Loss 0.0060 (0.0055)
Finished training epoch 82
Validate: Loss 0.0061 (0.0061)	
Finished validation.
Starting training epoch 83
Epoch: 83, Loss 0.0058 (0.0056)
Finished training epoch 83
Validate: Loss 0.0065 (0.0063)	
Finished validation.
Starting training epoch 84
Epoch: 84, Loss 0.0060 (0.0055)
Finished training epoch 84
Validate: Loss 0.0060 (0.0061)	
Finished validation.
Starting training epoch 85
Epoch: 85, Loss 0.0054 (0.0055)
Finished training epoch 85
Validate: Loss 0.0059 (0.0061)	
Finished validation.
Starting training epoch 86
Epoch: 86, Loss 0.0057 (0.0055)
Finished training epoch 86
Validate: Loss 0.0062 (0.0060)	
Finished validation.
Starting training epoch 87
Epoch: 87, Loss 0.0050 (0.0055)
Finished training epoch 87
Validate: Loss 0.0060 (0.0060)	
Finished validation.
Starting training epoch 88
Epoch: 88, Loss 0.0058 (0.0055)
Finished training epoch 88


  return func(*args, **kwargs)


Validate: Loss 0.0062 (0.0063)	
Finished validation.
Starting training epoch 89
Epoch: 89, Loss 0.0058 (0.0054)
Finished training epoch 89
Validate: Loss 0.0066 (0.0064)	
Finished validation.
Starting training epoch 90
Epoch: 90, Loss 0.0052 (0.0054)
Finished training epoch 90
Validate: Loss 0.0058 (0.0060)	
Finished validation.
Starting training epoch 91
Epoch: 91, Loss 0.0065 (0.0054)
Finished training epoch 91
Validate: Loss 0.0062 (0.0062)	
Finished validation.
Starting training epoch 92
Epoch: 92, Loss 0.0060 (0.0055)
Finished training epoch 92


  return func(*args, **kwargs)


Validate: Loss 0.0059 (0.0060)	
Finished validation.
Starting training epoch 93
Epoch: 93, Loss 0.0051 (0.0054)
Finished training epoch 93
Validate: Loss 0.0062 (0.0060)	
Finished validation.
Starting training epoch 94
Epoch: 94, Loss 0.0061 (0.0054)
Finished training epoch 94


  return func(*args, **kwargs)


Validate: Loss 0.0063 (0.0063)	
Finished validation.
Starting training epoch 95
Epoch: 95, Loss 0.0055 (0.0054)
Finished training epoch 95
Validate: Loss 0.0059 (0.0059)	
Finished validation.
Starting training epoch 96
Epoch: 96, Loss 0.0047 (0.0054)
Finished training epoch 96
Validate: Loss 0.0060 (0.0061)	
Finished validation.
Starting training epoch 97
Epoch: 97, Loss 0.0053 (0.0054)
Finished training epoch 97
Validate: Loss 0.0061 (0.0063)	
Finished validation.
Starting training epoch 98
Epoch: 98, Loss 0.0052 (0.0054)
Finished training epoch 98
Validate: Loss 0.0059 (0.0060)	
Finished validation.
Starting training epoch 99
Epoch: 99, Loss 0.0048 (0.0054)
Finished training epoch 99
Validate: Loss 0.0062 (0.0062)	
Finished validation.
Starting training epoch 100
Epoch: 100, Loss 0.0055 (0.0053)
Finished training epoch 100
Validate: Loss 0.0061 (0.0061)	
Finished validation.
Starting training epoch 101
Epoch: 101, Loss 0.0058 (0.0053)
Finished training epoch 101
Validate: Loss 0.0061

  return func(*args, **kwargs)


Validate: Loss 0.0061 (0.0062)	
Finished validation.
Starting training epoch 108
Epoch: 108, Loss 0.0049 (0.0052)
Finished training epoch 108
Validate: Loss 0.0061 (0.0061)	
Finished validation.
Starting training epoch 109
Epoch: 109, Loss 0.0052 (0.0052)
Finished training epoch 109
Validate: Loss 0.0060 (0.0062)	
Finished validation.
Starting training epoch 110
Epoch: 110, Loss 0.0053 (0.0052)
Finished training epoch 110
Validate: Loss 0.0063 (0.0063)	
Finished validation.
Starting training epoch 111
Epoch: 111, Loss 0.0056 (0.0052)
Finished training epoch 111
Validate: Loss 0.0060 (0.0061)	
Finished validation.
Starting training epoch 112
Epoch: 112, Loss 0.0053 (0.0052)
Finished training epoch 112


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: Loss 0.0063 (0.0065)	
Finished validation.
Starting training epoch 113
Epoch: 113, Loss 0.0051 (0.0052)
Finished training epoch 113
Validate: Loss 0.0069 (0.0067)	
Finished validation.
Starting training epoch 114
Epoch: 114, Loss 0.0054 (0.0051)
Finished training epoch 114
Validate: Loss 0.0062 (0.0062)	
Finished validation.
Starting training epoch 115
Epoch: 115, Loss 0.0055 (0.0051)
Finished training epoch 115
Validate: Loss 0.0060 (0.0060)	
Finished validation.
Starting training epoch 116
Epoch: 116, Loss 0.0056 (0.0051)
Finished training epoch 116
Validate: Loss 0.0065 (0.0066)	
Finished validation.
Starting training epoch 117
Epoch: 117, Loss 0.0050 (0.0052)
Finished training epoch 117


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: Loss 0.0068 (0.0067)	
Finished validation.
Starting training epoch 118
Epoch: 118, Loss 0.0048 (0.0052)
Finished training epoch 118
Validate: Loss 0.0064 (0.0064)	
Finished validation.
Starting training epoch 119
Epoch: 119, Loss 0.0048 (0.0051)
Finished training epoch 119


  return func(*args, **kwargs)


Validate: Loss 0.0067 (0.0067)	
Finished validation.
Starting training epoch 120
Epoch: 120, Loss 0.0048 (0.0051)
Finished training epoch 120
Validate: Loss 0.0061 (0.0061)	
Finished validation.
Starting training epoch 121
Epoch: 121, Loss 0.0052 (0.0051)
Finished training epoch 121
Validate: Loss 0.0062 (0.0062)	
Finished validation.
Starting training epoch 122
Epoch: 122, Loss 0.0047 (0.0051)
Finished training epoch 122
Validate: Loss 0.0067 (0.0069)	
Finished validation.
Starting training epoch 123
Epoch: 123, Loss 0.0053 (0.0050)
Finished training epoch 123
Validate: Loss 0.0064 (0.0064)	
Finished validation.
Starting training epoch 124
Epoch: 124, Loss 0.0050 (0.0050)
Finished training epoch 124
Validate: Loss 0.0060 (0.0059)	
Finished validation.
Starting training epoch 125
Epoch: 125, Loss 0.0054 (0.0051)
Finished training epoch 125


  return func(*args, **kwargs)


Validate: Loss 0.0066 (0.0067)	
Finished validation.
Starting training epoch 126
Epoch: 126, Loss 0.0045 (0.0050)
Finished training epoch 126


  return func(*args, **kwargs)


Validate: Loss 0.0066 (0.0064)	
Finished validation.
Starting training epoch 127
Epoch: 127, Loss 0.0051 (0.0051)
Finished training epoch 127
Validate: Loss 0.0065 (0.0066)	
Finished validation.
Starting training epoch 128
Epoch: 128, Loss 0.0052 (0.0050)
Finished training epoch 128
Validate: Loss 0.0064 (0.0066)	
Finished validation.
Starting training epoch 129
Epoch: 129, Loss 0.0060 (0.0050)
Finished training epoch 129
Validate: Loss 0.0062 (0.0062)	
Finished validation.
Starting training epoch 130
Epoch: 130, Loss 0.0053 (0.0050)
Finished training epoch 130
Validate: Loss 0.0061 (0.0061)	
Finished validation.
Starting training epoch 131
Epoch: 131, Loss 0.0054 (0.0050)
Finished training epoch 131
Validate: Loss 0.0060 (0.0061)	
Finished validation.
Starting training epoch 132
Epoch: 132, Loss 0.0051 (0.0050)
Finished training epoch 132
Validate: Loss 0.0061 (0.0061)	
Finished validation.
Starting training epoch 133
Epoch: 133, Loss 0.0057 (0.0050)
Finished training epoch 133
Valida

  return func(*args, **kwargs)


Validate: Loss 0.0068 (0.0071)	
Finished validation.
Starting training epoch 138
Epoch: 138, Loss 0.0053 (0.0049)
Finished training epoch 138
Validate: Loss 0.0060 (0.0062)	
Finished validation.
Starting training epoch 139
Epoch: 139, Loss 0.0057 (0.0049)
Finished training epoch 139
Validate: Loss 0.0061 (0.0061)	
Finished validation.
Starting training epoch 140
Epoch: 140, Loss 0.0052 (0.0049)
Finished training epoch 140


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: Loss 0.0065 (0.0065)	
Finished validation.
Starting training epoch 141
Epoch: 141, Loss 0.0045 (0.0050)
Finished training epoch 141
Validate: Loss 0.0066 (0.0063)	
Finished validation.
Starting training epoch 142
Epoch: 142, Loss 0.0053 (0.0050)
Finished training epoch 142
Validate: Loss 0.0061 (0.0063)	
Finished validation.
Starting training epoch 143
Epoch: 143, Loss 0.0044 (0.0049)
Finished training epoch 143
Validate: Loss 0.0059 (0.0063)	
Finished validation.
Starting training epoch 144
Epoch: 144, Loss 0.0054 (0.0049)
Finished training epoch 144
Validate: Loss 0.0062 (0.0063)	
Finished validation.
Starting training epoch 145
Epoch: 145, Loss 0.0046 (0.0049)
Finished training epoch 145
Validate: Loss 0.0062 (0.0063)	
Finished validation.
Starting training epoch 146
Epoch: 146, Loss 0.0048 (0.0049)
Finished training epoch 146
Validate: Loss 0.0060 (0.0062)	
Finished validation.
Starting training epoch 147
Epoch: 147, Loss 0.0043 (0.0049)
Finished training epoch 147
Valida

  return func(*args, **kwargs)


Validate: Loss 0.0071 (0.0069)	
Finished validation.
Starting training epoch 151
Epoch: 151, Loss 0.0043 (0.0049)
Finished training epoch 151
Validate: Loss 0.0059 (0.0063)	
Finished validation.
Starting training epoch 152
Epoch: 152, Loss 0.0055 (0.0049)
Finished training epoch 152
Validate: Loss 0.0062 (0.0064)	
Finished validation.
Starting training epoch 153
Epoch: 153, Loss 0.0050 (0.0049)
Finished training epoch 153
Validate: Loss 0.0060 (0.0062)	
Finished validation.
Starting training epoch 154
Epoch: 154, Loss 0.0051 (0.0048)
Finished training epoch 154
Validate: Loss 0.0060 (0.0063)	
Finished validation.
Starting training epoch 155
Epoch: 155, Loss 0.0051 (0.0048)
Finished training epoch 155
Validate: Loss 0.0060 (0.0060)	
Finished validation.
Starting training epoch 156
Epoch: 156, Loss 0.0049 (0.0050)
Finished training epoch 156
Validate: Loss 0.0064 (0.0065)	
Finished validation.
Starting training epoch 157
Epoch: 157, Loss 0.0050 (0.0049)
Finished training epoch 157
Valida

  return func(*args, **kwargs)


Validate: Loss 0.0059 (0.0062)	
Finished validation.
Starting training epoch 161
Epoch: 161, Loss 0.0047 (0.0048)
Finished training epoch 161
Validate: Loss 0.0060 (0.0062)	
Finished validation.
Starting training epoch 162
Epoch: 162, Loss 0.0049 (0.0048)
Finished training epoch 162
Validate: Loss 0.0062 (0.0062)	
Finished validation.
Starting training epoch 163
Epoch: 163, Loss 0.0049 (0.0048)
Finished training epoch 163
Validate: Loss 0.0060 (0.0061)	
Finished validation.
Starting training epoch 164
Epoch: 164, Loss 0.0045 (0.0048)
Finished training epoch 164
Validate: Loss 0.0062 (0.0064)	
Finished validation.
Starting training epoch 165
Epoch: 165, Loss 0.0047 (0.0048)
Finished training epoch 165
Validate: Loss 0.0060 (0.0063)	
Finished validation.
Starting training epoch 166
Epoch: 166, Loss 0.0050 (0.0048)
Finished training epoch 166


  return func(*args, **kwargs)


Validate: Loss 0.0062 (0.0063)	
Finished validation.
Starting training epoch 167
Epoch: 167, Loss 0.0049 (0.0048)
Finished training epoch 167
Validate: Loss 0.0061 (0.0063)	
Finished validation.
Starting training epoch 168
Epoch: 168, Loss 0.0050 (0.0048)
Finished training epoch 168


  return func(*args, **kwargs)


Validate: Loss 0.0065 (0.0063)	
Finished validation.
Starting training epoch 169
Epoch: 169, Loss 0.0052 (0.0048)
Finished training epoch 169
Validate: Loss 0.0059 (0.0060)	
Finished validation.
Starting training epoch 170
Epoch: 170, Loss 0.0046 (0.0048)
Finished training epoch 170
Validate: Loss 0.0065 (0.0066)	
Finished validation.
Starting training epoch 171
Epoch: 171, Loss 0.0047 (0.0047)
Finished training epoch 171
Validate: Loss 0.0067 (0.0065)	
Finished validation.
Starting training epoch 172
Epoch: 172, Loss 0.0045 (0.0048)
Finished training epoch 172


  return func(*args, **kwargs)


Validate: Loss 0.0063 (0.0064)	
Finished validation.
Starting training epoch 173
Epoch: 173, Loss 0.0042 (0.0048)
Finished training epoch 173


  return func(*args, **kwargs)


Validate: Loss 0.0059 (0.0061)	
Finished validation.
Starting training epoch 174
Epoch: 174, Loss 0.0047 (0.0048)
Finished training epoch 174


  return func(*args, **kwargs)


Validate: Loss 0.0060 (0.0062)	
Finished validation.
Starting training epoch 175
Epoch: 175, Loss 0.0046 (0.0047)
Finished training epoch 175


  return func(*args, **kwargs)


Validate: Loss 0.0063 (0.0063)	
Finished validation.
Starting training epoch 176
Epoch: 176, Loss 0.0043 (0.0048)
Finished training epoch 176


  return func(*args, **kwargs)


Validate: Loss 0.0062 (0.0065)	
Finished validation.
Starting training epoch 177
Epoch: 177, Loss 0.0051 (0.0047)
Finished training epoch 177
Validate: Loss 0.0062 (0.0063)	
Finished validation.
Starting training epoch 178
Epoch: 178, Loss 0.0051 (0.0047)
Finished training epoch 178


  return func(*args, **kwargs)


Validate: Loss 0.0063 (0.0063)	
Finished validation.
Starting training epoch 179
Epoch: 179, Loss 0.0049 (0.0047)
Finished training epoch 179
Validate: Loss 0.0062 (0.0064)	
Finished validation.
Starting training epoch 180
Epoch: 180, Loss 0.0042 (0.0047)
Finished training epoch 180
Validate: Loss 0.0067 (0.0070)	
Finished validation.
Starting training epoch 181
Epoch: 181, Loss 0.0041 (0.0047)
Finished training epoch 181
Validate: Loss 0.0060 (0.0062)	
Finished validation.
Starting training epoch 182
Epoch: 182, Loss 0.0048 (0.0047)
Finished training epoch 182
Validate: Loss 0.0060 (0.0062)	
Finished validation.
Starting training epoch 183
Epoch: 183, Loss 0.0055 (0.0047)
Finished training epoch 183
Validate: Loss 0.0065 (0.0066)	
Finished validation.
Starting training epoch 184
Epoch: 184, Loss 0.0049 (0.0047)
Finished training epoch 184
Validate: Loss 0.0061 (0.0062)	
Finished validation.
Starting training epoch 185
Epoch: 185, Loss 0.0049 (0.0047)
Finished training epoch 185
Valida

  return func(*args, **kwargs)


Validate: Loss 0.0059 (0.0062)	
Finished validation.
Starting training epoch 197
Epoch: 197, Loss 0.0046 (0.0047)
Finished training epoch 197
Validate: Loss 0.0063 (0.0064)	
Finished validation.
Starting training epoch 198
Epoch: 198, Loss 0.0047 (0.0047)
Finished training epoch 198
Validate: Loss 0.0058 (0.0062)	
Finished validation.
Starting training epoch 199
Epoch: 199, Loss 0.0046 (0.0047)
Finished training epoch 199
Validate: Loss 0.0059 (0.0064)	
Finished validation.


<Figure size 432x288 with 0 Axes>

In [None]:
# Load model
best_epoch = 44
pretrained = torch.load(f'{checkpoints}/model-epoch-{best_epoch}-losses-0.002.pth', map_location=lambda storage, loc: storage)
model.load_state_dict(pretrained)

In [None]:
# Validate
save_images = True
with torch.no_grad():
  validate(val_loader, model, criterion, save_images, 44)

In [None]:
# Show images 
image_pairs = []

for i in range(10):
    image_pairs.append((f'{color_imgs}img-{i}-epoch-{best_epoch}.jpg', f'{gray_imgs}img-{i}-epoch-{best_epoch}.jpg'))
    
for c, g in image_pairs:
  color = mpimg.imread(c)
  gray  = mpimg.imread(g)
  f, axarr = plt.subplots(1, 2)
  f.set_size_inches(15, 15)
  axarr[0].imshow(gray, cmap='gray')
  axarr[1].imshow(color)
  axarr[0].axis('off'), axarr[1].axis('off')
  plt.show()