In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# For conversion
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchmetrics import MeanSquaredError, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from PIL import Image
# For utilities
import os, shutil, time

In [2]:
%load_ext tensorboard
%tensorboard --logdir=runs

In [3]:
# colab i kaggle jeszcze nie testowane
colab = False
kaggle = False
test_number = '12_2'

In [4]:
color_imgs = 'outputs/color/'
gray_imgs = 'outputs/gray/'
checkpoints = 'checkpoints'
if colab:
    from google.colab import drive
    drive.mount('/content/drive')
    dataset = '/content/drive/MyDrive/MGU/cifar10/'
    
    color_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{color_imgs}'
    gray_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{gray_imgs}'
    checkpoints = f'/content/drive/MyDrive/MGU/{test_number}/{checkpoints}'
elif kaggle:
    dataset = '/kaggle/input/cifar10/'
    
    color_imgs = f'{test_number}/{color_imgs}'
    gray_imgs = f'{test_number}/{gray_imgs}'
    checkpoints = f'{test_number}/{checkpoints}'
else:
    dataset = '../../datasets/cifar10/'

In [5]:
# Make folders and set parameters
os.makedirs(color_imgs, exist_ok=True)
os.makedirs(gray_imgs, exist_ok=True)
os.makedirs(checkpoints, exist_ok=True)
save_images = True
best_losses = [1e10, 1e10, 1e10]
best_epoch = -1
patience = 50
epochs = 500
batch_size = 128
SIZE = 32

In [6]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [7]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()
print(use_gpu)

True


In [8]:
class LabImageFolder(torch.utils.data.Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC),
                transforms.RandomCrop(SIZE),
                transforms.RandomHorizontalFlip(),  
            ])
        elif split == 'val':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC), 
                transforms.RandomCrop(SIZE), 
            ])
            
        self.split = split
        self.size = SIZE
        self.paths = [os.path.join(paths, file) for file in os.listdir(paths) if os.path.isfile(
            os.path.join(paths, file))]
        
        
    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("RGB")
        img_original = self.transforms(img)
        img_original = np.asarray(img_original)
        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        img_gray = rgb2gray(img_original)
        img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
        return img_gray, img_ab
    
    def __len__(self):
        return len(self.paths)

In [9]:
# Training
train_imagefolder = LabImageFolder(dataset + 'train')
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=batch_size, shuffle=True)
# Validation 
val_imagefolder = LabImageFolder(dataset + 'val' , 'val')
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=batch_size, shuffle=False)

In [10]:
kernel_size=3
stride_en=1
stride_de=1
padding=1
scale_factor=2
padding_mode='zeros'
channels_base = 256

