In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# For conversion
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchmetrics import MeanSquaredError, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from PIL import Image
# For utilities
import os, shutil, time

In [2]:
%load_ext tensorboard
%tensorboard --logdir=runs

In [3]:
# colab i kaggle jeszcze nie testowane
colab = False
kaggle = False
test_number = '12_2'

In [4]:
color_imgs = 'outputs/color/'
gray_imgs = 'outputs/gray/'
checkpoints = 'checkpoints'
if colab:
    from google.colab import drive
    drive.mount('/content/drive')
    dataset = '/content/drive/MyDrive/MGU/cifar10/'
    
    color_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{color_imgs}'
    gray_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{gray_imgs}'
    checkpoints = f'/content/drive/MyDrive/MGU/{test_number}/{checkpoints}'
elif kaggle:
    dataset = '/kaggle/input/cifar10/'
    
    color_imgs = f'{test_number}/{color_imgs}'
    gray_imgs = f'{test_number}/{gray_imgs}'
    checkpoints = f'{test_number}/{checkpoints}'
else:
    dataset = '../../datasets/cifar10/'

In [5]:
# Make folders and set parameters
os.makedirs(color_imgs, exist_ok=True)
os.makedirs(gray_imgs, exist_ok=True)
os.makedirs(checkpoints, exist_ok=True)
save_images = True
best_losses = [1e10, 1e10, 1e10]
best_epoch = -1
patience = 50
epochs = 500
batch_size = 128
SIZE = 32

In [6]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [7]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()
print(use_gpu)

True


In [8]:
class LabImageFolder(torch.utils.data.Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC),
                transforms.RandomCrop(SIZE),
                transforms.RandomHorizontalFlip(),  
            ])
        elif split == 'val':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC), 
                transforms.RandomCrop(SIZE), 
            ])
            
        self.split = split
        self.size = SIZE
        self.paths = [os.path.join(paths, file) for file in os.listdir(paths) if os.path.isfile(
            os.path.join(paths, file))]
        
        
    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("RGB")
        img_original = self.transforms(img)
        img_original = np.asarray(img_original)
        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        img_gray = rgb2gray(img_original)
        img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
        return img_gray, img_ab
    
    def __len__(self):
        return len(self.paths)

In [9]:
# Training
train_imagefolder = LabImageFolder(dataset + 'train')
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=batch_size, shuffle=True)
# Validation 
val_imagefolder = LabImageFolder(dataset + 'val' , 'val')
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=batch_size, shuffle=False)

