In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# For conversion
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchmetrics import MeanSquaredError, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from PIL import Image
# For utilities
import os, shutil, time

In [2]:
%load_ext tensorboard
%tensorboard --logdir=runs

In [3]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [4]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()
print(use_gpu)

True


In [5]:
SIZE = 32
class LabImageFolder(torch.utils.data.Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC),
                transforms.RandomCrop(SIZE),
                transforms.RandomHorizontalFlip(), 
                transforms.ToTensor(),
                transforms.Normalize((0.4918019, 0.48263696, 0.44733968), (0.24705184, 0.2433789, 0.26162848)),
                transforms.ToPILImage()
            ])
        elif split == 'val':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC), 
                transforms.RandomCrop(SIZE),
                transforms.ToTensor(),
                transforms.Normalize((0.4918019, 0.48263696, 0.44733968), (0.24705184, 0.2433789, 0.26162848)),
                transforms.ToPILImage()
            ])
            
        self.split = split
        self.size = SIZE
        self.paths = [os.path.join(paths, file) for file in os.listdir(paths) if os.path.isfile(
            os.path.join(paths, file))]
        
        
    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("RGB")
        img_original = self.transforms(img)
        img_original = np.asarray(img_original)
        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        img_gray = rgb2gray(img_original)
        img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
        return img_gray, img_ab
    
    def __len__(self):
        return len(self.paths)

In [6]:
# Training
batch_size = 128
train_imagefolder = LabImageFolder('../../datasets/cifar10/train')
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=batch_size, shuffle=True)
# Validation 
val_imagefolder = LabImageFolder('../../datasets/cifar10/val' , 'val')
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=batch_size, shuffle=False)

In [7]:
kernel_size=3
stride_en=2
stride_de=1
padding=1
scale_factor=2
padding_mode='zeros'


class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        self.conv1 = nn.Conv2d(1, 16, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        
        self.convtrans1 = nn.ConvTranspose2d(64, 32, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans2 = nn.ConvTranspose2d(32, 16, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans3 = nn.ConvTranspose2d(16, 8, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans4 = nn.ConvTranspose2d(8, 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)

        self.batchnorm8 = nn.BatchNorm2d(8)
        self.batchnorm16 = nn.BatchNorm2d(16)
        self.batchnorm32 = nn.BatchNorm2d(32)
        self.batchnorm64 = nn.BatchNorm2d(64)
        
        
    def forward(self, input):
        # encoder
        x = F.relu(self.batchnorm16(self.conv1(input)))
        x = F.relu(self.batchnorm32(self.conv2(x)))
        x = F.relu(self.batchnorm64(self.conv3(x)))
        
        # decoder
        x = F.relu(self.batchnorm32(self.convtrans1(x)))
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.relu(self.batchnorm16(self.convtrans2(x)))
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.relu(self.batchnorm8(self.convtrans3(x)))
        x = F.interpolate(self.convtrans4(x), scale_factor=scale_factor)

        return x

In [8]:
model = Autoencoder()

In [9]:
criterion = [MeanSquaredError(), PeakSignalNoiseRatio(data_range=1.0), StructuralSimilarityIndexMeasure(data_range=1.0)]

In [10]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [11]:
# # Move model and loss function to GPU
if use_gpu: 
    criterion = [criterion[0].to("cuda"), criterion[1].to("cuda"), criterion[2].to("cuda")]
    model = model.cuda()

In [12]:
if use_gpu: 
    from torchsummary import summary
    summary(model, (1, SIZE, SIZE))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 16, 16]             160
       BatchNorm2d-2           [-1, 16, 16, 16]              32
            Conv2d-3             [-1, 32, 8, 8]           4,640
       BatchNorm2d-4             [-1, 32, 8, 8]              64
            Conv2d-5             [-1, 64, 4, 4]          18,496
       BatchNorm2d-6             [-1, 64, 4, 4]             128
   ConvTranspose2d-7             [-1, 32, 4, 4]          18,464
       BatchNorm2d-8             [-1, 32, 4, 4]              64
   ConvTranspose2d-9             [-1, 16, 8, 8]           4,624
      BatchNorm2d-10             [-1, 16, 8, 8]              32
  ConvTranspose2d-11            [-1, 8, 16, 16]           1,160
      BatchNorm2d-12            [-1, 8, 16, 16]              16
  ConvTranspose2d-13            [-1, 2, 16, 16]             146
Total params: 48,026
Trainable params: 

In [13]:
class AverageMeter(object):
    '''A handy class from the PyTorch ImageNet tutorial''' 
    def __init__(self):
        self.reset()
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
       Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
    color_image = lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    if save_path is not None and save_name is not None: 
        plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
        plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

In [14]:
color_imgs = 'outputs/color/'
gray_imgs = 'outputs/gray/'

In [15]:
def validate(val_loader, model, criterion, save_images, epoch):
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]

    model.eval()
    already_saved_images = False
    for gray, ab in val_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()

        # Run model and record loss
        output_ab = model(gray) # throw away class predictions
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))

        # Save images to file
        if save_images and not already_saved_images:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)): # save at most 5 images
                save_path = {'grayscale': gray_imgs, 'colorized': color_imgs}
                save_name = 'img-{}-epoch-{}.jpg'.format(j, epoch)
                to_rgb(gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

    print(f'Validate: MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print('Finished validation.')
    if epoch >= 0:        
        writer.add_scalar("MSE/test", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/test", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/test", _loss[2].avg, epoch)
    return _loss[0].avg, _loss[1].avg, _loss[2].avg

In [16]:
def train(train_loader, model, criterion, optimizer, epoch):
    print(f'Starting training epoch {epoch}')
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]
    
    model.train()

    for gray, ab in train_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()
            
        optimizer.zero_grad()

        output_ab = model(gray) 
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        loss[0].backward()
        optimizer.step()
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))
        
    print(f'Epoch: {epoch}, MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print(f'Finished training epoch {epoch}')
    if epoch >= 0:
        writer.add_scalar("MSE/train", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/train", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/train", _loss[2].avg, epoch)

In [17]:
# Make folders and set parameters
checkpoints = 'checkpoints'
os.makedirs(color_imgs, exist_ok=True)
os.makedirs(gray_imgs, exist_ok=True)
os.makedirs(checkpoints, exist_ok=True)
save_images = True
best_losses = [1e10, 1e10, 1e10]
best_epoch = -1
patience = 50
epochs = 500

In [18]:
# Train model
for epoch in range(epochs):
    # Train for one epoch, then validate
    train(train_loader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(val_loader, model, criterion, save_images, epoch)
    # Save checkpoint and replace old best model if current model is better
    if losses[0] < best_losses[0]:
        best_losses[0] = losses[0]
        best_epoch = epoch
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}.pth')
    if losses[1] < best_losses[1]:
        best_losses[1] = losses[1]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-PSNRLoss-{losses[1]:.8f}.pth')
    if losses[2] < best_losses[2]:
        best_losses[2] = losses[2]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-SSIMLoss-{losses[2]:.8f}.pth')
    
    if epoch - best_epoch >= patience:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}-early_stop.pth')
        break
    
    if epoch == epochs - 1:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')


