In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# For conversion
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchmetrics import MeanSquaredError, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from PIL import Image
# For utilities
import os, shutil, time

In [2]:
%load_ext tensorboard
%tensorboard --logdir=runs

Reusing TensorBoard on port 6007 (pid 16107), started 0:58:03 ago. (Use '!kill 16107' to kill it.)

In [3]:
# colab i kaggle jeszcze nie testowane
colab = False
kaggle = False
test_number = '12_2'

In [4]:
color_imgs = 'outputs/color/'
gray_imgs = 'outputs/gray/'
checkpoints = 'checkpoints'
if colab:
    from google.colab import drive
    drive.mount('/content/drive')
    dataset = '/content/drive/MyDrive/MGU/cifar10/'
    
    color_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{color_imgs}'
    gray_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{gray_imgs}'
    checkpoints = f'/content/drive/MyDrive/MGU/{test_number}/{checkpoints}'
elif kaggle:
    dataset = '/kaggle/input/cifar10/'
    
    color_imgs = f'{test_number}/{color_imgs}'
    gray_imgs = f'{test_number}/{gray_imgs}'
    checkpoints = f'{test_number}/{checkpoints}'
else:
    dataset = '../../datasets/cifar10/'

In [5]:
# Make folders and set parameters
os.makedirs(color_imgs, exist_ok=True)
os.makedirs(gray_imgs, exist_ok=True)
os.makedirs(checkpoints, exist_ok=True)
save_images = True
best_losses = [1e10, 1e10, 1e10]
best_epoch = -1
patience = 50
epochs = 500
batch_size = 128
SIZE = 32

In [6]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [7]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()
print(use_gpu)

True


In [8]:
class LabImageFolder(torch.utils.data.Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC),
                transforms.RandomCrop(SIZE),
                transforms.RandomHorizontalFlip(),  
            ])
        elif split == 'val':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC), 
                transforms.RandomCrop(SIZE), 
            ])
            
        self.split = split
        self.size = SIZE
        self.paths = [os.path.join(paths, file) for file in os.listdir(paths) if os.path.isfile(
            os.path.join(paths, file))]
        
        
    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("RGB")
        img_original = self.transforms(img)
        img_original = np.asarray(img_original)
        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        img_gray = rgb2gray(img_original)
        img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
        return img_gray, img_ab
    
    def __len__(self):
        return len(self.paths)

In [9]:
# Training
train_imagefolder = LabImageFolder(dataset + 'train')
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=batch_size, shuffle=True)
# Validation 
val_imagefolder = LabImageFolder(dataset + 'val' , 'val')
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=batch_size, shuffle=False)

