In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# For conversion
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchmetrics import MeanSquaredError, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from PIL import Image
# For utilities
import os, shutil, time

In [2]:
%load_ext tensorboard
%tensorboard --logdir=runs

Reusing TensorBoard on port 6009 (pid 22749), started 1:07:57 ago. (Use '!kill 22749' to kill it.)

In [3]:
# colab i kaggle jeszcze nie testowane
colab = False
kaggle = False
test_number = '12_2'

In [4]:
color_imgs = 'outputs/color/'
gray_imgs = 'outputs/gray/'
checkpoints = 'checkpoints'
if colab:
    from google.colab import drive
    drive.mount('/content/drive')
    dataset = '/content/drive/MyDrive/MGU/cifar10/'
    
    color_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{color_imgs}'
    gray_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{gray_imgs}'
    checkpoints = f'/content/drive/MyDrive/MGU/{test_number}/{checkpoints}'
elif kaggle:
    dataset = '/kaggle/input/cifar10/'
    
    color_imgs = f'{test_number}/{color_imgs}'
    gray_imgs = f'{test_number}/{gray_imgs}'
    checkpoints = f'{test_number}/{checkpoints}'
else:
    dataset = '../../datasets/dataset/'

In [5]:
# Make folders and set parameters
os.makedirs(color_imgs, exist_ok=True)
os.makedirs(gray_imgs, exist_ok=True)
os.makedirs(checkpoints, exist_ok=True)
save_images = True
best_losses = [1e10, 1e10, 1e10]
best_epoch = -1
patience = 50
epochs = 500
batch_size = 128
SIZE = 64

In [6]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [7]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()
print(use_gpu)

True


In [8]:
class LabImageFolder(torch.utils.data.Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC),
                transforms.RandomCrop(SIZE),
                transforms.RandomHorizontalFlip(),  
            ])
        elif split == 'val':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC), 
                transforms.RandomCrop(SIZE), 
            ])
            
        self.split = split
        self.size = SIZE
        self.paths = [os.path.join(paths, file) for file in os.listdir(paths) if os.path.isfile(
            os.path.join(paths, file))]
        
        
    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("RGB")
        img_original = self.transforms(img)
        img_original = np.asarray(img_original)
        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        img_gray = rgb2gray(img_original)
        img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
        return img_gray, img_ab
    
    def __len__(self):
        return len(self.paths)

In [9]:
# Training
train_imagefolder = LabImageFolder(dataset + 'train/colour')
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=batch_size, shuffle=True)
# Validation 
val_imagefolder = LabImageFolder(dataset + 'test/colour' , 'val')
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=batch_size, shuffle=False)

In [10]:
kernel_size=3
stride_en=1
stride_de=1
padding=1
scale_factor=2
padding_mode='zeros'
class Autoencoder(nn.Module):
  def __init__(self, input_size=128):
    super(Autoencoder, self).__init__()

    self.encoder = nn.Sequential(       
      nn.Conv2d(1, 16, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode),
      nn.BatchNorm2d(16),
      nn.LeakyReLU(0.1),
      nn.AvgPool2d(kernel_size=3, stride=2, padding=1),
      nn.Conv2d(16, 32, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode),
      nn.BatchNorm2d(32),
      nn.LeakyReLU(0.1),
      nn.AvgPool2d(kernel_size=3, stride=2, padding=1),
      nn.Conv2d(32, 64, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode),
      nn.BatchNorm2d(64),
      nn.LeakyReLU(0.1),
      nn.AvgPool2d(kernel_size=3, stride=2, padding=1),
    )     
        
    self.decoder = nn.Sequential(  
      nn.ConvTranspose2d(64, 32, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode),
      nn.BatchNorm2d(32),
      nn.LeakyReLU(0.1),
      nn.Upsample(scale_factor=scale_factor),   
      nn.ConvTranspose2d(32, 16, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode),
      nn.BatchNorm2d(16),
      nn.LeakyReLU(0.1),
      nn.Upsample(scale_factor=scale_factor),
      nn.ConvTranspose2d(16, 8, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode),
      nn.BatchNorm2d(8),
      nn.LeakyReLU(0.1),
      nn.ConvTranspose2d(8, 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode),
      nn.Upsample(scale_factor=scale_factor)
    )

  def forward(self, input):

    encoder = self.encoder(input)
    # Upsample to get colors
    output = self.decoder(encoder)
    return output

In [11]:
model = Autoencoder()

In [12]:
criterion = [MeanSquaredError(), PeakSignalNoiseRatio(data_range=1.0), StructuralSimilarityIndexMeasure(data_range=1.0)]

In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [14]:
# # Move model and loss function to GPU
if use_gpu: 
    criterion = [criterion[0].to("cuda"), criterion[1].to("cuda"), criterion[2].to("cuda")]
    model = model.cuda()