Starting training epoch 0
Epoch: 0, MSE 0.02296321 (0.03213845), PSNR 16.38967514 (15.63661413), SSIM 0.08838975 (0.08002797)
Finished training epoch 0


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.02577635 (0.02721111), PSNR 15.88778591 (15.65782379), SSIM 0.05900122 (0.08049460)
Finished validation.
Starting training epoch 1
Epoch: 1, MSE 0.02303727 (0.02416200), PSNR 16.37569046 (16.17018074), SSIM 0.10474984 (0.09548833)
Finished training epoch 1
Validate: MSE 0.02636596 (0.02764731), PSNR 15.78956318 (15.58915251), SSIM 0.05732761 (0.08201012)
Finished validation.
Starting training epoch 2
Epoch: 2, MSE 0.02377292 (0.02386727), PSNR 16.23917389 (16.22324094), SSIM 0.09601790 (0.10199876)
Finished training epoch 2
Validate: MSE 0.02602686 (0.02734291), PSNR 15.84578228 (15.63724310), SSIM 0.05749888 (0.08103151)
Finished validation.
Starting training epoch 3
Epoch: 3, MSE 0.02499909 (0.02362285), PSNR 16.02075768 (16.26788302), SSIM 0.11226132 (0.10886607)
Finished training epoch 3
Validate: MSE 0.02647737 (0.02779053), PSNR 15.77124977 (15.56665451), SSIM 0.06021140 (0.08090390)
Finished validation.
Starting training epoch 4
Epoch: 4, MSE 0.02271152 (0.023461

  return func(*args, **kwargs)


Validate: MSE 0.02902504 (0.03047368), PSNR 15.37227058 (15.16410019), SSIM 0.06600967 (0.07165848)
Finished validation.
Starting training epoch 28
Epoch: 28, MSE 0.02180263 (0.02178687), PSNR 16.61491203 (16.61927848), SSIM 0.20906714 (0.20615296)
Finished training epoch 28


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.03194495 (0.03333779), PSNR 14.95597649 (14.77348051), SSIM 0.05447592 (0.06390880)
Finished validation.
Starting training epoch 29
Epoch: 29, MSE 0.02148290 (0.02176656), PSNR 16.67906952 (16.62341918), SSIM 0.22996943 (0.20707053)
Finished training epoch 29


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.03577066 (0.03762915), PSNR 14.46472931 (14.24710272), SSIM 0.05690526 (0.06650925)
Finished validation.
Starting training epoch 30
Epoch: 30, MSE 0.02189891 (0.02175281), PSNR 16.59577560 (16.62595023), SSIM 0.21111903 (0.20750393)
Finished training epoch 30


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.03283965 (0.03461972), PSNR 14.83601379 (14.60932981), SSIM 0.05890835 (0.06719450)
Finished validation.
Starting training epoch 31
Epoch: 31, MSE 0.02137048 (0.02172793), PSNR 16.70185661 (16.63088001), SSIM 0.21260500 (0.20830315)
Finished training epoch 31


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.06229769 (0.06474145), PSNR 12.05528069 (11.88942542), SSIM 0.06148652 (0.06810257)
Finished validation.
Starting training epoch 32
Epoch: 32, MSE 0.02230737 (0.02169733), PSNR 16.51551437 (16.63705117), SSIM 0.20859821 (0.20911570)
Finished training epoch 32


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.04656847 (0.04870580), PSNR 13.31907940 (13.12587366), SSIM 0.06219275 (0.07064142)
Finished validation.
Starting training epoch 33
Epoch: 33, MSE 0.02228877 (0.02169524), PSNR 16.51913834 (16.63745532), SSIM 0.20305243 (0.20944072)
Finished training epoch 33


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.05206514 (0.05423606), PSNR 12.83452892 (12.65844577), SSIM 0.05441045 (0.06148927)
Finished validation.
Starting training epoch 34
Epoch: 34, MSE 0.02145355 (0.02167917), PSNR 16.68500710 (16.64082905), SSIM 0.21258171 (0.20993583)
Finished training epoch 34


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.05206527 (0.05412078), PSNR 12.83451843 (12.66770542), SSIM 0.05577087 (0.06582712)
Finished validation.
Starting training epoch 35
Epoch: 35, MSE 0.02201724 (0.02165733), PSNR 16.57237053 (16.64507375), SSIM 0.19677106 (0.21060926)
Finished training epoch 35


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.04943369 (0.05170355), PSNR 13.05976963 (12.86632395), SSIM 0.06453466 (0.07287499)
Finished validation.
Starting training epoch 36
Epoch: 36, MSE 0.02162805 (0.02165514), PSNR 16.64982605 (16.64554918), SSIM 0.21571901 (0.21063187)
Finished training epoch 36


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.07595317 (0.07857755), PSNR 11.19454098 (11.04796492), SSIM 0.05618010 (0.06419631)
Finished validation.
Starting training epoch 37
Epoch: 37, MSE 0.02120622 (0.02163664), PSNR 16.73536491 (16.64926109), SSIM 0.21058910 (0.21131597)
Finished training epoch 37


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.07319042 (0.07591916), PSNR 11.35545731 (11.19748212), SSIM 0.05662475 (0.06235569)
Finished validation.
Starting training epoch 38
Epoch: 38, MSE 0.02150553 (0.02162015), PSNR 16.67449760 (16.65257553), SSIM 0.22060160 (0.21156559)
Finished training epoch 38


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.05811544 (0.06054928), PSNR 12.35708427 (12.18011744), SSIM 0.05746503 (0.06389285)
Finished validation.
Starting training epoch 39
Epoch: 39, MSE 0.02274364 (0.02161164), PSNR 16.43139839 (16.65432341), SSIM 0.21264747 (0.21189165)
Finished training epoch 39


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.06713283 (0.07012769), PSNR 11.73064995 (11.54218253), SSIM 0.06802955 (0.06617098)
Finished validation.
Starting training epoch 40
Epoch: 40, MSE 0.02136628 (0.02160785), PSNR 16.70270920 (16.65496807), SSIM 0.20604356 (0.21195109)
Finished training epoch 40


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.07573870 (0.07809987), PSNR 11.20682144 (11.07441971), SSIM 0.05084012 (0.05547778)
Finished validation.
Starting training epoch 41
Epoch: 41, MSE 0.02309768 (0.02159397), PSNR 16.36431503 (16.65785160), SSIM 0.21123648 (0.21245527)
Finished training epoch 41


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.09487887 (0.09747487), PSNR 10.22830486 (10.11184203), SSIM 0.05941420 (0.06516623)
Finished validation.
Starting training epoch 42
Epoch: 42, MSE 0.02030238 (0.02160473), PSNR 16.92453003 (16.65582573), SSIM 0.21451071 (0.21222066)
Finished training epoch 42


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.06980397 (0.07237999), PSNR 11.56119919 (11.40479637), SSIM 0.06457946 (0.06702186)
Finished validation.
Starting training epoch 43
Epoch: 43, MSE 0.02264879 (0.02158749), PSNR 16.44954872 (16.65923409), SSIM 0.20728111 (0.21260796)
Finished training epoch 43


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.06783709 (0.07031244), PSNR 11.68532753 (11.53070109), SSIM 0.05492622 (0.05942199)
Finished validation.
Starting training epoch 44
Epoch: 44, MSE 0.02176495 (0.02160331), PSNR 16.62242317 (16.65614935), SSIM 0.20638239 (0.21232381)
Finished training epoch 44


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.05335092 (0.05559487), PSNR 12.72858047 (12.55100381), SSIM 0.06797072 (0.06957464)
Finished validation.
Starting training epoch 45
Epoch: 45, MSE 0.02221427 (0.02156390), PSNR 16.53367805 (16.66397057), SSIM 0.20432273 (0.21311403)
Finished training epoch 45


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.08372941 (0.08579346), PSNR 10.77121925 (10.66631155), SSIM 0.05982638 (0.06177848)
Finished validation.
Starting training epoch 46
Epoch: 46, MSE 0.02136158 (0.02157268), PSNR 16.70366478 (16.66231117), SSIM 0.21377745 (0.21294960)
Finished training epoch 46


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.06332123 (0.06606729), PSNR 11.98450565 (11.80124850), SSIM 0.06442861 (0.06360338)
Finished validation.
Starting training epoch 47
Epoch: 47, MSE 0.02249391 (0.02155743), PSNR 16.47935104 (16.66528709), SSIM 0.21244986 (0.21314905)
Finished training epoch 47


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.08601670 (0.08794626), PSNR 10.65417194 (10.55859681), SSIM 0.05640973 (0.05681605)
Finished validation.
Starting training epoch 48
Epoch: 48, MSE 0.02093522 (0.02154981), PSNR 16.79122543 (16.66662180), SSIM 0.21526022 (0.21351865)
Finished training epoch 48


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.07257536 (0.07492012), PSNR 11.39210796 (11.25491973), SSIM 0.05926995 (0.05807974)
Finished validation.
Starting training epoch 49
Epoch: 49, MSE 0.02175934 (0.02155210), PSNR 16.62354088 (16.66629283), SSIM 0.21972713 (0.21335171)
Finished training epoch 49


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.06209818 (0.06446897), PSNR 12.06921101 (11.90756685), SSIM 0.07160815 (0.06603606)
Finished validation.
Starting training epoch 50
Epoch: 50, MSE 0.02188122 (0.02153564), PSNR 16.59928513 (16.66960761), SSIM 0.21806729 (0.21376426)
Finished training epoch 50


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.06146006 (0.06379455), PSNR 12.11406898 (11.95323207), SSIM 0.06247158 (0.05848121)
Finished validation.
Starting training epoch 51
Epoch: 51, MSE 0.02101708 (0.02154234), PSNR 16.77427483 (16.66817258), SSIM 0.20917884 (0.21379438)
Finished training epoch 51


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.07670142 (0.07895828), PSNR 11.15196514 (11.02692707), SSIM 0.05503855 (0.05506257)
Finished validation.
Starting training epoch 52
Epoch: 52, MSE 0.02241582 (0.02153796), PSNR 16.49445152 (16.66906349), SSIM 0.20852847 (0.21386982)
Finished training epoch 52


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.05843434 (0.06061635), PSNR 12.33331871 (12.17528222), SSIM 0.07024387 (0.06775566)
Finished validation.
Starting training epoch 53
Epoch: 53, MSE 0.02239886 (0.02152914), PSNR 16.49774170 (16.67094223), SSIM 0.19658743 (0.21395905)
Finished training epoch 53


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.06006768 (0.06242184), PSNR 12.21359158 (12.04765277), SSIM 0.05669860 (0.05534262)
Finished validation.
Starting training epoch 54
Epoch: 54, MSE 0.02029816 (0.02151747), PSNR 16.92543221 (16.67331308), SSIM 0.21828265 (0.21426536)
Finished training epoch 54


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.06914373 (0.07157573), PSNR 11.60247135 (11.45326476), SSIM 0.06332284 (0.05569111)
Finished validation.
Starting training epoch 55
Epoch: 55, MSE 0.02022731 (0.02151578), PSNR 16.94061852 (16.67372625), SSIM 0.22825353 (0.21429145)
Finished training epoch 55


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.07677383 (0.07913303), PSNR 11.14786720 (11.01725380), SSIM 0.06197778 (0.05797596)
Finished validation.
Starting training epoch 56
Epoch: 56, MSE 0.02127771 (0.02151278), PSNR 16.72075081 (16.67415587), SSIM 0.20675275 (0.21439647)
Finished training epoch 56


  return func(*args, **kwargs)


Validate: MSE 0.05442329 (0.05661988), PSNR 12.64215183 (12.47146997), SSIM 0.06341606 (0.05797265)
Finished validation.
Starting training epoch 57
Epoch: 57, MSE 0.02251576 (0.02151263), PSNR 16.47513390 (16.67428967), SSIM 0.20113197 (0.21453909)
Finished training epoch 57


  return func(*args, **kwargs)


Validate: MSE 0.05651628 (0.05946521), PSNR 12.47826385 (12.25842671), SSIM 0.06434094 (0.05905219)
Finished validation.
Starting training epoch 58
Epoch: 58, MSE 0.02161378 (0.02151646), PSNR 16.65269279 (16.67350961), SSIM 0.22744524 (0.21436879)
Finished training epoch 58


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.06455753 (0.06742218), PSNR 11.90053082 (11.71283241), SSIM 0.06393441 (0.05834963)
Finished validation.
Starting training epoch 59
Epoch: 59, MSE 0.02195504 (0.02150364), PSNR 16.58465576 (16.67598281), SSIM 0.20991436 (0.21450668)
Finished training epoch 59


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.05428514 (0.05657092), PSNR 12.65318871 (12.47504369), SSIM 0.05346554 (0.05154267)
Finished validation.
Starting training epoch 60
Epoch: 60, MSE 0.02070709 (0.02149291), PSNR 16.83880997 (16.67832418), SSIM 0.20824757 (0.21477239)
Finished training epoch 60
Validate: MSE 0.05195842 (0.05492006), PSNR 12.84344101 (12.60377670), SSIM 0.06140968 (0.05554355)
Finished validation.
Starting training epoch 61
Epoch: 61, MSE 0.02157236 (0.02149556), PSNR 16.66102409 (16.67774591), SSIM 0.20609005 (0.21479862)
Finished training epoch 61


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.06694514 (0.07010063), PSNR 11.74280930 (11.54354114), SSIM 0.05697448 (0.05593866)
Finished validation.
Starting training epoch 62
Epoch: 62, MSE 0.02189090 (0.02149201), PSNR 16.59736252 (16.67828171), SSIM 0.21083418 (0.21486531)
Finished training epoch 62


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.08790126 (0.09051277), PSNR 10.56004906 (10.43354481), SSIM 0.04878916 (0.05056742)
Finished validation.
Starting training epoch 63
Epoch: 63, MSE 0.02098606 (0.02150044), PSNR 16.78068924 (16.67672721), SSIM 0.22896294 (0.21488185)
Finished training epoch 63


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.07486358 (0.07784533), PSNR 11.25729370 (11.08836121), SSIM 0.05542227 (0.05677809)
Finished validation.


<Figure size 432x288 with 0 Axes>

In [19]:
torch.save(model.state_dict(), f'{checkpoints}/last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')

In [20]:
# Validate
save_images = True
with torch.no_grad():
    validate(val_loader, model, criterion, save_images, -1)

Validate: MSE 0.07486358 (0.07784533), PSNR 11.25729370 (11.08836121), SSIM 0.05542227 (0.05677809)
Finished validation.


<Figure size 432x288 with 0 Axes>

In [21]:
# # Show images 
# image_pairs = []

# for i in range(10):
#     image_pairs.append((f'{color_imgs}img-{i}-epoch-{best_epoch}.jpg', f'{gray_imgs}img-{i}-epoch-{best_epoch}.jpg'))
    
# for c, g in image_pairs:
#   color = mpimg.imread(c)
#   gray  = mpimg.imread(g)
#   f, axarr = plt.subplots(1, 2)
#   f.set_size_inches(15, 15)
#   axarr[0].imshow(gray, cmap='gray')
#   axarr[1].imshow(color)
#   axarr[0].axis('off'), axarr[1].axis('off')
#   plt.show()