In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# For conversion
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchmetrics import MeanSquaredError, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from PIL import Image
# For utilities
import os, shutil, time

In [2]:
%load_ext tensorboard
%tensorboard --logdir=runs

In [3]:
# colab i kaggle jeszcze nie testowane
colab = False
kaggle = False
test_number = '12_2'

In [4]:
color_imgs = 'outputs/color/'
gray_imgs = 'outputs/gray/'
checkpoints = 'checkpoints'
if colab:
    from google.colab import drive
    drive.mount('/content/drive')
    dataset = '/content/drive/MyDrive/MGU/cifar10/'
    
    color_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{color_imgs}'
    gray_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{gray_imgs}'
    checkpoints = f'/content/drive/MyDrive/MGU/{test_number}/{checkpoints}'
elif kaggle:
    dataset = '/kaggle/input/cifar10/'
    
    color_imgs = f'{test_number}/{color_imgs}'
    gray_imgs = f'{test_number}/{gray_imgs}'
    checkpoints = f'{test_number}/{checkpoints}'
else:
    dataset = '../../datasets/cifar10/'

In [5]:
# Make folders and set parameters
os.makedirs(color_imgs, exist_ok=True)
os.makedirs(gray_imgs, exist_ok=True)
os.makedirs(checkpoints, exist_ok=True)
save_images = True
best_losses = [1e10, 1e10, 1e10]
best_epoch = -1
patience = 50
epochs = 500
batch_size = 128
SIZE = 32

In [6]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [7]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()
print(use_gpu)

True


In [8]:
class LabImageFolder(torch.utils.data.Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC),
                transforms.RandomCrop(SIZE),
                transforms.RandomHorizontalFlip(),  
            ])
        elif split == 'val':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC), 
                transforms.RandomCrop(SIZE), 
            ])
            
        self.split = split
        self.size = SIZE
        self.paths = [os.path.join(paths, file) for file in os.listdir(paths) if os.path.isfile(
            os.path.join(paths, file))]
        
        
    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("RGB")
        img_original = self.transforms(img)
        img_original = np.asarray(img_original)
        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        img_gray = rgb2gray(img_original)
        img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
        return img_gray, img_ab
    
    def __len__(self):
        return len(self.paths)

In [9]:
# Training
train_imagefolder = LabImageFolder(dataset + 'train')
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=batch_size, shuffle=True)
# Validation 
val_imagefolder = LabImageFolder(dataset + 'val' , 'val')
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=batch_size, shuffle=False)