In [15]:
if use_gpu: 
    from torchsummary import summary
    summary(model, (1, SIZE, SIZE))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 64, 64]             160
       BatchNorm2d-2           [-1, 16, 64, 64]              32
         LeakyReLU-3           [-1, 16, 64, 64]               0
         AvgPool2d-4           [-1, 16, 32, 32]               0
            Conv2d-5           [-1, 32, 32, 32]           4,640
       BatchNorm2d-6           [-1, 32, 32, 32]              64
         LeakyReLU-7           [-1, 32, 32, 32]               0
         AvgPool2d-8           [-1, 32, 16, 16]               0
            Conv2d-9           [-1, 64, 16, 16]          18,496
      BatchNorm2d-10           [-1, 64, 16, 16]             128
        LeakyReLU-11           [-1, 64, 16, 16]               0
        AvgPool2d-12             [-1, 64, 8, 8]               0
  ConvTranspose2d-13             [-1, 32, 8, 8]          18,464
      BatchNorm2d-14             [-1, 3

In [16]:
class AverageMeter(object):
    '''A handy class from the PyTorch ImageNet tutorial''' 
    def __init__(self):
        self.reset()
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
       Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
    color_image = lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    if save_path is not None and save_name is not None: 
        plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
        plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

In [17]:
def validate(val_loader, model, criterion, save_images, epoch):
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]

    model.eval()
    already_saved_images = False
    for gray, ab in val_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()

        # Run model and record loss
        output_ab = model(gray) # throw away class predictions
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))

        # Save images to file
        if save_images and not already_saved_images:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)): # save at most 5 images
                save_path = {'grayscale': gray_imgs, 'colorized': color_imgs}
                save_name = 'img-{}-epoch-{}.jpg'.format(j, epoch)
                to_rgb(gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

    print(f'Validate: MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print('Finished validation.')
    if epoch >= 0:
        writer.add_scalar("MSE/test", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/test", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/test", _loss[2].avg, epoch)
    return _loss[0].avg, _loss[1].avg, _loss[2].avg

In [18]:
def train(train_loader, model, criterion, optimizer, epoch):
    print(f'Starting training epoch {epoch}')
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]
    
    model.train()

    for gray, ab in train_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()
            
        optimizer.zero_grad()

        output_ab = model(gray) 
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        loss[0].backward()
        optimizer.step()
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))
        
    print(f'Epoch: {epoch}, MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print(f'Finished training epoch {epoch}')
    if epoch >= 0:
        writer.add_scalar("MSE/train", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/train", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/train", _loss[2].avg, epoch)

In [19]:
# Train model
for epoch in range(epochs):
    # Train for one epoch, then validate
    train(train_loader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(val_loader, model, criterion, save_images, epoch)
    # Save checkpoint and replace old best model if current model is better
    if losses[0] < best_losses[0]:
        best_losses[0] = losses[0]
        best_epoch = epoch
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}.pth')
    if losses[1] < best_losses[1]:
        best_losses[1] = losses[1]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-PSNRLoss-{losses[1]:.8f}.pth')
    if losses[2] < best_losses[2]:
        best_losses[2] = losses[2]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-SSIMLoss-{losses[2]:.8f}.pth')
    
    if epoch - best_epoch >= patience and epoch >= 100:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}-early_stop.pth')
        break
    
    if epoch == epochs - 1:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')


