In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# For conversion
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchmetrics import MeanSquaredError, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from PIL import Image
# For utilities
import os, shutil, time

In [2]:
%load_ext tensorboard
%tensorboard --logdir=runs

In [3]:
# colab i kaggle jeszcze nie testowane
colab = False
kaggle = False
test_number = '12_2'

In [4]:
color_imgs = 'outputs/color/'
gray_imgs = 'outputs/gray/'
checkpoints = 'checkpoints'
if colab:
    from google.colab import drive
    drive.mount('/content/drive')
    dataset = '/content/drive/MyDrive/MGU/cifar10/'
    
    color_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{color_imgs}'
    gray_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{gray_imgs}'
    checkpoints = f'/content/drive/MyDrive/MGU/{test_number}/{checkpoints}'
elif kaggle:
    dataset = '/kaggle/input/cifar10/'
    
    color_imgs = f'{test_number}/{color_imgs}'
    gray_imgs = f'{test_number}/{gray_imgs}'
    checkpoints = f'{test_number}/{checkpoints}'
else:
    dataset = '../../datasets/cifar10/'

In [5]:
# Make folders and set parameters
os.makedirs(color_imgs, exist_ok=True)
os.makedirs(gray_imgs, exist_ok=True)
os.makedirs(checkpoints, exist_ok=True)
save_images = True
best_losses = [1e10, 1e10, 1e10]
best_epoch = -1
patience = 50
epochs = 500
batch_size = 128
SIZE = 32

In [6]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [7]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()
print(use_gpu)

True


In [8]:
class LabImageFolder(torch.utils.data.Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC),
                transforms.RandomCrop(SIZE),
                transforms.RandomHorizontalFlip(),  
            ])
        elif split == 'val':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC), 
                transforms.RandomCrop(SIZE), 
            ])
            
        self.split = split
        self.size = SIZE
        self.paths = [os.path.join(paths, file) for file in os.listdir(paths) if os.path.isfile(
            os.path.join(paths, file))]
        
        
    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("RGB")
        img_original = self.transforms(img)
        img_original = np.asarray(img_original)
        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        img_gray = rgb2gray(img_original)
        img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
        return img_gray, img_ab
    
    def __len__(self):
        return len(self.paths)

In [9]:
# Training
train_imagefolder = LabImageFolder(dataset + 'train')
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=batch_size, shuffle=True)
# Validation 
val_imagefolder = LabImageFolder(dataset + 'val' , 'val')
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=batch_size, shuffle=False)