In [10]:
kernel_size=3
stride_en=1
stride_de=1
padding=1
scale_factor=2
padding_mode='zeros'
channels_base = 256
p1 = .5

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        self.conv1 = nn.Conv2d(1, channels_base, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        self.conv2 = nn.Conv2d(channels_base, channels_base * 2, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)

        self.convtrans1 = nn.ConvTranspose2d(channels_base * 2, channels_base, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans2 = nn.ConvTranspose2d(channels_base, channels_base // 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans3 = nn.ConvTranspose2d(channels_base // 2, 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)

        self.batchnorm1 = nn.BatchNorm2d(channels_base)
        self.batchnorm2 = nn.BatchNorm2d(channels_base * 2)
        self.batchnorm3 = nn.BatchNorm2d(channels_base)
        self.batchnorm4 = nn.BatchNorm2d(channels_base // 2)
        
    def forward(self, input):
        # encoder
        x = F.leaky_relu(self.batchnorm1(self.conv1(input)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        x = F.leaky_relu(self.batchnorm2(self.conv2(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        
        # decoder
        x = F.leaky_relu(self.batchnorm3(self.convtrans1(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.leaky_relu(self.batchnorm4(self.convtrans2(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.interpolate(self.convtrans3(x), scale_factor=scale_factor)

        return x

In [11]:
model = Autoencoder()

In [12]:
criterion = [MeanSquaredError(), PeakSignalNoiseRatio(data_range=1.0), StructuralSimilarityIndexMeasure(data_range=1.0)]

In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [14]:
# # Move model and loss function to GPU
if use_gpu: 
    criterion = [criterion[0].to("cuda"), criterion[1].to("cuda"), criterion[2].to("cuda")]
    model = model.cuda()

In [15]:
if use_gpu: 
    from torchsummary import summary
    summary(model, (1, SIZE, SIZE))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 256, 32, 32]           2,560
       BatchNorm2d-2          [-1, 256, 32, 32]             512
            Conv2d-3          [-1, 512, 16, 16]       1,180,160
       BatchNorm2d-4          [-1, 512, 16, 16]           1,024
   ConvTranspose2d-5            [-1, 256, 8, 8]       1,179,904
       BatchNorm2d-6            [-1, 256, 8, 8]             512
   ConvTranspose2d-7          [-1, 128, 16, 16]         295,040
       BatchNorm2d-8          [-1, 128, 16, 16]             256
   ConvTranspose2d-9            [-1, 2, 16, 16]           2,306
Total params: 2,662,274
Trainable params: 2,662,274
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 6.75
Params size (MB): 10.16
Estimated Total Size (MB): 16.91
-------------------------------------

In [16]:
class AverageMeter(object):
    '''A handy class from the PyTorch ImageNet tutorial''' 
    def __init__(self):
        self.reset()
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
       Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
    color_image = lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    if save_path is not None and save_name is not None: 
        plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
        plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

In [17]:
def validate(val_loader, model, criterion, save_images, epoch):
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]

    model.eval()
    already_saved_images = False
    for gray, ab in val_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()

        # Run model and record loss
        output_ab = model(gray) # throw away class predictions
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))

        # Save images to file
        if save_images and not already_saved_images:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)): # save at most 5 images
                save_path = {'grayscale': gray_imgs, 'colorized': color_imgs}
                save_name = 'img-{}-epoch-{}.jpg'.format(j, epoch)
                to_rgb(gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

    print(f'Validate: MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print('Finished validation.')
    if epoch >= 0:
        writer.add_scalar("MSE/test", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/test", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/test", _loss[2].avg, epoch)
    return _loss[0].avg, _loss[1].avg, _loss[2].avg

In [18]:
def train(train_loader, model, criterion, optimizer, epoch):
    print(f'Starting training epoch {epoch}')
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]
    
    model.train()

    for gray, ab in train_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()
            
        optimizer.zero_grad()

        output_ab = model(gray) 
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        loss[0].backward()
        optimizer.step()
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))
        
    print(f'Epoch: {epoch}, MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print(f'Finished training epoch {epoch}')
    if epoch >= 0:
        writer.add_scalar("MSE/train", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/train", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/train", _loss[2].avg, epoch)

In [19]:
# Train model
for epoch in range(epochs):
    # Train for one epoch, then validate
    train(train_loader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(val_loader, model, criterion, save_images, epoch)
    # Save checkpoint and replace old best model if current model is better
    if losses[0] < best_losses[0]:
        best_losses[0] = losses[0]
        best_epoch = epoch
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}.pth')
    if losses[1] < best_losses[1]:
        best_losses[1] = losses[1]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-PSNRLoss-{losses[1]:.8f}.pth')
    if losses[2] < best_losses[2]:
        best_losses[2] = losses[2]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-SSIMLoss-{losses[2]:.8f}.pth')
    
    if epoch - best_epoch >= patience and epoch >= 100:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}-early_stop.pth')
        break
    
    if epoch == epochs - 1:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')


Starting training epoch 0
Epoch: 0, MSE 0.05675825 (1.00301149), PSNR 12.45970917 (6.74976383), SSIM 0.09060211 (0.03589147)
Finished training epoch 0


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.05483253 (0.05225776), PSNR 12.60961723 (13.08420031), SSIM 0.09811215 (0.11079417)
Finished validation.
Starting training epoch 1
Epoch: 1, MSE 0.02652887 (0.03688454), PSNR 15.76281166 (14.43590253), SSIM 0.23030713 (0.21386879)
Finished training epoch 1


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.02643820 (0.02509392), PSNR 15.77768040 (16.31297426), SSIM 0.25158435 (0.30609975)
Finished validation.
Starting training epoch 2
Epoch: 2, MSE 0.01379932 (0.01880001), PSNR 18.60142326 (17.32543476), SSIM 0.43525547 (0.36884350)
Finished training epoch 2


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.01649961 (0.01397742), PSNR 17.82526398 (18.80790520), SSIM 0.33753020 (0.42684442)
Finished validation.
Starting training epoch 3
Epoch: 3, MSE 0.00791664 (0.01072990), PSNR 21.01459312 (19.74350535), SSIM 0.50914496 (0.46477444)
Finished training epoch 3


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00863572 (0.00812210), PSNR 20.63701630 (21.09883583), SSIM 0.41971225 (0.50821948)
Finished validation.
Starting training epoch 4
Epoch: 4, MSE 0.00605220 (0.00679374), PSNR 22.18087006 (21.71056166), SSIM 0.57183427 (0.53613522)
Finished training epoch 4


  return func(*args, **kwargs)


Validate: MSE 0.00620105 (0.00542437), PSNR 22.07534599 (22.77328869), SSIM 0.48311591 (0.57481653)
Finished validation.
Starting training epoch 5
Epoch: 5, MSE 0.00347606 (0.00479096), PSNR 24.58913040 (23.21969713), SSIM 0.62304890 (0.59582549)
Finished training epoch 5


  return func(*args, **kwargs)


Validate: MSE 0.00491850 (0.00423600), PSNR 23.08166695 (23.81183794), SSIM 0.52991074 (0.61844974)
Finished validation.
Starting training epoch 6
Epoch: 6, MSE 0.00347048 (0.00377302), PSNR 24.59609985 (24.24992032), SSIM 0.68088698 (0.64471634)
Finished training epoch 6


  return func(*args, **kwargs)


Validate: MSE 0.00394245 (0.00345360), PSNR 24.04233742 (24.68487995), SSIM 0.58294851 (0.66440150)
Finished validation.
Starting training epoch 7
Epoch: 7, MSE 0.00266156 (0.00325686), PSNR 25.74864197 (24.89071874), SSIM 0.69991845 (0.68103303)
Finished training epoch 7


  return func(*args, **kwargs)


Validate: MSE 0.00361849 (0.00306340), PSNR 24.41472435 (25.19629222), SSIM 0.61546385 (0.69841643)
Finished validation.
Starting training epoch 8
Epoch: 8, MSE 0.00292058 (0.00299408), PSNR 25.34531021 (25.25815550), SSIM 0.71094257 (0.70732131)
Finished training epoch 8
Validate: MSE 0.00338600 (0.00291742), PSNR 24.70313263 (25.41549366), SSIM 0.64021641 (0.71733366)
Finished validation.
Starting training epoch 9
Epoch: 9, MSE 0.00310421 (0.00287091), PSNR 25.08049583 (25.44062768), SSIM 0.72396213 (0.72664850)
Finished training epoch 9
Validate: MSE 0.00323184 (0.00282706), PSNR 24.90550613 (25.55651500), SSIM 0.66317689 (0.73470078)
Finished validation.
Starting training epoch 10
Epoch: 10, MSE 0.00216200 (0.00282097), PSNR 26.65143967 (25.51439193), SSIM 0.75441432 (0.74107134)
Finished training epoch 10
Validate: MSE 0.00329384 (0.00276413), PSNR 24.82297325 (25.65084862), SSIM 0.67692071 (0.74869692)
Finished validation.
Starting training epoch 11
Epoch: 11, MSE 0.00269127 (0.0

Validate: MSE 0.00294451 (0.00255697), PSNR 25.30986977 (25.98901967), SSIM 0.70500422 (0.77157472)
Finished validation.
Starting training epoch 38
Epoch: 38, MSE 0.00282857 (0.00241859), PSNR 25.48432541 (26.18962403), SSIM 0.72720915 (0.77122363)
Finished training epoch 38
Validate: MSE 0.00264978 (0.00241442), PSNR 25.76789474 (26.22721105), SSIM 0.70210218 (0.76376586)
Finished validation.
Starting training epoch 39
Epoch: 39, MSE 0.00246529 (0.00240192), PSNR 26.08131981 (26.21715522), SSIM 0.78466046 (0.77082598)
Finished training epoch 39
Validate: MSE 0.00324685 (0.00316224), PSNR 24.88537788 (25.08667813), SSIM 0.69731665 (0.75420414)
Finished validation.
Starting training epoch 40
Epoch: 40, MSE 0.00214257 (0.00240515), PSNR 26.69065666 (26.20792957), SSIM 0.76975441 (0.77165527)
Finished training epoch 40
Validate: MSE 0.00276210 (0.00240247), PSNR 25.58760834 (26.25829703), SSIM 0.70831436 (0.77330469)
Finished validation.
Starting training epoch 41
Epoch: 41, MSE 0.0021913

Validate: MSE 0.00274575 (0.00228862), PSNR 25.61339378 (26.46991375), SSIM 0.71503603 (0.77678566)
Finished validation.
Starting training epoch 68
Epoch: 68, MSE 0.00254398 (0.00229974), PSNR 25.94485855 (26.40662061), SSIM 0.76680851 (0.77313820)
Finished training epoch 68
Validate: MSE 0.00257482 (0.00246991), PSNR 25.89253235 (26.14188271), SSIM 0.70697486 (0.76959016)
Finished validation.
Starting training epoch 69
Epoch: 69, MSE 0.00277970 (0.00229146), PSNR 25.56001091 (26.42380879), SSIM 0.75622970 (0.77336237)
Finished training epoch 69
Validate: MSE 0.00268577 (0.00241185), PSNR 25.70931435 (26.22360109), SSIM 0.70400232 (0.76762572)
Finished validation.
Starting training epoch 70
Epoch: 70, MSE 0.00264675 (0.00228140), PSNR 25.77287102 (26.43967831), SSIM 0.76105064 (0.77349316)
Finished training epoch 70
Validate: MSE 0.00262487 (0.00233661), PSNR 25.80892372 (26.37624176), SSIM 0.70936513 (0.77279798)
Finished validation.
Starting training epoch 71
Epoch: 71, MSE 0.0027732

Validate: MSE 0.00249736 (0.00230403), PSNR 26.02517891 (26.43149627), SSIM 0.70929170 (0.76967903)
Finished validation.
Starting training epoch 98
Epoch: 98, MSE 0.00245353 (0.00222658), PSNR 26.10209084 (26.54602585), SSIM 0.76282179 (0.77296989)
Finished training epoch 98
Validate: MSE 0.00275845 (0.00233269), PSNR 25.59334946 (26.38448173), SSIM 0.70867395 (0.76542351)
Finished validation.
Starting training epoch 99
Epoch: 99, MSE 0.00291647 (0.00222799), PSNR 25.35142899 (26.54163703), SSIM 0.74186009 (0.77273536)
Finished training epoch 99
Validate: MSE 0.00284717 (0.00234745), PSNR 25.45586205 (26.35658797), SSIM 0.70455039 (0.76915674)
Finished validation.
Starting training epoch 100
Epoch: 100, MSE 0.00190735 (0.00222088), PSNR 27.19570541 (26.55793535), SSIM 0.80439556 (0.77295436)
Finished training epoch 100
Validate: MSE 0.00278541 (0.00228254), PSNR 25.55111313 (26.47845856), SSIM 0.70910442 (0.77424890)
Finished validation.
Starting training epoch 101
Epoch: 101, MSE 0.00

Epoch: 127, MSE 0.00176092 (0.00219398), PSNR 27.54260826 (26.61037206), SSIM 0.76388234 (0.77144336)
Finished training epoch 127
Validate: MSE 0.00269751 (0.00246539), PSNR 25.69037247 (26.16015206), SSIM 0.70959318 (0.76879728)
Finished validation.
Starting training epoch 128
Epoch: 128, MSE 0.00192404 (0.00218510), PSNR 27.15785789 (26.62858299), SSIM 0.77897203 (0.77241944)
Finished training epoch 128
Validate: MSE 0.00283169 (0.00227104), PSNR 25.47954178 (26.50123633), SSIM 0.70613742 (0.77365964)
Finished validation.
Starting training epoch 129
Epoch: 129, MSE 0.00155138 (0.00218903), PSNR 28.09282112 (26.62126862), SSIM 0.79892522 (0.77187222)
Finished training epoch 129
Validate: MSE 0.00282737 (0.00225376), PSNR 25.48617554 (26.53649307), SSIM 0.70528078 (0.77274943)
Finished validation.
Starting training epoch 130
Epoch: 130, MSE 0.00199436 (0.00218511), PSNR 27.00197411 (26.62970444), SSIM 0.76464164 (0.77218095)
Finished training epoch 130
Validate: MSE 0.00286253 (0.00259

Validate: MSE 0.00287328 (0.00228979), PSNR 25.41622543 (26.45999716), SSIM 0.70298207 (0.76677933)
Finished validation.
Starting training epoch 157
Epoch: 157, MSE 0.00221350 (0.00218099), PSNR 26.54920197 (26.63688230), SSIM 0.77699471 (0.77052610)
Finished training epoch 157
Validate: MSE 0.00256952 (0.00230184), PSNR 25.90147209 (26.44361135), SSIM 0.71090388 (0.77210734)
Finished validation.
Starting training epoch 158
Epoch: 158, MSE 0.00165139 (0.00216349), PSNR 27.82149696 (26.66980932), SSIM 0.79717219 (0.77100568)
Finished training epoch 158
Validate: MSE 0.00264313 (0.00238157), PSNR 25.77880669 (26.29362756), SSIM 0.71404111 (0.76627089)
Finished validation.
Starting training epoch 159
Epoch: 159, MSE 0.00207706 (0.00214759), PSNR 26.82551765 (26.70187122), SSIM 0.77083379 (0.77190727)
Finished training epoch 159
Validate: MSE 0.00268848 (0.00223845), PSNR 25.70492744 (26.55897820), SSIM 0.71060610 (0.77364394)
Finished validation.
Starting training epoch 160
Epoch: 160, MS

Epoch: 186, MSE 0.00202234 (0.00213622), PSNR 26.94145584 (26.72725761), SSIM 0.78681523 (0.77137185)
Finished training epoch 186
Validate: MSE 0.00286217 (0.00241791), PSNR 25.43303871 (26.23289335), SSIM 0.70880926 (0.76828812)
Finished validation.
Starting training epoch 187
Epoch: 187, MSE 0.00180399 (0.00213481), PSNR 27.43766403 (26.72704483), SSIM 0.80172586 (0.77126331)
Finished training epoch 187
Validate: MSE 0.00262658 (0.00224283), PSNR 25.80608749 (26.54833339), SSIM 0.71557450 (0.77149707)
Finished validation.
Starting training epoch 188
Epoch: 188, MSE 0.00214806 (0.00213423), PSNR 26.67953300 (26.73172245), SSIM 0.78715694 (0.77142919)
Finished training epoch 188
Validate: MSE 0.00255876 (0.00229072), PSNR 25.91969490 (26.45216907), SSIM 0.70770639 (0.77048412)
Finished validation.
Starting training epoch 189
Epoch: 189, MSE 0.00231151 (0.00213021), PSNR 26.36103630 (26.73428542), SSIM 0.75869304 (0.77141118)
Finished training epoch 189
Validate: MSE 0.00275064 (0.00226

Validate: MSE 0.00298354 (0.00237302), PSNR 25.25267410 (26.30967276), SSIM 0.69955480 (0.76153605)
Finished validation.
Starting training epoch 216
Epoch: 216, MSE 0.00233150 (0.00213406), PSNR 26.32364655 (26.72971761), SSIM 0.76050842 (0.77050710)
Finished training epoch 216
Validate: MSE 0.00250968 (0.00230820), PSNR 26.00381470 (26.42193132), SSIM 0.71296608 (0.76467439)
Finished validation.
Starting training epoch 217
Epoch: 217, MSE 0.00210929 (0.00210402), PSNR 26.75864601 (26.79077559), SSIM 0.77030408 (0.77141125)
Finished training epoch 217
Validate: MSE 0.00268900 (0.00230076), PSNR 25.70408249 (26.43819660), SSIM 0.69935924 (0.76229114)
Finished validation.
Starting training epoch 218
Epoch: 218, MSE 0.00182187 (0.00211443), PSNR 27.39482498 (26.76800657), SSIM 0.78586859 (0.77102332)
Finished training epoch 218
Validate: MSE 0.00287287 (0.00229548), PSNR 25.41684532 (26.45741310), SSIM 0.70295972 (0.76661336)
Finished validation.
Starting training epoch 219
Epoch: 219, MS

Epoch: 245, MSE 0.00213501 (0.00208735), PSNR 26.70600510 (26.82418721), SSIM 0.78760028 (0.77137382)
Finished training epoch 245
Validate: MSE 0.00255289 (0.00227327), PSNR 25.92967415 (26.48781636), SSIM 0.71030551 (0.76517601)
Finished validation.
Starting training epoch 246
Epoch: 246, MSE 0.00215516 (0.00209440), PSNR 26.66519928 (26.81026450), SSIM 0.77577055 (0.77099874)
Finished training epoch 246
Validate: MSE 0.00265945 (0.00225468), PSNR 25.75208282 (26.52954000), SSIM 0.71195328 (0.76794630)
Finished validation.
Starting training epoch 247
Epoch: 247, MSE 0.00182766 (0.00208472), PSNR 27.38103294 (26.82869618), SSIM 0.79321748 (0.77098961)
Finished training epoch 247
Validate: MSE 0.00285371 (0.00232407), PSNR 25.44589615 (26.40463817), SSIM 0.71223575 (0.77300377)
Finished validation.
Starting training epoch 248
Epoch: 248, MSE 0.00199330 (0.00211045), PSNR 27.00426292 (26.77958617), SSIM 0.76950204 (0.76962815)
Finished training epoch 248
Validate: MSE 0.00260772 (0.00228

<Figure size 432x288 with 0 Axes>

In [20]:
torch.save(model.state_dict(), f'{checkpoints}/last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')

In [21]:
# Validate
save_images = True
with torch.no_grad():
    validate(val_loader, model, criterion, save_images, -1)

Validate: MSE 0.00265156 (0.00230324), PSNR 25.76498795 (26.42936661), SSIM 0.70446002 (0.75928113)
Finished validation.


<Figure size 432x288 with 0 Axes>

In [22]:
# # Show images 
# image_pairs = []

# for i in range(10):
#     image_pairs.append((f'{color_imgs}img-{i}-epoch-{best_epoch}.jpg', f'{gray_imgs}img-{i}-epoch-{best_epoch}.jpg'))
    
# for c, g in image_pairs:
#   color = mpimg.imread(c)
#   gray  = mpimg.imread(g)
#   f, axarr = plt.subplots(1, 2)
#   f.set_size_inches(15, 15)
#   axarr[0].imshow(gray, cmap='gray')
#   axarr[1].imshow(color)
#   axarr[0].axis('off'), axarr[1].axis('off')
#   plt.show()