Starting training epoch 0
Epoch: 0, MSE 0.00798380 (0.02834848), PSNR 20.97790337 (18.51072498), SSIM 0.62755382 (0.54083144)
Finished training epoch 0


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00814617 (0.00813828), PSNR 20.89046478 (20.90227846), SSIM 0.63394707 (0.62463297)
Finished validation.
Starting training epoch 1
Epoch: 1, MSE 0.00912327 (0.00836441), PSNR 20.39849472 (20.79145772), SSIM 0.64618057 (0.63137375)
Finished training epoch 1


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00733609 (0.00737797), PSNR 21.34535027 (21.33064982), SSIM 0.65420568 (0.64476258)
Finished validation.
Starting training epoch 2
Epoch: 2, MSE 0.00751531 (0.00782307), PSNR 21.24053192 (21.07943420), SSIM 0.64886248 (0.64920109)
Finished training epoch 2
Validate: MSE 0.00708755 (0.00710014), PSNR 21.49503899 (21.49823015), SSIM 0.67103583 (0.66018918)
Finished validation.
Starting training epoch 3
Epoch: 3, MSE 0.00742368 (0.00776491), PSNR 21.29380608 (21.11348326), SSIM 0.66258341 (0.65671030)
Finished training epoch 3
Validate: MSE 0.00721324 (0.00738051), PSNR 21.41869736 (21.33153319), SSIM 0.67281538 (0.66164296)
Finished validation.
Starting training epoch 4
Epoch: 4, MSE 0.00755763 (0.00754047), PSNR 21.21614456 (21.23421232), SSIM 0.66974527 (0.66291001)
Finished training epoch 4
Validate: MSE 0.00674413 (0.00685097), PSNR 21.71073723 (21.65419039), SSIM 0.68369764 (0.67194777)
Finished validation.
Starting training epoch 5
Epoch: 5, MSE 0.00777070 (0.007400

  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00699177 (0.00705377), PSNR 21.55412865 (21.52656889), SSIM 0.68589038 (0.67390404)
Finished validation.
Starting training epoch 6
Epoch: 6, MSE 0.00705608 (0.00736508), PSNR 21.51436234 (21.34338799), SSIM 0.68715841 (0.67178809)
Finished training epoch 6
Validate: MSE 0.00687723 (0.00701915), PSNR 21.62586212 (21.55214586), SSIM 0.69161135 (0.67938491)
Finished validation.
Starting training epoch 7
Epoch: 7, MSE 0.00671132 (0.00719062), PSNR 21.73192215 (21.44664050), SSIM 0.68864912 (0.67491968)
Finished training epoch 7
Validate: MSE 0.00694566 (0.00687312), PSNR 21.58286476 (21.64281308), SSIM 0.69486272 (0.68286768)
Finished validation.
Starting training epoch 8
Epoch: 8, MSE 0.00762915 (0.00722428), PSNR 21.17523766 (21.42363687), SSIM 0.67892802 (0.67678398)
Finished training epoch 8
Validate: MSE 0.00728968 (0.00711690), PSNR 21.37291527 (21.49059531), SSIM 0.69143319 (0.67959907)
Finished validation.
Starting training epoch 9
Epoch: 9, MSE 0.00725162 (0.007067

  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00690092 (0.00686970), PSNR 21.61092567 (21.64241328), SSIM 0.69677812 (0.68420461)
Finished validation.
Starting training epoch 10
Epoch: 10, MSE 0.00725279 (0.00721339), PSNR 21.39495087 (21.42878938), SSIM 0.68979609 (0.68055316)
Finished training epoch 10
Validate: MSE 0.00684791 (0.00684413), PSNR 21.64442062 (21.66190894), SSIM 0.69766736 (0.68492909)
Finished validation.
Starting training epoch 11
Epoch: 11, MSE 0.00726839 (0.00701679), PSNR 21.38561630 (21.54676778), SSIM 0.67129636 (0.68266315)
Finished training epoch 11


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00664153 (0.00666703), PSNR 21.77731895 (21.77384045), SSIM 0.70253986 (0.69032404)
Finished validation.
Starting training epoch 12
Epoch: 12, MSE 0.00717523 (0.00707456), PSNR 21.44163895 (21.51239973), SSIM 0.70021319 (0.68374247)
Finished training epoch 12
Validate: MSE 0.00674936 (0.00695128), PSNR 21.70737076 (21.59306008), SSIM 0.70039189 (0.68756724)
Finished validation.
Starting training epoch 13
Epoch: 13, MSE 0.00731472 (0.00693264), PSNR 21.35802078 (21.60054837), SSIM 0.67685473 (0.68654072)
Finished training epoch 13
Validate: MSE 0.00768052 (0.00755170), PSNR 21.14609337 (21.23116714), SSIM 0.69820052 (0.68616108)
Finished validation.
Starting training epoch 14
Epoch: 14, MSE 0.00685426 (0.00688569), PSNR 21.64039230 (21.62844244), SSIM 0.67042297 (0.68862510)
Finished training epoch 14


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00656977 (0.00664804), PSNR 21.82449913 (21.78884408), SSIM 0.70742244 (0.69482195)
Finished validation.
Starting training epoch 15
Epoch: 15, MSE 0.00665102 (0.00685772), PSNR 21.77112007 (21.64787458), SSIM 0.69039309 (0.68921169)
Finished training epoch 15
Validate: MSE 0.00686973 (0.00678241), PSNR 21.63059998 (21.69989953), SSIM 0.70748633 (0.69568683)
Finished validation.
Starting training epoch 16
Epoch: 16, MSE 0.00680219 (0.00676693), PSNR 21.67351151 (21.70645870), SSIM 0.69540912 (0.69071894)
Finished training epoch 16
Validate: MSE 0.00646408 (0.00654055), PSNR 21.89493370 (21.85712459), SSIM 0.70849878 (0.69613552)
Finished validation.
Starting training epoch 17
Epoch: 17, MSE 0.00692817 (0.00676897), PSNR 21.59381485 (21.70941504), SSIM 0.69177836 (0.69098593)
Finished training epoch 17
Validate: MSE 0.00663943 (0.00660415), PSNR 21.77869225 (21.81465737), SSIM 0.70224446 (0.69078265)
Finished validation.
Starting training epoch 18
Epoch: 18, MSE 0.0069090

Validate: MSE 0.00615609 (0.00638302), PSNR 22.10694695 (21.96545481), SSIM 0.71671033 (0.70552650)
Finished validation.
Starting training epoch 45
Epoch: 45, MSE 0.00561972 (0.00619756), PSNR 22.50285149 (22.08709114), SSIM 0.72173977 (0.70258527)
Finished training epoch 45
Validate: MSE 0.00642562 (0.00667858), PSNR 21.92084694 (21.77276499), SSIM 0.72036058 (0.70671199)
Finished validation.
Starting training epoch 46
Epoch: 46, MSE 0.00602386 (0.00616939), PSNR 22.20125008 (22.10512337), SSIM 0.70389420 (0.70306751)
Finished training epoch 46
Validate: MSE 0.00651207 (0.00653645), PSNR 21.86280823 (21.86369526), SSIM 0.70954239 (0.69835498)
Finished validation.
Starting training epoch 47
Epoch: 47, MSE 0.00674263 (0.00612435), PSNR 21.71170425 (22.13995028), SSIM 0.70209807 (0.70297594)
Finished training epoch 47
Validate: MSE 0.00646242 (0.00660909), PSNR 21.89604759 (21.81137854), SSIM 0.71052450 (0.69941726)
Finished validation.
Starting training epoch 48
Epoch: 48, MSE 0.0057379

  return func(*args, **kwargs)


Validate: MSE 0.00659392 (0.00683543), PSNR 21.80855942 (21.66553223), SSIM 0.71778917 (0.70482601)
Finished validation.
Starting training epoch 52
Epoch: 52, MSE 0.00629021 (0.00608733), PSNR 22.01334572 (22.16726451), SSIM 0.68520492 (0.70401501)
Finished training epoch 52
Validate: MSE 0.00616767 (0.00624163), PSNR 22.09878922 (22.06196235), SSIM 0.71180439 (0.70152650)
Finished validation.
Starting training epoch 53
Epoch: 53, MSE 0.00605739 (0.00605541), PSNR 22.17714119 (22.19005097), SSIM 0.71823961 (0.70449438)
Finished training epoch 53
Validate: MSE 0.00631333 (0.00641997), PSNR 21.99741745 (21.94163316), SSIM 0.72086960 (0.70934394)
Finished validation.
Starting training epoch 54
Epoch: 54, MSE 0.00544876 (0.00606994), PSNR 22.63702583 (22.17866275), SSIM 0.69761688 (0.70457084)
Finished training epoch 54
Validate: MSE 0.00656765 (0.00723443), PSNR 21.82589722 (21.42107682), SSIM 0.70992404 (0.69770856)
Finished validation.
Starting training epoch 55
Epoch: 55, MSE 0.0054498

Validate: MSE 0.00738855 (0.00757283), PSNR 21.31440735 (21.22513203), SSIM 0.70386690 (0.69437144)
Finished validation.
Starting training epoch 82
Epoch: 82, MSE 0.00581011 (0.00562945), PSNR 22.35815239 (22.50690530), SSIM 0.69543904 (0.71029261)
Finished training epoch 82
Validate: MSE 0.00663309 (0.00641377), PSNR 21.78283882 (21.93457520), SSIM 0.70901000 (0.70297695)
Finished validation.
Starting training epoch 83
Epoch: 83, MSE 0.00551029 (0.00568948), PSNR 22.58825111 (22.45803405), SSIM 0.70295823 (0.71029999)
Finished training epoch 83


  return func(*args, **kwargs)


Validate: MSE 0.00821648 (0.00761217), PSNR 20.85314369 (21.19549848), SSIM 0.69603878 (0.69072509)
Finished validation.
Starting training epoch 84
Epoch: 84, MSE 0.00574931 (0.00564256), PSNR 22.40384293 (22.49852321), SSIM 0.71946245 (0.71055241)
Finished training epoch 84
Validate: MSE 0.00580010 (0.00583075), PSNR 22.36564255 (22.35933468), SSIM 0.72930765 (0.71673324)
Finished validation.
Starting training epoch 85
Epoch: 85, MSE 0.00525738 (0.00562122), PSNR 22.79230690 (22.50996848), SSIM 0.72298789 (0.71052507)
Finished training epoch 85
Validate: MSE 0.00578367 (0.00568274), PSNR 22.37796211 (22.46733151), SSIM 0.72180676 (0.71181185)
Finished validation.
Starting training epoch 86
Epoch: 86, MSE 0.00553345 (0.00558702), PSNR 22.57003593 (22.53504432), SSIM 0.71687013 (0.71091294)
Finished training epoch 86
Validate: MSE 0.00587742 (0.00602436), PSNR 22.30813026 (22.21181608), SSIM 0.71475893 (0.70468666)
Finished validation.
Starting training epoch 87
Epoch: 87, MSE 0.0059767

  return func(*args, **kwargs)


Validate: MSE 0.00703604 (0.00699987), PSNR 21.52671242 (21.55587969), SSIM 0.69989616 (0.69403579)
Finished validation.
Starting training epoch 88
Epoch: 88, MSE 0.00539724 (0.00560478), PSNR 22.67828178 (22.52392110), SSIM 0.72638500 (0.71092943)
Finished training epoch 88


  return func(*args, **kwargs)


Validate: MSE 0.00675238 (0.00657770), PSNR 21.70542908 (21.83402639), SSIM 0.71562672 (0.70717357)
Finished validation.
Starting training epoch 89
Epoch: 89, MSE 0.00589630 (0.00559827), PSNR 22.29420471 (22.53158519), SSIM 0.69657737 (0.71140375)
Finished training epoch 89
Validate: MSE 0.00560728 (0.00602669), PSNR 22.51247406 (22.21045306), SSIM 0.71960431 (0.70672257)
Finished validation.
Starting training epoch 90
Epoch: 90, MSE 0.00550745 (0.00555074), PSNR 22.59049034 (22.56574460), SSIM 0.72260898 (0.71150492)
Finished training epoch 90
Validate: MSE 0.00557019 (0.00571423), PSNR 22.54129601 (22.43879702), SSIM 0.71993047 (0.71053860)
Finished validation.
Starting training epoch 91
Epoch: 91, MSE 0.00631505 (0.00553085), PSNR 21.99622726 (22.58296781), SSIM 0.70816737 (0.71185939)
Finished training epoch 91
Validate: MSE 0.00583096 (0.00594393), PSNR 22.34259415 (22.27763029), SSIM 0.71817386 (0.70713236)
Finished validation.
Starting training epoch 92
Epoch: 92, MSE 0.0053060

  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00615373 (0.00590983), PSNR 22.10861397 (22.29424497), SSIM 0.71303701 (0.70444591)
Finished validation.
Starting training epoch 94
Epoch: 94, MSE 0.00548538 (0.00552488), PSNR 22.60793304 (22.58702524), SSIM 0.72299343 (0.71179395)
Finished training epoch 94
Validate: MSE 0.00656695 (0.00661403), PSNR 21.82636261 (21.81499998), SSIM 0.70187908 (0.69419556)
Finished validation.
Starting training epoch 95
Epoch: 95, MSE 0.00536633 (0.00551400), PSNR 22.70322037 (22.59599485), SSIM 0.70234209 (0.71185531)
Finished training epoch 95
Validate: MSE 0.00540862 (0.00550980), PSNR 22.66913605 (22.59896452), SSIM 0.72555608 (0.71461182)
Finished validation.
Starting training epoch 96
Epoch: 96, MSE 0.00500327 (0.00550953), PSNR 23.00745964 (22.59889787), SSIM 0.71401709 (0.71181925)
Finished training epoch 96
Validate: MSE 0.00558740 (0.00569355), PSNR 22.52790260 (22.45614913), SSIM 0.72483450 (0.71561255)
Finished validation.
Starting training epoch 97
Epoch: 97, MSE 0.0054047

  return func(*args, **kwargs)


Validate: MSE 0.00602007 (0.00582753), PSNR 22.20398521 (22.35756388), SSIM 0.71846145 (0.71093889)
Finished validation.
Starting training epoch 98
Epoch: 98, MSE 0.00529297 (0.00547075), PSNR 22.76300049 (22.62946016), SSIM 0.72982657 (0.71188660)
Finished training epoch 98


  return func(*args, **kwargs)


Validate: MSE 0.00644243 (0.00626119), PSNR 21.90950012 (22.03960206), SSIM 0.70975357 (0.70246366)
Finished validation.
Starting training epoch 99
Epoch: 99, MSE 0.00562750 (0.00543266), PSNR 22.49684334 (22.65871676), SSIM 0.70517170 (0.71256741)
Finished training epoch 99
Validate: MSE 0.00599759 (0.00646978), PSNR 22.22023201 (21.90553986), SSIM 0.71447235 (0.70363340)
Finished validation.
Starting training epoch 100
Epoch: 100, MSE 0.00560797 (0.00543486), PSNR 22.51194000 (22.65944770), SSIM 0.72005028 (0.71214472)
Finished training epoch 100
Validate: MSE 0.00594578 (0.00614601), PSNR 22.25790977 (22.13020234), SSIM 0.71844476 (0.70868390)
Finished validation.
Starting training epoch 101
Epoch: 101, MSE 0.00574875 (0.00543528), PSNR 22.40426445 (22.65962070), SSIM 0.70723140 (0.71223204)
Finished training epoch 101
Validate: MSE 0.00554227 (0.00594479), PSNR 22.56312180 (22.27039719), SSIM 0.72051233 (0.70542057)
Finished validation.
Starting training epoch 102
Epoch: 102, MSE 0

  return func(*args, **kwargs)


Validate: MSE 0.00599885 (0.00583201), PSNR 22.21931839 (22.35380008), SSIM 0.72081220 (0.71066695)
Finished validation.
Starting training epoch 104
Epoch: 104, MSE 0.00563547 (0.00540802), PSNR 22.49069595 (22.68030187), SSIM 0.72310060 (0.71265561)
Finished training epoch 104


  return func(*args, **kwargs)


Validate: MSE 0.00559997 (0.00568025), PSNR 22.51814079 (22.47215936), SSIM 0.72196960 (0.71278756)
Finished validation.
Starting training epoch 105
Epoch: 105, MSE 0.00524511 (0.00539839), PSNR 22.80245781 (22.68693535), SSIM 0.71432364 (0.71268964)
Finished training epoch 105
Validate: MSE 0.00544197 (0.00581596), PSNR 22.64243698 (22.36822215), SSIM 0.72280246 (0.70977243)
Finished validation.
Starting training epoch 106
Epoch: 106, MSE 0.00607415 (0.00542149), PSNR 22.16514778 (22.66834301), SSIM 0.70079058 (0.71257806)
Finished training epoch 106
Validate: MSE 0.00563996 (0.00564756), PSNR 22.48724174 (22.49141118), SSIM 0.72318447 (0.71365700)
Finished validation.
Starting training epoch 107
Epoch: 107, MSE 0.00516067 (0.00541045), PSNR 22.87293625 (22.67818805), SSIM 0.70862365 (0.71257210)
Finished training epoch 107
Validate: MSE 0.00573253 (0.00575562), PSNR 22.41653252 (22.40716138), SSIM 0.71902835 (0.70896627)
Finished validation.
Starting training epoch 108
Epoch: 108, MS

  return func(*args, **kwargs)


Validate: MSE 0.00595752 (0.00589621), PSNR 22.24934196 (22.29829941), SSIM 0.71701294 (0.70840382)
Finished validation.
Starting training epoch 110
Epoch: 110, MSE 0.00544351 (0.00534186), PSNR 22.64120674 (22.73847585), SSIM 0.70616490 (0.71288236)
Finished training epoch 110
Validate: MSE 0.00561965 (0.00565815), PSNR 22.50290680 (22.47818578), SSIM 0.71855962 (0.70992165)
Finished validation.
Starting training epoch 111
Epoch: 111, MSE 0.00551163 (0.00534923), PSNR 22.58720016 (22.72729110), SSIM 0.70863813 (0.71319463)
Finished training epoch 111
Validate: MSE 0.00644572 (0.00590323), PSNR 21.90728378 (22.30104907), SSIM 0.71394145 (0.70768628)
Finished validation.
Starting training epoch 112
Epoch: 112, MSE 0.00502017 (0.00529692), PSNR 22.99281311 (22.77116929), SSIM 0.73448539 (0.71327576)
Finished training epoch 112
Validate: MSE 0.00593509 (0.00595787), PSNR 22.26572418 (22.25981081), SSIM 0.72308302 (0.71477559)
Finished validation.
Starting training epoch 113
Epoch: 113, MS

  return func(*args, **kwargs)


Validate: MSE 0.00612889 (0.00607549), PSNR 22.12617874 (22.17293721), SSIM 0.71768391 (0.70785041)
Finished validation.
Starting training epoch 119
Epoch: 119, MSE 0.00506423 (0.00529894), PSNR 22.95486450 (22.76774696), SSIM 0.71242589 (0.71325597)
Finished training epoch 119
Validate: MSE 0.00553682 (0.00549218), PSNR 22.56739616 (22.61658496), SSIM 0.72985411 (0.72011301)
Finished validation.
Starting training epoch 120
Epoch: 120, MSE 0.00505546 (0.00529475), PSNR 22.96239281 (22.77131573), SSIM 0.71661127 (0.71283018)
Finished training epoch 120
Validate: MSE 0.00580177 (0.00577255), PSNR 22.36439323 (22.39733431), SSIM 0.71862072 (0.70691874)
Finished validation.
Starting training epoch 121
Epoch: 121, MSE 0.00485288 (0.00527766), PSNR 23.13999939 (22.78618993), SSIM 0.71386212 (0.71335656)
Finished training epoch 121
Validate: MSE 0.00595315 (0.00575163), PSNR 22.25253296 (22.41206195), SSIM 0.72318310 (0.71428388)
Finished validation.
Starting training epoch 122
Epoch: 122, MS

  return func(*args, **kwargs)


Validate: MSE 0.00605411 (0.00590989), PSNR 22.17949486 (22.30588836), SSIM 0.72071481 (0.71233436)
Finished validation.
Starting training epoch 128
Epoch: 128, MSE 0.00494395 (0.00528413), PSNR 23.05925560 (22.78043746), SSIM 0.73218852 (0.71288809)
Finished training epoch 128
Validate: MSE 0.00553781 (0.00562147), PSNR 22.56661987 (22.51513093), SSIM 0.72480100 (0.71124997)
Finished validation.
Starting training epoch 129
Epoch: 129, MSE 0.00573439 (0.00521994), PSNR 22.41513062 (22.83233949), SSIM 0.70741427 (0.71357988)
Finished training epoch 129
Validate: MSE 0.00597614 (0.00602167), PSNR 22.23579025 (22.21478450), SSIM 0.70459980 (0.69817243)
Finished validation.
Starting training epoch 130
Epoch: 130, MSE 0.00600562 (0.00519222), PSNR 22.21441841 (22.85816413), SSIM 0.69065100 (0.71317197)
Finished training epoch 130


  return func(*args, **kwargs)


Validate: MSE 0.00655714 (0.00632710), PSNR 21.83285141 (22.00241915), SSIM 0.71107703 (0.70344919)
Finished validation.
Starting training epoch 131
Epoch: 131, MSE 0.00504999 (0.00517671), PSNR 22.96709633 (22.87083787), SSIM 0.71106124 (0.71350349)
Finished training epoch 131
Validate: MSE 0.00661154 (0.00629068), PSNR 21.79697418 (22.02023183), SSIM 0.70769364 (0.70124823)
Finished validation.
Starting training epoch 132
Epoch: 132, MSE 0.00524041 (0.00520935), PSNR 22.80634880 (22.84428974), SSIM 0.71705872 (0.71325128)
Finished training epoch 132
Validate: MSE 0.00567947 (0.00579883), PSNR 22.45692444 (22.37596326), SSIM 0.71883357 (0.71181237)
Finished validation.
Starting training epoch 133
Epoch: 133, MSE 0.00558606 (0.00518244), PSNR 22.52894592 (22.86690415), SSIM 0.70925742 (0.71346438)
Finished training epoch 133
Validate: MSE 0.00565024 (0.00575708), PSNR 22.47933197 (22.40625775), SSIM 0.72452807 (0.71403287)
Finished validation.
Starting training epoch 134
Epoch: 134, MS

  return func(*args, **kwargs)


Validate: MSE 0.00534135 (0.00548138), PSNR 22.72349167 (22.62083409), SSIM 0.72351646 (0.71186043)
Finished validation.
Starting training epoch 140
Epoch: 140, MSE 0.00451643 (0.00510782), PSNR 23.45204353 (22.93022238), SSIM 0.70903462 (0.71410943)
Finished training epoch 140
Validate: MSE 0.00627203 (0.00593871), PSNR 22.02592087 (22.26805380), SSIM 0.71682674 (0.70855086)
Finished validation.
Starting training epoch 141
Epoch: 141, MSE 0.00529484 (0.00514105), PSNR 22.76147270 (22.90028074), SSIM 0.71326518 (0.71313012)
Finished training epoch 141
Validate: MSE 0.00551110 (0.00540224), PSNR 22.58761215 (22.68600452), SSIM 0.72138566 (0.71121370)
Finished validation.
Starting training epoch 142
Epoch: 142, MSE 0.00585177 (0.00506340), PSNR 22.32712746 (22.96524891), SSIM 0.70608586 (0.71368195)
Finished training epoch 142


  return func(*args, **kwargs)


Validate: MSE 0.00584672 (0.00596937), PSNR 22.33087730 (22.24628471), SSIM 0.71380574 (0.70528518)
Finished validation.
Starting training epoch 143
Epoch: 143, MSE 0.00604661 (0.00515800), PSNR 22.18488121 (22.88623533), SSIM 0.69301337 (0.71323240)
Finished training epoch 143
Validate: MSE 0.00603710 (0.00589168), PSNR 22.19171715 (22.31156092), SSIM 0.71806723 (0.70842945)
Finished validation.
Starting training epoch 144
Epoch: 144, MSE 0.00547448 (0.00511286), PSNR 22.61656952 (22.92356308), SSIM 0.72040248 (0.71327326)
Finished training epoch 144
Validate: MSE 0.00578987 (0.00556045), PSNR 22.37330818 (22.55822183), SSIM 0.72460747 (0.71459599)
Finished validation.
Starting training epoch 145
Epoch: 145, MSE 0.00568978 (0.00510733), PSNR 22.44904137 (22.92837284), SSIM 0.71389711 (0.71338301)
Finished training epoch 145
Validate: MSE 0.00558174 (0.00565497), PSNR 22.53229904 (22.48466121), SSIM 0.72256386 (0.71107315)
Finished validation.
Starting training epoch 146
Epoch: 146, MS

  return func(*args, **kwargs)


Validate: MSE 0.00601563 (0.00578119), PSNR 22.20718384 (22.39646075), SSIM 0.71460968 (0.70513800)
Finished validation.
Starting training epoch 161
Epoch: 161, MSE 0.00483711 (0.00504148), PSNR 23.15413475 (22.98276028), SSIM 0.72739923 (0.71342976)
Finished training epoch 161
Validate: MSE 0.00552408 (0.00546462), PSNR 22.57739449 (22.63540019), SSIM 0.72605127 (0.71400883)
Finished validation.
Starting training epoch 162
Epoch: 162, MSE 0.00507716 (0.00499760), PSNR 22.94379234 (23.02267793), SSIM 0.71127099 (0.71382287)
Finished training epoch 162
Validate: MSE 0.00524683 (0.00532572), PSNR 22.80103302 (22.74318306), SSIM 0.72032988 (0.70932589)
Finished validation.
Starting training epoch 163
Epoch: 163, MSE 0.00552106 (0.00494755), PSNR 22.57977295 (23.06809382), SSIM 0.69973844 (0.71385491)
Finished training epoch 163
Validate: MSE 0.00626224 (0.00590697), PSNR 22.03269958 (22.29046462), SSIM 0.70987988 (0.70075487)
Finished validation.
Starting training epoch 164
Epoch: 164, MS

  return func(*args, **kwargs)


Validate: MSE 0.00579968 (0.00555319), PSNR 22.36595535 (22.56536566), SSIM 0.71770871 (0.70954439)
Finished validation.
Starting training epoch 186
Epoch: 186, MSE 0.00475508 (0.00483453), PSNR 23.22841835 (23.16596341), SSIM 0.71665508 (0.71406910)
Finished training epoch 186
Validate: MSE 0.00577209 (0.00564326), PSNR 22.38666916 (22.49769090), SSIM 0.71974617 (0.71202115)
Finished validation.
Starting training epoch 187
Epoch: 187, MSE 0.00410083 (0.00487885), PSNR 23.87128258 (23.12673379), SSIM 0.73099571 (0.71393811)
Finished training epoch 187
Validate: MSE 0.00576545 (0.00579700), PSNR 22.39166832 (22.38441743), SSIM 0.72053039 (0.71086182)
Finished validation.
Starting training epoch 188
Epoch: 188, MSE 0.00628558 (0.00483194), PSNR 22.01654434 (23.17212415), SSIM 0.70119309 (0.71466452)
Finished training epoch 188
Validate: MSE 0.00570187 (0.00552311), PSNR 22.43982506 (22.58762506), SSIM 0.71858335 (0.70922443)
Finished validation.
Starting training epoch 189
Epoch: 189, MS

  return func(*args, **kwargs)


Validate: MSE 0.00565497 (0.00562912), PSNR 22.47569466 (22.50533714), SSIM 0.71993065 (0.70831514)
Finished validation.
Starting training epoch 208
Epoch: 208, MSE 0.00516202 (0.00481305), PSNR 22.87180328 (23.19028195), SSIM 0.70974356 (0.71394235)
Finished training epoch 208
Validate: MSE 0.00592966 (0.00589407), PSNR 22.26969910 (22.31187290), SSIM 0.71550483 (0.70658748)
Finished validation.
Starting training epoch 209
Epoch: 209, MSE 0.00447156 (0.00475862), PSNR 23.49541092 (23.23407139), SSIM 0.72609085 (0.71411248)
Finished training epoch 209
Validate: MSE 0.00534430 (0.00523468), PSNR 22.72109032 (22.82273467), SSIM 0.71975648 (0.71029715)
Finished validation.
Starting training epoch 210
Epoch: 210, MSE 0.00438355 (0.00477599), PSNR 23.58173752 (23.21929383), SSIM 0.71671814 (0.71423564)
Finished training epoch 210


  return func(*args, **kwargs)


Validate: MSE 0.00701593 (0.00653145), PSNR 21.53914452 (21.86178587), SSIM 0.69648242 (0.68993047)
Finished validation.
Starting training epoch 211
Epoch: 211, MSE 0.00462247 (0.00476862), PSNR 23.35125542 (23.22697816), SSIM 0.73216617 (0.71417022)
Finished training epoch 211
Validate: MSE 0.00520974 (0.00543351), PSNR 22.83183479 (22.66260555), SSIM 0.72828287 (0.71597534)
Finished validation.
Starting training epoch 212
Epoch: 212, MSE 0.00490119 (0.00474650), PSNR 23.09698868 (23.24612335), SSIM 0.73035586 (0.71461359)
Finished training epoch 212
Validate: MSE 0.00527669 (0.00540125), PSNR 22.77638054 (22.68302219), SSIM 0.71804452 (0.70916855)
Finished validation.
Starting training epoch 213
Epoch: 213, MSE 0.00466858 (0.00472009), PSNR 23.30815315 (23.27186627), SSIM 0.71072692 (0.71434838)
Finished training epoch 213
Validate: MSE 0.00543653 (0.00526083), PSNR 22.64678192 (22.79852825), SSIM 0.71895242 (0.70889992)
Finished validation.
Starting training epoch 214
Epoch: 214, MS

  return func(*args, **kwargs)


Validate: MSE 0.00559314 (0.00529723), PSNR 22.52344131 (22.77128355), SSIM 0.71729356 (0.70799379)
Finished validation.
Starting training epoch 219
Epoch: 219, MSE 0.00528982 (0.00473289), PSNR 22.76559258 (23.26004026), SSIM 0.71743798 (0.71412821)
Finished training epoch 219


  return func(*args, **kwargs)


Validate: MSE 0.00564162 (0.00543538), PSNR 22.48595810 (22.66040127), SSIM 0.71487141 (0.70637450)
Finished validation.
Starting training epoch 220
Epoch: 220, MSE 0.00482352 (0.00472918), PSNR 23.16635895 (23.26430561), SSIM 0.71693975 (0.71430153)
Finished training epoch 220
Validate: MSE 0.00521626 (0.00527854), PSNR 22.82640648 (22.78202461), SSIM 0.72321945 (0.71096930)
Finished validation.
Starting training epoch 221
Epoch: 221, MSE 0.00460095 (0.00470299), PSNR 23.37152481 (23.28595294), SSIM 0.71073121 (0.71467368)
Finished training epoch 221


  return func(*args, **kwargs)


Validate: MSE 0.00515165 (0.00527450), PSNR 22.88053322 (22.78768304), SSIM 0.72460854 (0.71214341)
Finished validation.
Starting training epoch 222
Epoch: 222, MSE 0.00500656 (0.00469352), PSNR 23.00460625 (23.29215326), SSIM 0.70926422 (0.71450509)
Finished training epoch 222
Validate: MSE 0.00592639 (0.00561903), PSNR 22.27210045 (22.50910813), SSIM 0.71341848 (0.70359042)
Finished validation.
Starting training epoch 223
Epoch: 223, MSE 0.00493817 (0.00469201), PSNR 23.06433678 (23.29811936), SSIM 0.70959127 (0.71418314)
Finished training epoch 223
Validate: MSE 0.00513268 (0.00514066), PSNR 22.89655685 (22.90321260), SSIM 0.72644967 (0.71619999)
Finished validation.
Starting training epoch 224
Epoch: 224, MSE 0.00500126 (0.00474354), PSNR 23.00920105 (23.24994770), SSIM 0.70997345 (0.71432640)
Finished training epoch 224
Validate: MSE 0.00545459 (0.00528022), PSNR 22.63237381 (22.78241684), SSIM 0.72727877 (0.71852507)
Finished validation.
Starting training epoch 225
Epoch: 225, MS

  return func(*args, **kwargs)


Validate: MSE 0.00574282 (0.00561931), PSNR 22.40875053 (22.51677381), SSIM 0.71380949 (0.70328516)
Finished validation.
Starting training epoch 231
Epoch: 231, MSE 0.00484009 (0.00470749), PSNR 23.15146255 (23.28208664), SSIM 0.70377833 (0.71420151)
Finished training epoch 231
Validate: MSE 0.00544581 (0.00547029), PSNR 22.63937759 (22.62887784), SSIM 0.72485352 (0.71264658)
Finished validation.
Starting training epoch 232
Epoch: 232, MSE 0.00454532 (0.00466418), PSNR 23.42435074 (23.31981378), SSIM 0.71274418 (0.71443824)
Finished training epoch 232
Validate: MSE 0.00546440 (0.00525015), PSNR 22.62457848 (22.80958932), SSIM 0.72137433 (0.71152633)
Finished validation.
Starting training epoch 233
Epoch: 233, MSE 0.00419828 (0.00465622), PSNR 23.76927948 (23.33180367), SSIM 0.73247379 (0.71452200)
Finished training epoch 233
Validate: MSE 0.00606890 (0.00605407), PSNR 22.16890335 (22.18377717), SSIM 0.71036094 (0.70120345)
Finished validation.
Starting training epoch 234
Epoch: 234, MS

<Figure size 432x288 with 0 Axes>

In [20]:
torch.save(model.state_dict(), f'{checkpoints}/last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')

In [21]:
# Validate
save_images = True
with torch.no_grad():
    validate(val_loader, model, criterion, save_images, -1)

Validate: MSE 0.00552353 (0.00528050), PSNR 22.57783127 (22.78507738), SSIM 0.71978343 (0.70994222)
Finished validation.


<Figure size 432x288 with 0 Axes>

In [22]:
# # Show images 
# image_pairs = []

# for i in range(10):
#     image_pairs.append((f'{color_imgs}img-{i}-epoch-{best_epoch}.jpg', f'{gray_imgs}img-{i}-epoch-{best_epoch}.jpg'))
    
# for c, g in image_pairs:
#   color = mpimg.imread(c)
#   gray  = mpimg.imread(g)
#   f, axarr = plt.subplots(1, 2)
#   f.set_size_inches(15, 15)
#   axarr[0].imshow(gray, cmap='gray')
#   axarr[1].imshow(color)
#   axarr[0].axis('off'), axarr[1].axis('off')
#   plt.show()