In [10]:
kernel_size=3
stride_en=1
stride_de=1
padding=1
scale_factor=2
padding_mode='zeros'
channels_base = 64
p1 = .5

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        self.conv1 = nn.Conv2d(1, channels_base, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        self.conv2 = nn.Conv2d(channels_base, channels_base * 2, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        
        self.convtrans1 = nn.ConvTranspose2d(channels_base * 2, channels_base, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans2 = nn.ConvTranspose2d(channels_base, channels_base // 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans3 = nn.ConvTranspose2d(channels_base // 2, 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)

        self.batchnorm1 = nn.BatchNorm2d(channels_base // 2)
        self.batchnorm2 = nn.BatchNorm2d(channels_base)
        self.batchnorm3 = nn.BatchNorm2d(channels_base * 2)
        
        self.dropout1 = nn.Dropout(p=p1)        
        
    def forward(self, input):
        # encoder
        x = F.leaky_relu(self.batchnorm2(self.conv1(input)), negative_slope=.1)
        x = F.avg_pool2d(x, kernel_size=3, stride=2, padding=1)
        x = y = self.dropout1(x)
        x = F.leaky_relu(self.batchnorm3(self.conv2(x)), negative_slope=.1)
        x = F.avg_pool2d(x, kernel_size=3, stride=2, padding=1)
        x = self.dropout1(x)
        
        # decoder
        x = F.leaky_relu(self.batchnorm2(self.convtrans1(x)), negative_slope=.1)
        x = self.dropout1(x)
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.leaky_relu(self.batchnorm1(self.convtrans2(x + y)), negative_slope=.1)
        x = self.dropout1(x)
        x = self.convtrans3(F.interpolate(x, scale_factor=scale_factor) + input)

        return x

In [11]:
model = Autoencoder()

In [12]:
criterion = [MeanSquaredError(), PeakSignalNoiseRatio(data_range=1.0), StructuralSimilarityIndexMeasure(data_range=1.0)]

In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [14]:
# # Move model and loss function to GPU
if use_gpu: 
    criterion = [criterion[0].to("cuda"), criterion[1].to("cuda"), criterion[2].to("cuda")]
    model = model.cuda()

In [15]:
if use_gpu: 
    from torchsummary import summary
    summary(model, (1, SIZE, SIZE))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]             640
       BatchNorm2d-2           [-1, 64, 32, 32]             128
           Dropout-3           [-1, 64, 16, 16]               0
            Conv2d-4          [-1, 128, 16, 16]          73,856
       BatchNorm2d-5          [-1, 128, 16, 16]             256
           Dropout-6            [-1, 128, 8, 8]               0
   ConvTranspose2d-7             [-1, 64, 8, 8]          73,792
       BatchNorm2d-8             [-1, 64, 8, 8]             128
           Dropout-9             [-1, 64, 8, 8]               0
  ConvTranspose2d-10           [-1, 32, 16, 16]          18,464
      BatchNorm2d-11           [-1, 32, 16, 16]              64
          Dropout-12           [-1, 32, 16, 16]               0
  ConvTranspose2d-13            [-1, 2, 32, 32]             578
Total params: 167,906
Trainable params:

In [16]:
class AverageMeter(object):
    '''A handy class from the PyTorch ImageNet tutorial''' 
    def __init__(self):
        self.reset()
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
       Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
    color_image = lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    if save_path is not None and save_name is not None: 
        plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
        plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

In [17]:
def validate(val_loader, model, criterion, save_images, epoch):
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]

    model.eval()
    already_saved_images = False
    for gray, ab in val_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()

        # Run model and record loss
        output_ab = model(gray) # throw away class predictions
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))

        # Save images to file
        if save_images and not already_saved_images:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)): # save at most 5 images
                save_path = {'grayscale': gray_imgs, 'colorized': color_imgs}
                save_name = 'img-{}-epoch-{}.jpg'.format(j, epoch)
                to_rgb(gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

    print(f'Validate: MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print('Finished validation.')
    if epoch >= 0:
        writer.add_scalar("MSE/test", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/test", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/test", _loss[2].avg, epoch)
    return _loss[0].avg, _loss[1].avg, _loss[2].avg

In [18]:
def train(train_loader, model, criterion, optimizer, epoch):
    print(f'Starting training epoch {epoch}')
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]
    
    model.train()

    for gray, ab in train_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()
            
        optimizer.zero_grad()

        output_ab = model(gray) 
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        loss[0].backward()
        optimizer.step()
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))
        
    print(f'Epoch: {epoch}, MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print(f'Finished training epoch {epoch}')
    if epoch >= 0:
        writer.add_scalar("MSE/train", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/train", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/train", _loss[2].avg, epoch)

In [19]:
# Train model
for epoch in range(epochs):
    # Train for one epoch, then validate
    train(train_loader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(val_loader, model, criterion, save_images, epoch)
    # Save checkpoint and replace old best model if current model is better
    if losses[0] < best_losses[0]:
        best_losses[0] = losses[0]
        best_epoch = epoch
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}.pth')
    if losses[1] < best_losses[1]:
        best_losses[1] = losses[1]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-PSNRLoss-{losses[1]:.8f}.pth')
    if losses[2] < best_losses[2]:
        best_losses[2] = losses[2]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-SSIMLoss-{losses[2]:.8f}.pth')
    
    if epoch - best_epoch >= patience:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}-early_stop.pth')
        break
    
    if epoch == epochs - 1:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')