class Autoencoder(nn.Module):
    def __init__(self, input_size=128):
        super(Autoencoder, self).__init__()

        self.conv1 = nn.Conv2d(1, channels_base, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        self.conv2 = nn.Conv2d(channels_base, channels_base * 2, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)

        self.convtrans1 = nn.ConvTranspose2d(channels_base * 2, channels_base, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans2 = nn.ConvTranspose2d(channels_base, channels_base // 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans3 = nn.ConvTranspose2d(channels_base // 2, 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)

        self.batchnorm1 = nn.BatchNorm2d(channels_base)
        self.batchnorm2 = nn.BatchNorm2d(channels_base * 2)
        self.batchnorm3 = nn.BatchNorm2d(channels_base)
        self.batchnorm4 = nn.BatchNorm2d(channels_base // 2)
        
    def forward(self, input):
        # encoder
        x = F.leaky_relu(self.batchnorm1(self.conv1(input)), negative_slope=0.1)
        x = F.avg_pool2d(x, kernel_size=3, stride=2, padding=1)
        x = F.leaky_relu(self.batchnorm2(self.conv2(x)), negative_slope=0.1)
        x = F.avg_pool2d(x, kernel_size=3, stride=2, padding=1)
        
        # decoder
        x = F.leaky_relu(self.batchnorm3(self.convtrans1(x)), negative_slope=0.1)
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.leaky_relu(self.batchnorm4(self.convtrans2(x)), negative_slope=0.1)
        x = F.interpolate(self.convtrans3(x), scale_factor=scale_factor)

        return x

In [11]:
model = Autoencoder()

In [12]:
criterion = [MeanSquaredError(), PeakSignalNoiseRatio(data_range=1.0), StructuralSimilarityIndexMeasure(data_range=1.0)]

In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [14]:
# # Move model and loss function to GPU
if use_gpu: 
    criterion = [criterion[0].to("cuda"), criterion[1].to("cuda"), criterion[2].to("cuda")]
    model = model.cuda()

In [15]:
if use_gpu: 
    from torchsummary import summary
    summary(model, (1, SIZE, SIZE))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 256, 32, 32]           2,560
       BatchNorm2d-2          [-1, 256, 32, 32]             512
            Conv2d-3          [-1, 512, 16, 16]       1,180,160
       BatchNorm2d-4          [-1, 512, 16, 16]           1,024
   ConvTranspose2d-5            [-1, 256, 8, 8]       1,179,904
       BatchNorm2d-6            [-1, 256, 8, 8]             512
   ConvTranspose2d-7          [-1, 128, 16, 16]         295,040
       BatchNorm2d-8          [-1, 128, 16, 16]             256
   ConvTranspose2d-9            [-1, 2, 16, 16]           2,306
Total params: 2,662,274
Trainable params: 2,662,274
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 6.75
Params size (MB): 10.16
Estimated Total Size (MB): 16.91
-------------------------------------

In [16]:
class AverageMeter(object):
    '''A handy class from the PyTorch ImageNet tutorial''' 
    def __init__(self):
        self.reset()
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
       Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
    color_image = lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    if save_path is not None and save_name is not None: 
        plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
        plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

In [17]:
def validate(val_loader, model, criterion, save_images, epoch):
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]

    model.eval()
    already_saved_images = False
    for gray, ab in val_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()

        # Run model and record loss
        output_ab = model(gray) # throw away class predictions
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))

        # Save images to file
        if save_images and not already_saved_images:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)): # save at most 5 images
                save_path = {'grayscale': gray_imgs, 'colorized': color_imgs}
                save_name = 'img-{}-epoch-{}.jpg'.format(j, epoch)
                to_rgb(gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

    print(f'Validate: MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print('Finished validation.')
    if epoch >= 0:
        writer.add_scalar("MSE/test", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/test", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/test", _loss[2].avg, epoch)
    return _loss[0].avg, _loss[1].avg, _loss[2].avg

In [18]:
def train(train_loader, model, criterion, optimizer, epoch):
    print(f'Starting training epoch {epoch}')
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]
    
    model.train()

    for gray, ab in train_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()
            
        optimizer.zero_grad()

        output_ab = model(gray) 
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        loss[0].backward()
        optimizer.step()
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))
        
    print(f'Epoch: {epoch}, MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print(f'Finished training epoch {epoch}')
    if epoch >= 0:
        writer.add_scalar("MSE/train", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/train", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/train", _loss[2].avg, epoch)

In [19]:
# Train model
for epoch in range(epochs):
    # Train for one epoch, then validate
    train(train_loader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(val_loader, model, criterion, save_images, epoch)
    # Save checkpoint and replace old best model if current model is better
    if losses[0] < best_losses[0]:
        best_losses[0] = losses[0]
        best_epoch = epoch
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}.pth')
    if losses[1] < best_losses[1]:
        best_losses[1] = losses[1]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-PSNRLoss-{losses[1]:.8f}.pth')
    if losses[2] < best_losses[2]:
        best_losses[2] = losses[2]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-SSIMLoss-{losses[2]:.8f}.pth')
    
    if epoch - best_epoch >= patience and epoch >= 100:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}-early_stop.pth')
        break
    
    if epoch == epochs - 1:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')


Starting training epoch 0
Epoch: 0, MSE 0.01403700 (1.10604152), PSNR 18.52725601 (13.45730746), SSIM 0.25151959 (0.16115290)
Finished training epoch 0


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.01299332 (0.01219544), PSNR 18.86279869 (19.16979901), SSIM 0.20267975 (0.27511389)
Finished validation.
Starting training epoch 1
Epoch: 1, MSE 0.00638667 (0.00863926), PSNR 21.94725609 (20.71664290), SSIM 0.40723848 (0.34501353)
Finished training epoch 1


  return func(*args, **kwargs)