In [10]:
kernel_size=3
stride_en=1
stride_de=1
padding=1
scale_factor=2
padding_mode='zeros'
channels_base = 256
p1 = .5

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        self.conv1 = nn.Conv2d(1, channels_base, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        self.conv2 = nn.Conv2d(channels_base, channels_base * 2, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)

        self.convtrans1 = nn.ConvTranspose2d(channels_base * 2, channels_base, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans2 = nn.ConvTranspose2d(channels_base, channels_base // 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans3 = nn.ConvTranspose2d(channels_base // 2, 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)

        self.batchnorm1 = nn.BatchNorm2d(channels_base)
        self.batchnorm2 = nn.BatchNorm2d(channels_base * 2)
        self.batchnorm3 = nn.BatchNorm2d(channels_base)
        self.batchnorm4 = nn.BatchNorm2d(channels_base // 2)
        
    def forward(self, input):
        # encoder
        x = F.leaky_relu(self.batchnorm1(self.conv1(input)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.avg_pool2d(x, kernel_size=3, stride=2, padding=1)
        x = F.leaky_relu(self.batchnorm2(self.conv2(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.avg_pool2d(x, kernel_size=3, stride=2, padding=1)
        
        # decoder
        x = F.leaky_relu(self.batchnorm3(self.convtrans1(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.leaky_relu(self.batchnorm4(self.convtrans2(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.interpolate(self.convtrans3(x), scale_factor=scale_factor)

        return x

In [11]:
model = Autoencoder()

In [12]:
criterion = [MeanSquaredError(), PeakSignalNoiseRatio(data_range=1.0), StructuralSimilarityIndexMeasure(data_range=1.0)]

In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [14]:
# # Move model and loss function to GPU
if use_gpu: 
    criterion = [criterion[0].to("cuda"), criterion[1].to("cuda"), criterion[2].to("cuda")]
    model = model.cuda()

In [15]:
if use_gpu: 
    from torchsummary import summary
    summary(model, (1, SIZE, SIZE))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 256, 32, 32]           2,560
       BatchNorm2d-2          [-1, 256, 32, 32]             512
            Conv2d-3          [-1, 512, 16, 16]       1,180,160
       BatchNorm2d-4          [-1, 512, 16, 16]           1,024
   ConvTranspose2d-5            [-1, 256, 8, 8]       1,179,904
       BatchNorm2d-6            [-1, 256, 8, 8]             512
   ConvTranspose2d-7          [-1, 128, 16, 16]         295,040
       BatchNorm2d-8          [-1, 128, 16, 16]             256
   ConvTranspose2d-9            [-1, 2, 16, 16]           2,306
Total params: 2,662,274
Trainable params: 2,662,274
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 6.75
Params size (MB): 10.16
Estimated Total Size (MB): 16.91
-------------------------------------

In [16]:
class AverageMeter(object):
    '''A handy class from the PyTorch ImageNet tutorial''' 
    def __init__(self):
        self.reset()
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
       Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
    color_image = lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    if save_path is not None and save_name is not None: 
        plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
        plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

In [17]:
def validate(val_loader, model, criterion, save_images, epoch):
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]

    model.eval()
    already_saved_images = False
    for gray, ab in val_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()

        # Run model and record loss
        output_ab = model(gray) # throw away class predictions
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))

        # Save images to file
        if save_images and not already_saved_images:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)): # save at most 5 images
                save_path = {'grayscale': gray_imgs, 'colorized': color_imgs}
                save_name = 'img-{}-epoch-{}.jpg'.format(j, epoch)
                to_rgb(gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

    print(f'Validate: MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print('Finished validation.')
    if epoch >= 0:
        writer.add_scalar("MSE/test", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/test", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/test", _loss[2].avg, epoch)
    return _loss[0].avg, _loss[1].avg, _loss[2].avg

In [18]:
def train(train_loader, model, criterion, optimizer, epoch):
    print(f'Starting training epoch {epoch}')
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]
    
    model.train()

    for gray, ab in train_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()
            
        optimizer.zero_grad()

        output_ab = model(gray) 
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        loss[0].backward()
        optimizer.step()
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))
        
    print(f'Epoch: {epoch}, MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print(f'Finished training epoch {epoch}')
    if epoch >= 0:
        writer.add_scalar("MSE/train", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/train", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/train", _loss[2].avg, epoch)

In [19]:
# Train model
for epoch in range(epochs):
    # Train for one epoch, then validate
    train(train_loader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(val_loader, model, criterion, save_images, epoch)
    # Save checkpoint and replace old best model if current model is better
    if losses[0] < best_losses[0]:
        best_losses[0] = losses[0]
        best_epoch = epoch
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}.pth')
    if losses[1] < best_losses[1]:
        best_losses[1] = losses[1]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-PSNRLoss-{losses[1]:.8f}.pth')
    if losses[2] < best_losses[2]:
        best_losses[2] = losses[2]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-SSIMLoss-{losses[2]:.8f}.pth')
    
    if epoch - best_epoch >= patience and epoch >= 100:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}-early_stop.pth')
        break
    
    if epoch == epochs - 1:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')


Starting training epoch 0
Epoch: 0, MSE 0.10759954 (1.94897345), PSNR 9.68189526 (4.60876074), SSIM 0.04124395 (0.01382493)
Finished training epoch 0


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.08022149 (0.10228657), PSNR 10.95709229 (10.25020947), SSIM 0.04317306 (0.04190793)
Finished validation.
Starting training epoch 1
Epoch: 1, MSE 0.05792886 (0.07710596), PSNR 12.37104988 (11.19899808), SSIM 0.22418387 (0.10879513)
Finished training epoch 1


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.04575937 (0.05302155), PSNR 13.39519882 (13.33253447), SSIM 0.17119032 (0.19321194)
Finished validation.
Starting training epoch 2
Epoch: 2, MSE 0.03758120 (0.04658081), PSNR 14.25029278 (13.35549822), SSIM 0.26839513 (0.24917847)
Finished training epoch 2


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.02889898 (0.03515750), PSNR 15.39117336 (15.13252701), SSIM 0.26680142 (0.31588141)
Finished validation.
Starting training epoch 3
Epoch: 3, MSE 0.02387550 (0.03008388), PSNR 16.22047615 (15.25247860), SSIM 0.37600407 (0.35217189)
Finished training epoch 3


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.02044395 (0.02583954), PSNR 16.89435196 (16.44143798), SSIM 0.32853615 (0.39421866)
Finished validation.
Starting training epoch 4
Epoch: 4, MSE 0.01617739 (0.01966981), PSNR 17.91091537 (17.09455260), SSIM 0.44457665 (0.42424622)
Finished training epoch 4


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.01329490 (0.01622793), PSNR 18.76314926 (18.37991976), SSIM 0.37941021 (0.45821365)
Finished validation.
Starting training epoch 5
Epoch: 5, MSE 0.01018897 (0.01316150), PSNR 19.91869736 (18.83759282), SSIM 0.51549351 (0.48363659)
Finished training epoch 5


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.01043332 (0.01172813), PSNR 19.81577492 (19.72911078), SSIM 0.42386913 (0.51201537)
Finished validation.
Starting training epoch 6
Epoch: 6, MSE 0.00719334 (0.00902564), PSNR 21.43069077 (20.47273463), SSIM 0.56424636 (0.53341954)
Finished training epoch 6


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00666613 (0.00745771), PSNR 21.76125908 (21.57483108), SSIM 0.48077494 (0.56020327)
Finished validation.
Starting training epoch 7
Epoch: 7, MSE 0.00561851 (0.00644206), PSNR 22.50378990 (21.93128272), SSIM 0.61309904 (0.57453763)
Finished training epoch 7


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00538449 (0.00558908), PSNR 22.68855095 (22.74398469), SSIM 0.51247001 (0.59357520)
Finished validation.
Starting training epoch 8
Epoch: 8, MSE 0.00456383 (0.00486607), PSNR 23.40670204 (23.14674838), SSIM 0.60853499 (0.60982206)
Finished training epoch 8
Validate: MSE 0.00421030 (0.00427802), PSNR 23.75687027 (23.82043450), SSIM 0.55014777 (0.62997597)
Finished validation.
Starting training epoch 9
Epoch: 9, MSE 0.00364085 (0.00391332), PSNR 24.38796806 (24.09305133), SSIM 0.64753616 (0.63998219)
Finished training epoch 9
Validate: MSE 0.00367949 (0.00353233), PSNR 24.34212112 (24.60717479), SSIM 0.58982980 (0.65993816)
Finished validation.
Starting training epoch 10
Epoch: 10, MSE 0.00304835 (0.00336919), PSNR 25.15934753 (24.74219407), SSIM 0.68892640 (0.66651481)
Finished training epoch 10


  return func(*args, **kwargs)


Validate: MSE 0.00340592 (0.00318801), PSNR 24.67765999 (25.03963680), SSIM 0.61779535 (0.67845076)
Finished validation.
Starting training epoch 11
Epoch: 11, MSE 0.00362838 (0.00305677), PSNR 24.40287590 (25.16863363), SSIM 0.68508917 (0.68931587)
Finished training epoch 11
Validate: MSE 0.00323373 (0.00296695), PSNR 24.90295601 (25.35515921), SSIM 0.62923193 (0.69834541)
Finished validation.
Starting training epoch 12
Epoch: 12, MSE 0.00240246 (0.00287992), PSNR 26.19342995 (25.42750319), SSIM 0.72012317 (0.70880786)
Finished training epoch 12
Validate: MSE 0.00328374 (0.00280938), PSNR 24.83631325 (25.58000973), SSIM 0.66949695 (0.72753991)
Finished validation.
Starting training epoch 13
Epoch: 13, MSE 0.00304349 (0.00277911), PSNR 25.16628075 (25.58328638), SSIM 0.70826352 (0.72402055)
Finished training epoch 13
Validate: MSE 0.00310060 (0.00272669), PSNR 25.08554268 (25.71196626), SSIM 0.67489541 (0.73570372)
Finished validation.
Starting training epoch 14
Epoch: 14, MSE 0.0027776

Validate: MSE 0.00278861 (0.00254577), PSNR 25.54611397 (25.99720335), SSIM 0.70798618 (0.76874453)
Finished validation.
Starting training epoch 41
Epoch: 41, MSE 0.00254789 (0.00242760), PSNR 25.93819618 (26.17405444), SSIM 0.80405998 (0.77241193)
Finished training epoch 41
Validate: MSE 0.00333477 (0.00283557), PSNR 24.76933479 (25.56643369), SSIM 0.70461470 (0.77075737)
Finished validation.
Starting training epoch 42
Epoch: 42, MSE 0.00240373 (0.00243386), PSNR 26.19114685 (26.15970129), SSIM 0.76522654 (0.77203588)
Finished training epoch 42
Validate: MSE 0.00267361 (0.00270297), PSNR 25.72901726 (25.78658418), SSIM 0.71007848 (0.76876568)
Finished validation.
Starting training epoch 43
Epoch: 43, MSE 0.00206798 (0.00242114), PSNR 26.84453392 (26.18248067), SSIM 0.79171526 (0.77258563)
Finished training epoch 43
Validate: MSE 0.00283778 (0.00243823), PSNR 25.47021866 (26.19715743), SSIM 0.70848644 (0.77144750)
Finished validation.
Starting training epoch 44
Epoch: 44, MSE 0.0019875

Validate: MSE 0.00300668 (0.00257494), PSNR 25.21912766 (25.96698466), SSIM 0.69902563 (0.76141212)
Finished validation.
Starting training epoch 71
Epoch: 71, MSE 0.00265454 (0.00233879), PSNR 25.76011467 (26.33114209), SSIM 0.77537072 (0.77330700)
Finished training epoch 71
Validate: MSE 0.00267297 (0.00258390), PSNR 25.73005295 (25.95570375), SSIM 0.71061313 (0.76378130)
Finished validation.
Starting training epoch 72
Epoch: 72, MSE 0.00239365 (0.00235506), PSNR 26.20940018 (26.30525733), SSIM 0.77584523 (0.77237587)
Finished training epoch 72
Validate: MSE 0.00291660 (0.00239632), PSNR 25.35122681 (26.27226289), SSIM 0.71134913 (0.77619179)
Finished validation.
Starting training epoch 73
Epoch: 73, MSE 0.00260398 (0.00234046), PSNR 25.84362984 (26.32930080), SSIM 0.75710428 (0.77323523)
Finished training epoch 73
Validate: MSE 0.00264735 (0.00255414), PSNR 25.77189255 (26.00217802), SSIM 0.70595753 (0.76343972)
Finished validation.
Starting training epoch 74
Epoch: 74, MSE 0.0021774

Validate: MSE 0.00255395 (0.00241746), PSNR 25.92787933 (26.23929469), SSIM 0.71099412 (0.76879286)
Finished validation.
Starting training epoch 101
Epoch: 101, MSE 0.00274619 (0.00230085), PSNR 25.61269951 (26.40474591), SSIM 0.75528467 (0.77301441)
Finished training epoch 101
Validate: MSE 0.00271367 (0.00248174), PSNR 25.66442108 (26.11692482), SSIM 0.70568693 (0.76712684)
Finished validation.
Starting training epoch 102
Epoch: 102, MSE 0.00225968 (0.00230365), PSNR 26.45953178 (26.40064864), SSIM 0.77655977 (0.77280927)
Finished training epoch 102
Validate: MSE 0.00269833 (0.00239623), PSNR 25.68904495 (26.26849534), SSIM 0.70918393 (0.76994517)
Finished validation.
Starting training epoch 103
Epoch: 103, MSE 0.00194333 (0.00229672), PSNR 27.11454010 (26.41176201), SSIM 0.77936059 (0.77295697)
Finished training epoch 103
Validate: MSE 0.00275627 (0.00233881), PSNR 25.59679031 (26.36933755), SSIM 0.70214462 (0.76732682)
Finished validation.
Starting training epoch 104
Epoch: 104, MS

Epoch: 130, MSE 0.00183569 (0.00227259), PSNR 27.36200333 (26.45867663), SSIM 0.78263104 (0.77298757)
Finished training epoch 130
Validate: MSE 0.00259608 (0.00228766), PSNR 25.85682297 (26.46456035), SSIM 0.71547490 (0.77169670)
Finished validation.
Starting training epoch 131
Epoch: 131, MSE 0.00213751 (0.00227256), PSNR 26.70091057 (26.45722573), SSIM 0.77699649 (0.77304531)
Finished training epoch 131
Validate: MSE 0.00268329 (0.00242876), PSNR 25.71331596 (26.21590881), SSIM 0.71229267 (0.77003959)
Finished validation.
Starting training epoch 132
Epoch: 132, MSE 0.00315544 (0.00227572), PSNR 25.00940704 (26.45517661), SSIM 0.73487478 (0.77250011)
Finished training epoch 132
Validate: MSE 0.00276524 (0.00232408), PSNR 25.58267593 (26.40059693), SSIM 0.70592988 (0.76668060)
Finished validation.
Starting training epoch 133
Epoch: 133, MSE 0.00210512 (0.00227346), PSNR 26.76723480 (26.45624802), SSIM 0.78432941 (0.77254835)
Finished training epoch 133
Validate: MSE 0.00273858 (0.00263

Validate: MSE 0.00263382 (0.00239825), PSNR 25.79413033 (26.27620010), SSIM 0.71283853 (0.76830169)
Finished validation.
Starting training epoch 160
Epoch: 160, MSE 0.00192126 (0.00225358), PSNR 27.16412926 (26.49370339), SSIM 0.79219145 (0.77303131)
Finished training epoch 160
Validate: MSE 0.00271795 (0.00256139), PSNR 25.65758324 (25.99603258), SSIM 0.70947188 (0.76779246)
Finished validation.
Starting training epoch 161
Epoch: 161, MSE 0.00185905 (0.00224810), PSNR 27.30709267 (26.50577743), SSIM 0.78711116 (0.77322347)
Finished training epoch 161
Validate: MSE 0.00295685 (0.00250854), PSNR 25.29170990 (26.07868866), SSIM 0.69724572 (0.76334793)
Finished validation.
Starting training epoch 162
Epoch: 162, MSE 0.00190459 (0.00225559), PSNR 27.20197868 (26.49058116), SSIM 0.77926290 (0.77276255)
Finished training epoch 162
Validate: MSE 0.00262342 (0.00236997), PSNR 25.81131935 (26.31712680), SSIM 0.71301949 (0.76959504)
Finished validation.
Starting training epoch 163
Epoch: 163, MS

Epoch: 189, MSE 0.00227174 (0.00225248), PSNR 26.43641472 (26.49398932), SSIM 0.76299202 (0.77198486)
Finished training epoch 189
Validate: MSE 0.00274842 (0.00241191), PSNR 25.60916901 (26.24144894), SSIM 0.71416330 (0.76960252)
Finished validation.
Starting training epoch 190
Epoch: 190, MSE 0.00165684 (0.00223800), PSNR 27.80718231 (26.52918338), SSIM 0.79032528 (0.77277008)
Finished training epoch 190
Validate: MSE 0.00281543 (0.00234308), PSNR 25.50454521 (26.36440344), SSIM 0.70740211 (0.76784636)
Finished validation.
Starting training epoch 191
Epoch: 191, MSE 0.00240871 (0.00223832), PSNR 26.18214417 (26.52393309), SSIM 0.76111209 (0.77275803)
Finished training epoch 191
Validate: MSE 0.00266307 (0.00227190), PSNR 25.74617386 (26.49549029), SSIM 0.71146810 (0.77067641)
Finished validation.
Starting training epoch 192
Epoch: 192, MSE 0.00211682 (0.00227563), PSNR 26.74315834 (26.45190016), SSIM 0.79146200 (0.77095285)
Finished training epoch 192
Validate: MSE 0.00272696 (0.00260

Validate: MSE 0.00275222 (0.00229225), PSNR 25.60316658 (26.46227076), SSIM 0.71000719 (0.77269068)
Finished validation.
Starting training epoch 219
Epoch: 219, MSE 0.00177084 (0.00223724), PSNR 27.51820183 (26.52278425), SSIM 0.79855841 (0.77227858)
Finished training epoch 219
Validate: MSE 0.00268945 (0.00225620), PSNR 25.70336151 (26.52258314), SSIM 0.71451342 (0.77451330)
Finished validation.
Starting training epoch 220
Epoch: 220, MSE 0.00179367 (0.00223387), PSNR 27.46256828 (26.53068815), SSIM 0.77320325 (0.77259651)
Finished training epoch 220
Validate: MSE 0.00310258 (0.00268066), PSNR 25.08276939 (25.79052833), SSIM 0.70907271 (0.76562541)
Finished validation.
Starting training epoch 221
Epoch: 221, MSE 0.00259240 (0.00222718), PSNR 25.86297226 (26.54471552), SSIM 0.75821054 (0.77262895)
Finished training epoch 221
Validate: MSE 0.00259271 (0.00230671), PSNR 25.86245346 (26.42641298), SSIM 0.70894837 (0.77258932)
Finished validation.
Starting training epoch 222
Epoch: 222, MS

<Figure size 432x288 with 0 Axes>

In [20]:
torch.save(model.state_dict(), f'{checkpoints}/last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')

In [21]:
# Validate
save_images = True
with torch.no_grad():
    validate(val_loader, model, criterion, save_images, -1)

Validate: MSE 0.00278930 (0.00239648), PSNR 25.54504013 (26.27838015), SSIM 0.71262717 (0.77283016)
Finished validation.


<Figure size 432x288 with 0 Axes>

In [22]:
# # Show images 
# image_pairs = []

# for i in range(10):
#     image_pairs.append((f'{color_imgs}img-{i}-epoch-{best_epoch}.jpg', f'{gray_imgs}img-{i}-epoch-{best_epoch}.jpg'))
    
# for c, g in image_pairs:
#   color = mpimg.imread(c)
#   gray  = mpimg.imread(g)
#   f, axarr = plt.subplots(1, 2)
#   f.set_size_inches(15, 15)
#   axarr[0].imshow(gray, cmap='gray')
#   axarr[1].imshow(color)
#   axarr[0].axis('off'), axarr[1].axis('off')
#   plt.show()