In [10]:
kernel_size=3
stride_en=1
stride_de=1
padding=1
scale_factor=2
padding_mode='zeros'
channels_base = 64
p1 = .5

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        self.conv1 = nn.Conv2d(1, channels_base, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        self.conv2 = nn.Conv2d(channels_base, channels_base * 2, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        
        self.convtrans1 = nn.ConvTranspose2d(channels_base * 2, channels_base, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans2 = nn.ConvTranspose2d(channels_base, channels_base // 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans3 = nn.ConvTranspose2d(channels_base // 2, 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)

        self.batchnorm1 = nn.BatchNorm2d(channels_base // 2)
        self.batchnorm2 = nn.BatchNorm2d(channels_base)
        self.batchnorm3 = nn.BatchNorm2d(channels_base * 2)
        
        self.dropout1 = nn.Dropout(p=p1)        
        
    def forward(self, input):
        # encoder
        x = F.leaky_relu(self.batchnorm2(self.conv1(input)), negative_slope=.1)
        x = self.dropout1(x)
        x = y = F.avg_pool2d(x, kernel_size=3, stride=2, padding=1)
        x = F.leaky_relu(self.batchnorm3(self.conv2(x)), negative_slope=.1)
        x = self.dropout1(x)
        x = F.avg_pool2d(x, kernel_size=3, stride=2, padding=1)
        
        # decoder
        x = F.leaky_relu(self.batchnorm2(self.convtrans1(x)), negative_slope=.1)
        x = self.dropout1(x)
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.leaky_relu(self.batchnorm1(self.convtrans2(x + y)), negative_slope=.1)
        x = self.dropout1(x)
        x = self.convtrans3(F.interpolate(x, scale_factor=scale_factor) + input)

        return x

In [11]:
model = Autoencoder()

In [12]:
criterion = [MeanSquaredError(), PeakSignalNoiseRatio(data_range=1.0), StructuralSimilarityIndexMeasure(data_range=1.0)]

In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [14]:
# # Move model and loss function to GPU
if use_gpu: 
    criterion = [criterion[0].to("cuda"), criterion[1].to("cuda"), criterion[2].to("cuda")]
    model = model.cuda()

In [15]:
if use_gpu: 
    from torchsummary import summary
    summary(model, (1, SIZE, SIZE))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]             640
       BatchNorm2d-2           [-1, 64, 32, 32]             128
           Dropout-3           [-1, 64, 32, 32]               0
            Conv2d-4          [-1, 128, 16, 16]          73,856
       BatchNorm2d-5          [-1, 128, 16, 16]             256
           Dropout-6          [-1, 128, 16, 16]               0
   ConvTranspose2d-7             [-1, 64, 8, 8]          73,792
       BatchNorm2d-8             [-1, 64, 8, 8]             128
           Dropout-9             [-1, 64, 8, 8]               0
  ConvTranspose2d-10           [-1, 32, 16, 16]          18,464
      BatchNorm2d-11           [-1, 32, 16, 16]              64
          Dropout-12           [-1, 32, 16, 16]               0
  ConvTranspose2d-13            [-1, 2, 32, 32]             578
Total params: 167,906
Trainable params:

In [16]:
class AverageMeter(object):
    '''A handy class from the PyTorch ImageNet tutorial''' 
    def __init__(self):
        self.reset()
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
       Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
    color_image = lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    if save_path is not None and save_name is not None: 
        plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
        plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

In [17]:
def validate(val_loader, model, criterion, save_images, epoch):
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]

    model.eval()
    already_saved_images = False
    for gray, ab in val_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()

        # Run model and record loss
        output_ab = model(gray) # throw away class predictions
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))

        # Save images to file
        if save_images and not already_saved_images:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)): # save at most 5 images
                save_path = {'grayscale': gray_imgs, 'colorized': color_imgs}
                save_name = 'img-{}-epoch-{}.jpg'.format(j, epoch)
                to_rgb(gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

    print(f'Validate: MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print('Finished validation.')
    if epoch >= 0:
        writer.add_scalar("MSE/test", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/test", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/test", _loss[2].avg, epoch)
    return _loss[0].avg, _loss[1].avg, _loss[2].avg

In [18]:
def train(train_loader, model, criterion, optimizer, epoch):
    print(f'Starting training epoch {epoch}')
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]
    
    model.train()

    for gray, ab in train_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()
            
        optimizer.zero_grad()

        output_ab = model(gray) 
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        loss[0].backward()
        optimizer.step()
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))
        
    print(f'Epoch: {epoch}, MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print(f'Finished training epoch {epoch}')
    if epoch >= 0:
        writer.add_scalar("MSE/train", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/train", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/train", _loss[2].avg, epoch)

In [19]:
# Train model
for epoch in range(epochs):
    # Train for one epoch, then validate
    train(train_loader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(val_loader, model, criterion, save_images, epoch)
    # Save checkpoint and replace old best model if current model is better
    if losses[0] < best_losses[0]:
        best_losses[0] = losses[0]
        best_epoch = epoch
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}.pth')
    if losses[1] < best_losses[1]:
        best_losses[1] = losses[1]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-PSNRLoss-{losses[1]:.8f}.pth')
    if losses[2] < best_losses[2]:
        best_losses[2] = losses[2]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-SSIMLoss-{losses[2]:.8f}.pth')
    
    if epoch - best_epoch >= patience:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}-early_stop.pth')
        break
    
    if epoch == epochs - 1:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')