Validate: MSE 0.00705020 (0.00621563), PSNR 21.51798820 (22.08946611), SSIM 0.30919349 (0.40057519)
Finished validation.
Starting training epoch 2
Epoch: 2, MSE 0.00486796 (0.00575800), PSNR 23.12652588 (22.42987719), SSIM 0.47686148 (0.43843427)
Finished training epoch 2


  return func(*args, **kwargs)


Validate: MSE 0.00583069 (0.00559902), PSNR 22.34279823 (22.57581513), SSIM 0.37776470 (0.46822066)
Finished validation.
Starting training epoch 3
Epoch: 3, MSE 0.00433237 (0.00474015), PSNR 23.63274765 (23.26658913), SSIM 0.53503323 (0.49857273)
Finished training epoch 3
Validate: MSE 0.00550643 (0.00487058), PSNR 22.59129715 (23.15691014), SSIM 0.42187586 (0.51986861)
Finished validation.
Starting training epoch 4
Epoch: 4, MSE 0.00395494 (0.00434421), PSNR 24.02859497 (23.65808513), SSIM 0.59011328 (0.54000273)
Finished training epoch 4
Validate: MSE 0.00575043 (0.00582091), PSNR 22.40299988 (22.48346262), SSIM 0.44516003 (0.53823552)
Finished validation.
Starting training epoch 5
Epoch: 5, MSE 0.00459446 (0.00497631), PSNR 23.37765884 (23.20637043), SSIM 0.59345829 (0.56314319)
Finished training epoch 5
Validate: MSE 0.00689051 (0.00752965), PSNR 21.61748695 (21.42993427), SSIM 0.45940948 (0.54983176)
Finished validation.
Starting training epoch 6
Epoch: 6, MSE 0.00437830 (0.005679

  return func(*args, **kwargs)


Validate: MSE 0.00430464 (0.00424570), PSNR 23.66062927 (23.83279915), SSIM 0.57066810 (0.65928069)
Finished validation.
Starting training epoch 13
Epoch: 13, MSE 0.00468285 (0.00514263), PSNR 23.29489326 (23.16980596), SSIM 0.62744683 (0.65307865)
Finished training epoch 13
Validate: MSE 0.00451306 (0.00464948), PSNR 23.45528793 (23.41726406), SSIM 0.57446420 (0.66249702)
Finished validation.
Starting training epoch 14
Epoch: 14, MSE 0.00502409 (0.00506608), PSNR 22.98942757 (23.32739623), SSIM 0.63140601 (0.66034466)
Finished training epoch 14
Validate: MSE 0.00527541 (0.00629407), PSNR 22.77744102 (22.17141524), SSIM 0.55878580 (0.64116774)
Finished validation.
Starting training epoch 15
Epoch: 15, MSE 0.00239081 (0.00433925), PSNR 26.21455193 (23.91964270), SSIM 0.70529079 (0.67297887)
Finished training epoch 15
Validate: MSE 0.00360616 (0.00324129), PSNR 24.42955208 (24.94248177), SSIM 0.61235571 (0.69638186)
Finished validation.
Starting training epoch 16
Epoch: 16, MSE 0.0027331

Validate: MSE 0.00294292 (0.00249846), PSNR 25.31221962 (26.10013888), SSIM 0.71045989 (0.77806440)
Finished validation.
Starting training epoch 43
Epoch: 43, MSE 0.00234330 (0.00236362), PSNR 26.30171585 (26.28605596), SSIM 0.78631914 (0.77849882)
Finished training epoch 43
Validate: MSE 0.00285498 (0.00280611), PSNR 25.44396400 (25.62436049), SSIM 0.71333897 (0.77099556)
Finished validation.
Starting training epoch 44
Epoch: 44, MSE 0.00244112 (0.00236866), PSNR 26.12411118 (26.27758950), SSIM 0.78689259 (0.77862889)
Finished training epoch 44
Validate: MSE 0.00276272 (0.00256298), PSNR 25.58663559 (25.98084699), SSIM 0.71664178 (0.77943189)
Finished validation.
Starting training epoch 45
Epoch: 45, MSE 0.00194610 (0.00236732), PSNR 27.10833740 (26.27993754), SSIM 0.77793759 (0.77890254)
Finished training epoch 45
Validate: MSE 0.00308168 (0.00369695), PSNR 25.11212349 (24.51479241), SSIM 0.70731997 (0.76281651)
Finished validation.
Starting training epoch 46
Epoch: 46, MSE 0.0018342

Validate: MSE 0.00266173 (0.00240140), PSNR 25.74836159 (26.24902934), SSIM 0.71600777 (0.77553653)
Finished validation.
Starting training epoch 73
Epoch: 73, MSE 0.00205699 (0.00222874), PSNR 26.86767960 (26.54532503), SSIM 0.77208233 (0.78254075)
Finished training epoch 73
Validate: MSE 0.00269836 (0.00258027), PSNR 25.68899536 (25.98589485), SSIM 0.72150981 (0.77965741)
Finished validation.
Starting training epoch 74
Epoch: 74, MSE 0.00238356 (0.00222777), PSNR 26.22773552 (26.54492237), SSIM 0.78238124 (0.78254921)
Finished training epoch 74
Validate: MSE 0.00280192 (0.00244416), PSNR 25.52543449 (26.18706679), SSIM 0.71742052 (0.77755220)
Finished validation.
Starting training epoch 75
Epoch: 75, MSE 0.00287452 (0.00223320), PSNR 25.41434479 (26.53727113), SSIM 0.75718009 (0.78257837)
Finished training epoch 75
Validate: MSE 0.00266335 (0.00274213), PSNR 25.74571419 (25.74121021), SSIM 0.71286047 (0.77329797)
Finished validation.
Starting training epoch 76
Epoch: 76, MSE 0.0019747

Validate: MSE 0.00315013 (0.00247961), PSNR 25.01670837 (26.13041531), SSIM 0.71552718 (0.77642312)
Finished validation.
Starting training epoch 103
Epoch: 103, MSE 0.00281350 (0.00211153), PSNR 25.50753593 (26.77691934), SSIM 0.76455700 (0.78312237)
Finished training epoch 103
Validate: MSE 0.00256008 (0.00242409), PSNR 25.91745758 (26.22137594), SSIM 0.71764195 (0.77846082)
Finished validation.
Starting training epoch 104
Epoch: 104, MSE 0.00212555 (0.00211269), PSNR 26.72527695 (26.77561047), SSIM 0.77240580 (0.78303348)
Finished training epoch 104
Validate: MSE 0.00272138 (0.00228068), PSNR 25.65210152 (26.47537342), SSIM 0.72003520 (0.77484576)
Finished validation.
Starting training epoch 105
Epoch: 105, MSE 0.00185847 (0.00209742), PSNR 27.30844307 (26.80615264), SSIM 0.79587346 (0.78282235)
Finished training epoch 105
Validate: MSE 0.00276972 (0.00232731), PSNR 25.57564163 (26.39920384), SSIM 0.71613395 (0.77529388)
Finished validation.
Starting training epoch 106
Epoch: 106, MS

Epoch: 132, MSE 0.00200828 (0.00201843), PSNR 26.97176552 (26.97108126), SSIM 0.78243315 (0.78220208)
Finished training epoch 132
Validate: MSE 0.00244369 (0.00228061), PSNR 26.11953545 (26.47062492), SSIM 0.71949476 (0.77088816)
Finished validation.
Starting training epoch 133
Epoch: 133, MSE 0.00200382 (0.00200676), PSNR 26.98140144 (26.99708574), SSIM 0.77584708 (0.78223384)
Finished training epoch 133
Validate: MSE 0.00249378 (0.00231218), PSNR 26.03141785 (26.42287705), SSIM 0.72519898 (0.77112905)
Finished validation.
Starting training epoch 134
Epoch: 134, MSE 0.00192218 (0.00199780), PSNR 27.16205025 (27.01659740), SSIM 0.77656806 (0.78207760)
Finished training epoch 134
Validate: MSE 0.00273002 (0.00230033), PSNR 25.63834190 (26.44675025), SSIM 0.71176392 (0.77371619)
Finished validation.
Starting training epoch 135
Epoch: 135, MSE 0.00224262 (0.00199823), PSNR 26.49244690 (27.01564144), SSIM 0.78203720 (0.78223235)
Finished training epoch 135
Validate: MSE 0.00264042 (0.00230

Validate: MSE 0.00256616 (0.00246275), PSNR 25.90716553 (26.13174468), SSIM 0.70543724 (0.76583651)
Finished validation.
Starting training epoch 162
Epoch: 162, MSE 0.00230615 (0.00189825), PSNR 26.37112617 (27.23727981), SSIM 0.75957263 (0.78160853)
Finished training epoch 162
Validate: MSE 0.00260065 (0.00227776), PSNR 25.84918213 (26.48146028), SSIM 0.71317405 (0.76297945)
Finished validation.
Starting training epoch 163
Epoch: 163, MSE 0.00156064 (0.00189945), PSNR 28.06698227 (27.23436867), SSIM 0.78908968 (0.78162139)
Finished training epoch 163
Validate: MSE 0.00300086 (0.00273927), PSNR 25.22754669 (25.70497955), SSIM 0.70109117 (0.75365598)
Finished validation.
Starting training epoch 164
Epoch: 164, MSE 0.00188296 (0.00189202), PSNR 27.25158310 (27.25181412), SSIM 0.76831013 (0.78158006)
Finished training epoch 164
Validate: MSE 0.00250807 (0.00238145), PSNR 26.00661087 (26.28090714), SSIM 0.70979190 (0.76118318)
Finished validation.
Starting training epoch 165
Epoch: 165, MS

Epoch: 191, MSE 0.00183773 (0.00181660), PSNR 27.35719109 (27.42534922), SSIM 0.77788228 (0.78109417)
Finished training epoch 191
Validate: MSE 0.00311663 (0.00248235), PSNR 25.06314087 (26.11803687), SSIM 0.69687200 (0.76009504)
Finished validation.
Starting training epoch 192
Epoch: 192, MSE 0.00162699 (0.00180183), PSNR 27.88615227 (27.46035644), SSIM 0.78394830 (0.78101321)
Finished training epoch 192
Validate: MSE 0.00278837 (0.00238315), PSNR 25.54649353 (26.28521685), SSIM 0.69094640 (0.76144506)
Finished validation.
Starting training epoch 193
Epoch: 193, MSE 0.00184486 (0.00179570), PSNR 27.34036446 (27.47501983), SSIM 0.78821409 (0.78132550)
Finished training epoch 193
Validate: MSE 0.00366073 (0.00289240), PSNR 24.36431694 (25.45067295), SSIM 0.69759631 (0.76054877)
Finished validation.
Starting training epoch 194
Epoch: 194, MSE 0.00208884 (0.00179895), PSNR 26.80095673 (27.47004719), SSIM 0.78238904 (0.78118910)
Finished training epoch 194
Validate: MSE 0.00282504 (0.00229

<Figure size 432x288 with 0 Axes>

In [20]:
torch.save(model.state_dict(), f'{checkpoints}/last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')

In [21]:
# Validate
save_images = True
with torch.no_grad():
    validate(val_loader, model, criterion, save_images, -1)

Validate: MSE 0.00301909 (0.00241591), PSNR 25.20123863 (26.23248492), SSIM 0.68872857 (0.75954099)
Finished validation.


<Figure size 432x288 with 0 Axes>

In [22]:
# # Show images 
# image_pairs = []

# for i in range(10):
#     image_pairs.append((f'{color_imgs}img-{i}-epoch-{best_epoch}.jpg', f'{gray_imgs}img-{i}-epoch-{best_epoch}.jpg'))
    
# for c, g in image_pairs:
#   color = mpimg.imread(c)
#   gray  = mpimg.imread(g)
#   f, axarr = plt.subplots(1, 2)
#   f.set_size_inches(15, 15)
#   axarr[0].imshow(gray, cmap='gray')
#   axarr[1].imshow(color)
#   axarr[0].axis('off'), axarr[1].axis('off')
#   plt.show()