Starting training epoch 0
Epoch: 0, MSE 0.01061336 (0.10917074), PSNR 19.74147034 (15.27167198), SSIM 0.40097588 (0.18371201)
Finished training epoch 0
Validate: MSE 0.00713853 (0.00528875), PSNR 21.46390915 (22.79395141), SSIM 0.50021672 (0.61100447)
Finished validation.
Starting training epoch 1
Epoch: 1, MSE 0.00644733 (0.00853924), PSNR 21.90619659 (20.74478412), SSIM 0.59223449 (0.49637633)
Finished training epoch 1
Validate: MSE 0.00352503 (0.00300880), PSNR 24.52836800 (25.29431458), SSIM 0.66796434 (0.75502631)
Finished validation.
Starting training epoch 2
Epoch: 2, MSE 0.00494758 (0.00573407), PSNR 23.05607605 (22.43682124), SSIM 0.60730106 (0.59764148)
Finished training epoch 2
Validate: MSE 0.00347612 (0.00315941), PSNR 24.58904839 (25.11563808), SSIM 0.68400997 (0.76589682)
Finished validation.
Starting training epoch 3
Epoch: 3, MSE 0.00450543 (0.00478475), PSNR 23.46263885 (23.23030405), SSIM 0.65196908 (0.63256551)
Finished training epoch 3
Validate: MSE 0.00347902 (0.0

Epoch: 30, MSE 0.00272967 (0.00316869), PSNR 25.63890076 (25.08669084), SSIM 0.72998136 (0.73502305)
Finished training epoch 30
Validate: MSE 0.00491584 (0.00523442), PSNR 23.08402634 (22.99780280), SSIM 0.65221345 (0.74101894)
Finished validation.
Starting training epoch 31
Epoch: 31, MSE 0.00227824 (0.00300785), PSNR 26.42400169 (25.30361138), SSIM 0.75846899 (0.73958720)
Finished training epoch 31
Validate: MSE 0.00318456 (0.00315901), PSNR 24.96950912 (25.12829045), SSIM 0.69881940 (0.77220274)
Finished validation.
Starting training epoch 32
Epoch: 32, MSE 0.00324102 (0.00308032), PSNR 24.89318466 (25.21487730), SSIM 0.71907222 (0.73621447)
Finished training epoch 32
Validate: MSE 0.00375976 (0.00317989), PSNR 24.24839973 (25.02152720), SSIM 0.68356788 (0.75090792)
Finished validation.
Starting training epoch 33
Epoch: 33, MSE 0.00259180 (0.00310486), PSNR 25.86398506 (25.20414652), SSIM 0.74027544 (0.73120166)
Finished training epoch 33
Validate: MSE 0.00398657 (0.00399175), PSNR 

  return func(*args, **kwargs)


Validate: MSE 0.00625360 (0.00538084), PSNR 22.03869438 (22.73119887), SSIM 0.57452053 (0.65069530)
Finished validation.
Starting training epoch 45
Epoch: 45, MSE 0.00329165 (0.00279346), PSNR 24.82586098 (25.57829129), SSIM 0.70757931 (0.72364952)
Finished training epoch 45
Validate: MSE 0.00353227 (0.00312414), PSNR 24.51946259 (25.11822284), SSIM 0.69148719 (0.76643111)
Finished validation.
Starting training epoch 46
Epoch: 46, MSE 0.00566961 (0.00286844), PSNR 22.46446800 (25.47849873), SSIM 0.64809579 (0.71627903)
Finished training epoch 46
Validate: MSE 0.00330355 (0.00285622), PSNR 24.81018639 (25.50216593), SSIM 0.69109321 (0.75941125)
Finished validation.
Starting training epoch 47
Epoch: 47, MSE 0.00264774 (0.00296729), PSNR 25.77123833 (25.33102376), SSIM 0.73069650 (0.70720079)
Finished training epoch 47
Validate: MSE 0.00480194 (0.00451293), PSNR 23.18582916 (23.54180660), SSIM 0.58098727 (0.67236672)
Finished validation.
Starting training epoch 48
Epoch: 48, MSE 0.0030203

  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00421030 (0.00389420), PSNR 23.75686836 (24.16720003), SSIM 0.63516229 (0.70587471)
Finished validation.
Starting training epoch 52
Epoch: 52, MSE 0.00257769 (0.00286195), PSNR 25.88769913 (25.48967129), SSIM 0.74702668 (0.71346870)
Finished training epoch 52
Validate: MSE 0.00371660 (0.00374808), PSNR 24.29854202 (24.41392728), SSIM 0.67891896 (0.76157947)
Finished validation.
Starting training epoch 53
Epoch: 53, MSE 0.00272753 (0.00285619), PSNR 25.64230919 (25.49241180), SSIM 0.71951735 (0.71334162)
Finished training epoch 53
Validate: MSE 0.00443386 (0.00425480), PSNR 23.53217697 (23.86803735), SSIM 0.64092094 (0.73225364)
Finished validation.
Starting training epoch 54
Epoch: 54, MSE 0.00255415 (0.00309391), PSNR 25.92754173 (25.21787405), SSIM 0.73088092 (0.69930584)
Finished training epoch 54
Validate: MSE 0.00327130 (0.00304184), PSNR 24.85280228 (25.24090875), SSIM 0.67780638 (0.74425237)
Finished validation.
Starting training epoch 55
Epoch: 55, MSE 0.0024480

  return func(*args, **kwargs)


Validate: MSE 0.00432891 (0.00415763), PSNR 23.63621521 (23.91156106), SSIM 0.64908600 (0.71683247)
Finished validation.
Starting training epoch 58
Epoch: 58, MSE 0.00442132 (0.00287059), PSNR 23.54448128 (25.47665957), SSIM 0.63427067 (0.71061106)
Finished training epoch 58
Validate: MSE 0.01232142 (0.01152172), PSNR 19.09339142 (19.50985233), SSIM 0.38287854 (0.48032154)
Finished validation.
Starting training epoch 59
Epoch: 59, MSE 0.00251644 (0.00283008), PSNR 25.99213028 (25.53504206), SSIM 0.74783152 (0.71728591)
Finished training epoch 59
Validate: MSE 0.00358396 (0.00340280), PSNR 24.45636177 (24.75679400), SSIM 0.65636861 (0.72311698)
Finished validation.
Starting training epoch 60
Epoch: 60, MSE 0.00252408 (0.00276242), PSNR 25.97896957 (25.62318483), SSIM 0.72953308 (0.72241760)
Finished training epoch 60
Validate: MSE 0.00411506 (0.00394814), PSNR 23.85623932 (24.16768221), SSIM 0.65710127 (0.73162007)
Finished validation.
Starting training epoch 61
Epoch: 61, MSE 0.0031098

  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00438416 (0.00421287), PSNR 23.58113098 (23.83138667), SSIM 0.62770844 (0.69796178)
Finished validation.
Starting training epoch 67
Epoch: 67, MSE 0.00264960 (0.00277446), PSNR 25.76819992 (25.61366230), SSIM 0.73824465 (0.72013093)
Finished training epoch 67
Validate: MSE 0.00439697 (0.00392456), PSNR 23.56846237 (24.12420661), SSIM 0.62326962 (0.71676442)
Finished validation.
Starting training epoch 68
Epoch: 68, MSE 0.00220441 (0.00289203), PSNR 26.56707954 (25.44375712), SSIM 0.73879135 (0.71014350)
Finished training epoch 68
Validate: MSE 0.00378675 (0.00362869), PSNR 24.21733475 (24.52400433), SSIM 0.70389766 (0.77189869)
Finished validation.
Starting training epoch 69
Epoch: 69, MSE 0.00229010 (0.00293299), PSNR 26.40144730 (25.38725310), SSIM 0.75970256 (0.70582953)
Finished training epoch 69
Validate: MSE 0.00430383 (0.00427242), PSNR 23.66144562 (23.83161398), SSIM 0.64557612 (0.72021791)
Finished validation.
Starting training epoch 70
Epoch: 70, MSE 0.0029513

  return func(*args, **kwargs)


Validate: MSE 0.00451011 (0.00455612), PSNR 23.45812607 (23.50969118), SSIM 0.62518370 (0.69091481)
Finished validation.
Starting training epoch 93
Epoch: 93, MSE 0.00222169 (0.00279463), PSNR 26.53316689 (25.57627417), SSIM 0.74488342 (0.71410613)
Finished training epoch 93
Validate: MSE 0.00461475 (0.00448134), PSNR 23.35851479 (23.64438932), SSIM 0.66900659 (0.75426209)
Finished validation.
Starting training epoch 94
Epoch: 94, MSE 0.00267849 (0.00272864), PSNR 25.72109795 (25.67609880), SSIM 0.71152467 (0.72162334)
Finished training epoch 94
Validate: MSE 0.00354920 (0.00334454), PSNR 24.49869347 (24.86728414), SSIM 0.68109685 (0.75469035)
Finished validation.
Starting training epoch 95
Epoch: 95, MSE 0.00249028 (0.00275448), PSNR 26.03752136 (25.63597694), SSIM 0.72856694 (0.71697267)
Finished training epoch 95
Validate: MSE 0.00347534 (0.00322925), PSNR 24.59002686 (24.97864104), SSIM 0.66486931 (0.73195994)
Finished validation.
Starting training epoch 96
Epoch: 96, MSE 0.0027483

  return func(*args, **kwargs)


Validate: MSE 0.00425649 (0.00395499), PSNR 23.70947838 (24.11393748), SSIM 0.66165596 (0.72885111)
Finished validation.
Starting training epoch 98
Epoch: 98, MSE 0.00312168 (0.00272368), PSNR 25.05611229 (25.69053761), SSIM 0.72820312 (0.71640181)
Finished training epoch 98
Validate: MSE 0.00699580 (0.00643354), PSNR 21.55162621 (22.06528841), SSIM 0.56576520 (0.66415475)
Finished validation.
Starting training epoch 99
Epoch: 99, MSE 0.00230797 (0.00278656), PSNR 26.36769485 (25.58878180), SSIM 0.74579394 (0.71308321)
Finished training epoch 99
Validate: MSE 0.00498311 (0.00407572), PSNR 23.02499008 (23.95394687), SSIM 0.61327505 (0.71023369)
Finished validation.
Starting training epoch 100
Epoch: 100, MSE 0.00294548 (0.00283948), PSNR 25.30843925 (25.51981383), SSIM 0.69813287 (0.70868009)
Finished training epoch 100
Validate: MSE 0.00444497 (0.00369622), PSNR 23.52131271 (24.39336611), SSIM 0.62353855 (0.71851146)
Finished validation.
Starting training epoch 101
Epoch: 101, MSE 0.00

  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00736926 (0.00697535), PSNR 21.32576180 (21.62678948), SSIM 0.56677985 (0.64290021)
Finished validation.
Starting training epoch 102
Epoch: 102, MSE 0.00284406 (0.00272745), PSNR 25.46060753 (25.68141674), SSIM 0.72367924 (0.71462830)
Finished training epoch 102
Validate: MSE 0.00452520 (0.00404779), PSNR 23.44361877 (24.05495869), SSIM 0.63309324 (0.72523070)
Finished validation.
Starting training epoch 103
Epoch: 103, MSE 0.00274855 (0.00271512), PSNR 25.60896111 (25.69549546), SSIM 0.70817864 (0.71836528)
Finished training epoch 103


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00533776 (0.00550433), PSNR 22.72640800 (22.69389643), SSIM 0.60859418 (0.67256321)
Finished validation.
Starting training epoch 104
Epoch: 104, MSE 0.00305106 (0.00277217), PSNR 25.15549278 (25.61750340), SSIM 0.69674569 (0.70964836)
Finished training epoch 104
Validate: MSE 0.00567133 (0.00473872), PSNR 22.46315002 (23.29826387), SSIM 0.57144356 (0.63984838)
Finished validation.
Starting training epoch 105
Epoch: 105, MSE 0.00385482 (0.00270001), PSNR 24.13996315 (25.71556504), SSIM 0.69760096 (0.71868273)
Finished training epoch 105
Validate: MSE 0.00354288 (0.00325518), PSNR 24.50643539 (24.99586147), SSIM 0.68580461 (0.76680543)
Finished validation.
Starting training epoch 106
Epoch: 106, MSE 0.00242791 (0.00273581), PSNR 26.14767075 (25.66744743), SSIM 0.67722750 (0.71072705)
Finished training epoch 106
Validate: MSE 0.00406570 (0.00403845), PSNR 23.90864563 (24.10136379), SSIM 0.66060489 (0.74618029)
Finished validation.
Starting training epoch 107
Epoch: 107, MS

  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00559034 (0.00503273), PSNR 22.52561951 (23.04436721), SSIM 0.58985698 (0.66263219)
Finished validation.
Starting training epoch 108
Epoch: 108, MSE 0.00212732 (0.00273727), PSNR 26.72166061 (25.65710078), SSIM 0.74497145 (0.71193781)
Finished training epoch 108
Validate: MSE 0.00550916 (0.00569953), PSNR 22.58914566 (22.58882430), SSIM 0.65191853 (0.72755786)
Finished validation.
Starting training epoch 109
Epoch: 109, MSE 0.00238063 (0.00266969), PSNR 26.23307037 (25.76566264), SSIM 0.69408238 (0.72114597)
Finished training epoch 109


  return func(*args, **kwargs)


Validate: MSE 0.00560306 (0.00570453), PSNR 22.51574707 (22.54617302), SSIM 0.66856718 (0.73320717)
Finished validation.
Starting training epoch 110
Epoch: 110, MSE 0.00259410 (0.00272716), PSNR 25.86013412 (25.68109935), SSIM 0.72887433 (0.71287850)
Finished training epoch 110
Validate: MSE 0.00563707 (0.00566618), PSNR 22.48946571 (22.60006051), SSIM 0.57031131 (0.65431018)
Finished validation.
Starting training epoch 111
Epoch: 111, MSE 0.00263399 (0.00277243), PSNR 25.79385185 (25.61747363), SSIM 0.71489847 (0.70697447)
Finished training epoch 111
Validate: MSE 0.00518445 (0.00558885), PSNR 22.85296822 (22.67429507), SSIM 0.69135654 (0.76021302)
Finished validation.


<Figure size 432x288 with 0 Axes>

In [20]:
torch.save(model.state_dict(), f'{checkpoints}/last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')

In [21]:
# Validate
save_images = True
with torch.no_grad():
    validate(val_loader, model, criterion, save_images, -1)

Validate: MSE 0.00518445 (0.00558885), PSNR 22.85296822 (22.67429500), SSIM 0.69135642 (0.76021302)
Finished validation.


<Figure size 432x288 with 0 Axes>

In [22]:
# # Show images 
# image_pairs = []

# for i in range(10):
#     image_pairs.append((f'{color_imgs}img-{i}-epoch-{best_epoch}.jpg', f'{gray_imgs}img-{i}-epoch-{best_epoch}.jpg'))
    
# for c, g in image_pairs:
#   color = mpimg.imread(c)
#   gray  = mpimg.imread(g)
#   f, axarr = plt.subplots(1, 2)
#   f.set_size_inches(15, 15)
#   axarr[0].imshow(gray, cmap='gray')
#   axarr[1].imshow(color)
#   axarr[0].axis('off'), axarr[1].axis('off')
#   plt.show()