Starting training epoch 0
Epoch: 0, MSE 0.01579515 (0.16102453), PSNR 18.01476288 (13.77749583), SSIM 0.21847618 (0.11140301)
Finished training epoch 0
Validate: MSE 0.00876218 (0.00640827), PSNR 20.57387924 (21.97239111), SSIM 0.37430716 (0.49390418)
Finished validation.
Starting training epoch 1
Epoch: 1, MSE 0.00751342 (0.01071239), PSNR 21.24162292 (19.77268370), SSIM 0.53342354 (0.42452353)
Finished training epoch 1
Validate: MSE 0.00378541 (0.00304306), PSNR 24.21887016 (25.24605891), SSIM 0.68414509 (0.76413183)
Finished validation.
Starting training epoch 2
Epoch: 2, MSE 0.00597144 (0.00670460), PSNR 22.23920631 (21.76585656), SSIM 0.57873458 (0.55396768)
Finished training epoch 2
Validate: MSE 0.00334013 (0.00281006), PSNR 24.76236725 (25.58433429), SSIM 0.69902939 (0.77490507)
Finished validation.
Starting training epoch 3
Epoch: 3, MSE 0.00449400 (0.00499803), PSNR 23.47366714 (23.02723142), SSIM 0.61164105 (0.59628953)
Finished training epoch 3
Validate: MSE 0.00366516 (0.0

Epoch: 30, MSE 0.00289964 (0.00311555), PSNR 25.37655640 (25.17379223), SSIM 0.76591700 (0.74231376)
Finished training epoch 30
Validate: MSE 0.00334742 (0.00292694), PSNR 24.75289154 (25.41130697), SSIM 0.70418531 (0.77535697)
Finished validation.
Starting training epoch 31
Epoch: 31, MSE 0.00361930 (0.00340949), PSNR 24.41375923 (24.83173038), SSIM 0.72683012 (0.72955913)
Finished training epoch 31
Validate: MSE 0.00357439 (0.00307057), PSNR 24.46797562 (25.19720572), SSIM 0.66544580 (0.74753260)
Finished validation.
Starting training epoch 32
Epoch: 32, MSE 0.00317664 (0.00301087), PSNR 24.98032570 (25.28118221), SSIM 0.70798969 (0.74071470)
Finished training epoch 32
Validate: MSE 0.00351382 (0.00298796), PSNR 24.54220390 (25.32084614), SSIM 0.68931615 (0.75691785)
Finished validation.
Starting training epoch 33
Epoch: 33, MSE 0.00230920 (0.00296705), PSNR 26.36538315 (25.33933406), SSIM 0.77667969 (0.74306791)
Finished training epoch 33
Validate: MSE 0.00318705 (0.00283896), PSNR 

  return func(*args, **kwargs)


Validate: MSE 0.00456892 (0.00485047), PSNR 23.40186119 (23.25197393), SSIM 0.63362682 (0.69772698)
Finished validation.
Starting training epoch 48
Epoch: 48, MSE 0.00303975 (0.00286211), PSNR 25.17162704 (25.47426622), SSIM 0.70813257 (0.71647102)
Finished training epoch 48
Validate: MSE 0.00343896 (0.00334391), PSNR 24.63572884 (24.87430264), SSIM 0.68107814 (0.76245477)
Finished validation.
Starting training epoch 49
Epoch: 49, MSE 0.00263400 (0.00278673), PSNR 25.79384804 (25.59798874), SSIM 0.71400452 (0.72087612)
Finished training epoch 49
Validate: MSE 0.00404687 (0.00430781), PSNR 23.92880440 (23.78978765), SSIM 0.67499977 (0.74104920)
Finished validation.
Starting training epoch 50
Epoch: 50, MSE 0.00317968 (0.00296350), PSNR 24.97616196 (25.35614650), SSIM 0.72358710 (0.70745784)
Finished training epoch 50
Validate: MSE 0.00389420 (0.00398417), PSNR 24.09581566 (24.16288121), SSIM 0.68919998 (0.76769037)
Finished validation.
Starting training epoch 51
Epoch: 51, MSE 0.0025428

  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.01040766 (0.00955157), PSNR 19.82646942 (20.23363876), SSIM 0.51055324 (0.58728012)
Finished validation.
Starting training epoch 62
Epoch: 62, MSE 0.00281005 (0.00285706), PSNR 25.51285744 (25.50912763), SSIM 0.72569960 (0.71361512)
Finished training epoch 62
Validate: MSE 0.01043666 (0.01130649), PSNR 19.81438255 (19.57906034), SSIM 0.64285851 (0.72320781)
Finished validation.
Starting training epoch 63
Epoch: 63, MSE 0.00353289 (0.00284981), PSNR 24.51870155 (25.49805151), SSIM 0.69105160 (0.71102785)
Finished training epoch 63
Validate: MSE 0.00796952 (0.00659340), PSNR 20.98567772 (21.87148614), SSIM 0.49372002 (0.58377812)
Finished validation.
Starting training epoch 64
Epoch: 64, MSE 0.00241955 (0.00267649), PSNR 26.16265488 (25.76253257), SSIM 0.72250104 (0.72490759)
Finished training epoch 64


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00529193 (0.00523775), PSNR 22.76385689 (22.88924182), SSIM 0.64666104 (0.71185264)
Finished validation.
Starting training epoch 65
Epoch: 65, MSE 0.00294622 (0.00281880), PSNR 25.30734253 (25.55088234), SSIM 0.70191562 (0.71209656)
Finished training epoch 65


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00671264 (0.00627027), PSNR 21.73106575 (22.08659934), SSIM 0.55701536 (0.63231635)
Finished validation.
Starting training epoch 66
Epoch: 66, MSE 0.00290985 (0.00284108), PSNR 25.36129379 (25.50705967), SSIM 0.72699690 (0.70806970)
Finished training epoch 66
Validate: MSE 0.00567411 (0.00599176), PSNR 22.46102333 (22.38149716), SSIM 0.62447810 (0.70482333)
Finished validation.
Starting training epoch 67
Epoch: 67, MSE 0.00412855 (0.00293311), PSNR 23.84202194 (25.40249501), SSIM 0.56877059 (0.70169991)
Finished training epoch 67
Validate: MSE 0.00783565 (0.00652175), PSNR 21.05925179 (21.86926026), SSIM 0.47660613 (0.56791509)
Finished validation.
Starting training epoch 68
Epoch: 68, MSE 0.00285481 (0.00277530), PSNR 25.44422913 (25.63575580), SSIM 0.71407282 (0.71640291)
Finished training epoch 68
Validate: MSE 0.00475708 (0.00474491), PSNR 23.22659111 (23.36603164), SSIM 0.68695146 (0.75980035)
Finished validation.
Starting training epoch 69
Epoch: 69, MSE 0.0021356

  return func(*args, **kwargs)


Validate: MSE 0.00555000 (0.00555412), PSNR 22.55707359 (22.65645133), SSIM 0.67723775 (0.74231193)
Finished validation.
Starting training epoch 72
Epoch: 72, MSE 0.00216674 (0.00269960), PSNR 26.64193726 (25.72409745), SSIM 0.73250854 (0.71708212)
Finished training epoch 72
Validate: MSE 0.00546608 (0.00582396), PSNR 22.62323952 (22.48317336), SSIM 0.67367852 (0.74321613)
Finished validation.
Starting training epoch 73
Epoch: 73, MSE 0.00289580 (0.00266690), PSNR 25.38231850 (25.77491496), SSIM 0.74372286 (0.72184561)
Finished training epoch 73
Validate: MSE 0.00476023 (0.00510309), PSNR 23.22372055 (23.08670646), SSIM 0.67665112 (0.75393102)
Finished validation.
Starting training epoch 74
Epoch: 74, MSE 0.00246539 (0.00272878), PSNR 26.08115005 (25.68633898), SSIM 0.73023212 (0.71557912)
Finished training epoch 74


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00672279 (0.00640832), PSNR 21.72450447 (21.98732769), SSIM 0.57259637 (0.64144753)
Finished validation.
Starting training epoch 75
Epoch: 75, MSE 0.00236272 (0.00277565), PSNR 26.26588058 (25.61036469), SSIM 0.73973101 (0.70896792)
Finished training epoch 75
Validate: MSE 0.00656484 (0.00677118), PSNR 21.82775497 (21.84780444), SSIM 0.63571978 (0.72293238)
Finished validation.
Starting training epoch 76
Epoch: 76, MSE 0.00238362 (0.00260180), PSNR 26.22762680 (25.87480593), SSIM 0.75073874 (0.72599321)
Finished training epoch 76
Validate: MSE 0.00691301 (0.00727406), PSNR 21.60332870 (21.50073842), SSIM 0.65553558 (0.72695724)
Finished validation.
Starting training epoch 77
Epoch: 77, MSE 0.00233271 (0.00279961), PSNR 26.32139778 (25.57208703), SSIM 0.73578477 (0.70879813)
Finished training epoch 77
Validate: MSE 0.00450952 (0.00467588), PSNR 23.45869255 (23.44110293), SSIM 0.69815481 (0.76671941)
Finished validation.
Starting training epoch 78
Epoch: 78, MSE 0.0031984

  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00984705 (0.00997594), PSNR 20.06693649 (20.09186264), SSIM 0.63138747 (0.70054598)
Finished validation.
Starting training epoch 79
Epoch: 79, MSE 0.00312076 (0.00274764), PSNR 25.05739403 (25.64535313), SSIM 0.69872296 (0.71328950)
Finished training epoch 79
Validate: MSE 0.01166597 (0.01102862), PSNR 19.33079147 (19.65566320), SSIM 0.49196839 (0.57845925)
Finished validation.
Starting training epoch 80
Epoch: 80, MSE 0.00361743 (0.00275008), PSNR 24.41599274 (25.65478955), SSIM 0.60798013 (0.71294378)
Finished training epoch 80
Validate: MSE 0.01052099 (0.00938587), PSNR 19.77943039 (20.34500561), SSIM 0.49517062 (0.56191199)
Finished validation.
Starting training epoch 81
Epoch: 81, MSE 0.00230450 (0.00265830), PSNR 26.37423706 (25.79218319), SSIM 0.70592821 (0.72160543)
Finished training epoch 81


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00637851 (0.00636645), PSNR 21.95280838 (22.04023243), SSIM 0.62863839 (0.69522878)
Finished validation.
Starting training epoch 82
Epoch: 82, MSE 0.00256450 (0.00278603), PSNR 25.90997314 (25.59986873), SSIM 0.72964841 (0.70843110)
Finished training epoch 82
Validate: MSE 0.00776600 (0.00768195), PSNR 21.09802246 (21.25263286), SSIM 0.62552369 (0.70565169)
Finished validation.
Starting training epoch 83
Epoch: 83, MSE 0.00264693 (0.00265266), PSNR 25.77258301 (25.78914694), SSIM 0.68002123 (0.72223258)
Finished training epoch 83
Validate: MSE 0.00584350 (0.00625943), PSNR 22.33326721 (22.18998118), SSIM 0.61929476 (0.69868905)
Finished validation.
Starting training epoch 84
Epoch: 84, MSE 0.00246485 (0.00269497), PSNR 26.08209038 (25.72984265), SSIM 0.69842017 (0.71480003)
Finished training epoch 84
Validate: MSE 0.00526012 (0.00531962), PSNR 22.79004478 (22.88782999), SSIM 0.65299380 (0.73294063)
Finished validation.
Starting training epoch 85
Epoch: 85, MSE 0.0032497

  return func(*args, **kwargs)


Validate: MSE 0.00588916 (0.00612008), PSNR 22.29946709 (22.24658584), SSIM 0.69598436 (0.76024268)
Finished validation.
Starting training epoch 86
Epoch: 86, MSE 0.00291573 (0.00268977), PSNR 25.35253143 (25.73552865), SSIM 0.67452401 (0.71900296)
Finished training epoch 86


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.01028118 (0.01062400), PSNR 19.87957001 (19.80110782), SSIM 0.59693164 (0.66068334)
Finished validation.


<Figure size 432x288 with 0 Axes>

In [20]:
torch.save(model.state_dict(), f'{checkpoints}/last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')

In [21]:
# Validate
save_images = True
with torch.no_grad():
    validate(val_loader, model, criterion, save_images, -1)

Validate: MSE 0.01028118 (0.01062400), PSNR 19.87957001 (19.80110782), SSIM 0.59693158 (0.66068335)
Finished validation.


<Figure size 432x288 with 0 Axes>

In [22]:
# # Show images 
# image_pairs = []

# for i in range(10):
#     image_pairs.append((f'{color_imgs}img-{i}-epoch-{best_epoch}.jpg', f'{gray_imgs}img-{i}-epoch-{best_epoch}.jpg'))
    
# for c, g in image_pairs:
#   color = mpimg.imread(c)
#   gray  = mpimg.imread(g)
#   f, axarr = plt.subplots(1, 2)
#   f.set_size_inches(15, 15)
#   axarr[0].imshow(gray, cmap='gray')
#   axarr[1].imshow(color)
#   axarr[0].axis('off'), axarr[1].axis('off')
#   plt.show()