In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# For conversion
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchmetrics import MeanSquaredError, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from PIL import Image
# For utilities
import os, shutil, time

In [2]:
%load_ext tensorboard
%tensorboard --logdir=runs

In [3]:
# colab i kaggle jeszcze nie testowane
colab = False
kaggle = False
test_number = '12_2'

In [4]:
color_imgs = 'outputs/color/'
gray_imgs = 'outputs/gray/'
checkpoints = 'checkpoints'
if colab:
    from google.colab import drive
    drive.mount('/content/drive')
    dataset = '/content/drive/MyDrive/MGU/cifar10/'
    
    color_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{color_imgs}'
    gray_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{gray_imgs}'
    checkpoints = f'/content/drive/MyDrive/MGU/{test_number}/{checkpoints}'
elif kaggle:
    dataset = '/kaggle/input/cifar10/'
    
    color_imgs = f'{test_number}/{color_imgs}'
    gray_imgs = f'{test_number}/{gray_imgs}'
    checkpoints = f'{test_number}/{checkpoints}'
else:
    dataset = '../../datasets/cifar10/'

In [5]:
# Make folders and set parameters
os.makedirs(color_imgs, exist_ok=True)
os.makedirs(gray_imgs, exist_ok=True)
os.makedirs(checkpoints, exist_ok=True)
save_images = True
best_losses = [1e10, 1e10, 1e10]
best_epoch = -1
patience = 50
epochs = 500
batch_size = 128
SIZE = 32

In [6]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [7]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()
print(use_gpu)

True


In [8]:
class LabImageFolder(torch.utils.data.Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC),
                transforms.RandomCrop(SIZE),
                transforms.RandomHorizontalFlip(),  
            ])
        elif split == 'val':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC), 
                transforms.RandomCrop(SIZE), 
            ])
            
        self.split = split
        self.size = SIZE
        self.paths = [os.path.join(paths, file) for file in os.listdir(paths) if os.path.isfile(
            os.path.join(paths, file))]
        
        
    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("RGB")
        img_original = self.transforms(img)
        img_original = np.asarray(img_original)
        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        img_gray = rgb2gray(img_original)
        img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
        return img_gray, img_ab
    
    def __len__(self):
        return len(self.paths)

In [9]:
# Training
train_imagefolder = LabImageFolder(dataset + 'train')
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=batch_size, shuffle=True)
# Validation 
val_imagefolder = LabImageFolder(dataset + 'val' , 'val')
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=batch_size, shuffle=False)

In [10]:
kernel_size=3
stride_en=1
stride_de=1
padding=1
scale_factor=2
padding_mode='zeros'
channels_base = 512

class Autoencoder(nn.Module):
    def __init__(self, input_size=128):
        super(Autoencoder, self).__init__()

        self.conv1 = nn.Conv2d(1, channels_base, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        self.conv2 = nn.Conv2d(channels_base, channels_base * 2, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)

        self.convtrans1 = nn.ConvTranspose2d(channels_base * 2, channels_base, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans2 = nn.ConvTranspose2d(channels_base, channels_base // 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans3 = nn.ConvTranspose2d(channels_base // 2, 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)

        self.batchnorm1 = nn.BatchNorm2d(channels_base)
        self.batchnorm2 = nn.BatchNorm2d(channels_base * 2)
        self.batchnorm3 = nn.BatchNorm2d(channels_base)
        self.batchnorm4 = nn.BatchNorm2d(channels_base // 2)
        
    def forward(self, input):
        # encoder
        x = F.leaky_relu(self.batchnorm1(self.conv1(input)), negative_slope=0.1)
        x = F.avg_pool2d(x, kernel_size=3, stride=2, padding=1)
        x = F.leaky_relu(self.batchnorm2(self.conv2(x)), negative_slope=0.1)
        x = F.avg_pool2d(x, kernel_size=3, stride=2, padding=1)
        
        # decoder
        x = F.leaky_relu(self.batchnorm3(self.convtrans1(x)), negative_slope=0.1)
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.leaky_relu(self.batchnorm4(self.convtrans2(x)), negative_slope=0.1)
        x = F.interpolate(self.convtrans3(x), scale_factor=scale_factor)

        return x

In [11]:
model = Autoencoder()

In [12]:
criterion = [MeanSquaredError(), PeakSignalNoiseRatio(data_range=1.0), StructuralSimilarityIndexMeasure(data_range=1.0)]

In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [14]:
# # Move model and loss function to GPU
if use_gpu: 
    criterion = [criterion[0].to("cuda"), criterion[1].to("cuda"), criterion[2].to("cuda")]
    model = model.cuda()

In [15]:
if use_gpu: 
    from torchsummary import summary
    summary(model, (1, SIZE, SIZE))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 512, 32, 32]           5,120
       BatchNorm2d-2          [-1, 512, 32, 32]           1,024
            Conv2d-3         [-1, 1024, 16, 16]       4,719,616
       BatchNorm2d-4         [-1, 1024, 16, 16]           2,048
   ConvTranspose2d-5            [-1, 512, 8, 8]       4,719,104
       BatchNorm2d-6            [-1, 512, 8, 8]           1,024
   ConvTranspose2d-7          [-1, 256, 16, 16]       1,179,904
       BatchNorm2d-8          [-1, 256, 16, 16]             512
   ConvTranspose2d-9            [-1, 2, 16, 16]           4,610
Total params: 10,632,962
Trainable params: 10,632,962
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 13.50
Params size (MB): 40.56
Estimated Total Size (MB): 54.07
----------------------------------

In [16]:
class AverageMeter(object):
    '''A handy class from the PyTorch ImageNet tutorial''' 
    def __init__(self):
        self.reset()
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
       Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
    color_image = lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    if save_path is not None and save_name is not None: 
        plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
        plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

In [17]:
def validate(val_loader, model, criterion, save_images, epoch):
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]

    model.eval()
    already_saved_images = False
    for gray, ab in val_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()

        # Run model and record loss
        output_ab = model(gray) # throw away class predictions
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))

        # Save images to file
        if save_images and not already_saved_images:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)): # save at most 5 images
                save_path = {'grayscale': gray_imgs, 'colorized': color_imgs}
                save_name = 'img-{}-epoch-{}.jpg'.format(j, epoch)
                to_rgb(gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

    print(f'Validate: MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print('Finished validation.')
    if epoch >= 0:
        writer.add_scalar("MSE/test", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/test", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/test", _loss[2].avg, epoch)
    return _loss[0].avg, _loss[1].avg, _loss[2].avg

In [18]:
def train(train_loader, model, criterion, optimizer, epoch):
    print(f'Starting training epoch {epoch}')
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]
    
    model.train()

    for gray, ab in train_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()
            
        optimizer.zero_grad()

        output_ab = model(gray) 
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        loss[0].backward()
        optimizer.step()
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))
        
    print(f'Epoch: {epoch}, MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print(f'Finished training epoch {epoch}')
    if epoch >= 0:
        writer.add_scalar("MSE/train", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/train", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/train", _loss[2].avg, epoch)

In [19]:
# Train model
for epoch in range(epochs):
    # Train for one epoch, then validate
    train(train_loader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(val_loader, model, criterion, save_images, epoch)
    # Save checkpoint and replace old best model if current model is better
    if losses[0] < best_losses[0]:
        best_losses[0] = losses[0]
        best_epoch = epoch
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}.pth')
    if losses[1] < best_losses[1]:
        best_losses[1] = losses[1]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-PSNRLoss-{losses[1]:.8f}.pth')
    if losses[2] < best_losses[2]:
        best_losses[2] = losses[2]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-SSIMLoss-{losses[2]:.8f}.pth')
    
    if epoch - best_epoch >= patience and epoch >= 100:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}-early_stop.pth')
        break
    
    if epoch == epochs - 1:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')


Starting training epoch 0
Epoch: 0, MSE 0.02858108 (5.65832478), PSNR 15.43921280 (8.93980723), SSIM 0.10657670 (0.05411384)
Finished training epoch 0


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.03054378 (0.02794322), PSNR 15.15077209 (15.55997611), SSIM 0.08432081 (0.10807018)
Finished validation.
Starting training epoch 1
Epoch: 1, MSE 0.01511877 (0.01962274), PSNR 18.20483398 (17.17206833), SSIM 0.18303493 (0.14915802)
Finished training epoch 1


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.01525505 (0.01358815), PSNR 18.16586494 (18.69094517), SSIM 0.14700198 (0.18688398)
Finished validation.
Starting training epoch 2
Epoch: 2, MSE 0.01009373 (0.01132364), PSNR 19.95948029 (19.49165746), SSIM 0.22849782 (0.21970849)
Finished training epoch 2


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.01094548 (0.00971033), PSNR 19.60765266 (20.14082439), SSIM 0.20095310 (0.24750442)
Finished validation.
Starting training epoch 3
Epoch: 3, MSE 0.00813374 (0.00838369), PSNR 20.89709663 (20.78772081), SSIM 0.28138861 (0.27383651)
Finished training epoch 3


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00884703 (0.00756665), PSNR 20.53202438 (21.23545187), SSIM 0.24012280 (0.29702682)
Finished validation.
Starting training epoch 4
Epoch: 4, MSE 0.00739877 (0.00691792), PSNR 21.30840492 (21.61827820), SSIM 0.31824270 (0.31711515)
Finished training epoch 4


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00727465 (0.00641527), PSNR 21.38188171 (21.94625591), SSIM 0.28134650 (0.33620821)
Finished validation.
Starting training epoch 5
Epoch: 5, MSE 0.00720613 (0.00618423), PSNR 21.42297935 (22.11118513), SSIM 0.37092480 (0.35308502)
Finished training epoch 5


  return func(*args, **kwargs)


Validate: MSE 0.00664397 (0.00592414), PSNR 21.77572060 (22.31347506), SSIM 0.30788144 (0.37236849)
Finished validation.
Starting training epoch 6
Epoch: 6, MSE 0.00881453 (0.00624644), PSNR 20.54800987 (22.11017289), SSIM 0.37585852 (0.38089309)
Finished training epoch 6


  return func(*args, **kwargs)


Validate: MSE 0.00922541 (0.00817986), PSNR 20.35014343 (20.91369651), SSIM 0.30634475 (0.37876525)
Finished validation.
Starting training epoch 7
Epoch: 7, MSE 0.01852869 (0.00785155), PSNR 17.32155228 (21.55232789), SSIM 0.33212543 (0.39939237)
Finished training epoch 7


  return func(*args, **kwargs)


Validate: MSE 0.02162447 (0.02185578), PSNR 16.65054512 (16.63548684), SSIM 0.28392166 (0.34597655)
Finished validation.
Starting training epoch 8
Epoch: 8, MSE 0.00520881 (0.01022475), PSNR 22.83261871 (20.68607933), SSIM 0.42604336 (0.40927545)
Finished training epoch 8


  return func(*args, **kwargs)


Validate: MSE 0.00565814 (0.00512595), PSNR 22.47325897 (22.92541998), SSIM 0.38287240 (0.44870532)
Finished validation.
Starting training epoch 9
Epoch: 9, MSE 0.02992703 (0.01029452), PSNR 15.23936272 (20.56874382), SSIM 0.33274189 (0.42706727)
Finished training epoch 9
Validate: MSE 0.02303023 (0.03116031), PSNR 16.37701607 (15.09525177), SSIM 0.32065189 (0.35144751)
Finished validation.
Starting training epoch 10
Epoch: 10, MSE 0.01006711 (0.01143119), PSNR 19.97095299 (20.33056709), SSIM 0.47094736 (0.44009989)
Finished training epoch 10
Validate: MSE 0.01114194 (0.01229014), PSNR 19.53038979 (19.18085431), SSIM 0.36960632 (0.43337671)
Finished validation.
Starting training epoch 11
Epoch: 11, MSE 0.00647668 (0.00948027), PSNR 21.88647461 (21.19101279), SSIM 0.49951273 (0.47092323)
Finished training epoch 11


  return func(*args, **kwargs)


Validate: MSE 0.00712489 (0.00632582), PSNR 21.47221947 (22.01261869), SSIM 0.42079926 (0.49442928)
Finished validation.
Starting training epoch 12
Epoch: 12, MSE 0.01051524 (0.00862491), PSNR 19.78180504 (21.13921722), SSIM 0.44410172 (0.48539187)
Finished training epoch 12


  return func(*args, **kwargs)


Validate: MSE 0.00553602 (0.00515936), PSNR 22.56802177 (22.89606950), SSIM 0.45629799 (0.52414111)
Finished validation.
Starting training epoch 13
Epoch: 13, MSE 0.00755024 (0.00747630), PSNR 21.22039223 (21.63809623), SSIM 0.50791883 (0.50836898)
Finished training epoch 13
Validate: MSE 0.00572443 (0.00666982), PSNR 22.42267799 (21.77717753), SSIM 0.46872473 (0.52260174)
Finished validation.
Starting training epoch 14
Epoch: 14, MSE 0.00419644 (0.00732059), PSNR 23.77119064 (21.75501217), SSIM 0.57660496 (0.52527902)
Finished training epoch 14
Validate: MSE 0.00425630 (0.00413882), PSNR 23.70967865 (23.89067606), SSIM 0.49666604 (0.56870887)
Finished validation.
Starting training epoch 15
Epoch: 15, MSE 0.00340152 (0.00634669), PSNR 24.68326759 (22.28200511), SSIM 0.58619821 (0.54530444)
Finished training epoch 15
Validate: MSE 0.00409028 (0.00389309), PSNR 23.88246346 (24.14419532), SSIM 0.52339089 (0.59249459)
Finished validation.
Starting training epoch 16
Epoch: 16, MSE 0.0057456

  return func(*args, **kwargs)


Validate: MSE 0.00714262 (0.00739753), PSNR 21.46142578 (21.35645566), SSIM 0.49058458 (0.54881324)
Finished validation.
Starting training epoch 17
Epoch: 17, MSE 0.00417274 (0.00544769), PSNR 23.79578590 (22.90163091), SSIM 0.61041558 (0.58058396)
Finished training epoch 17


  return func(*args, **kwargs)


Validate: MSE 0.00459054 (0.00413850), PSNR 23.38136482 (23.86509373), SSIM 0.55101973 (0.61973855)
Finished validation.
Starting training epoch 18
Epoch: 18, MSE 0.00525288 (0.00520972), PSNR 22.79602051 (23.11833534), SSIM 0.58852690 (0.59781030)
Finished training epoch 18


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00683356 (0.00673391), PSNR 21.65353203 (21.74952275), SSIM 0.49376842 (0.56913093)
Finished validation.
Starting training epoch 19
Epoch: 19, MSE 0.00367970 (0.00475252), PSNR 24.34187698 (23.38855543), SSIM 0.63156486 (0.61380395)
Finished training epoch 19


  return func(*args, **kwargs)


Validate: MSE 0.00487777 (0.00453759), PSNR 23.11778831 (23.48589159), SSIM 0.55316019 (0.62616629)
Finished validation.
Starting training epoch 20
Epoch: 20, MSE 0.00524301 (0.00454417), PSNR 22.80419350 (23.54960614), SSIM 0.60954583 (0.62886446)
Finished training epoch 20
Validate: MSE 0.00640623 (0.00582543), PSNR 21.93397331 (22.40340490), SSIM 0.53998721 (0.61213523)
Finished validation.
Starting training epoch 21
Epoch: 21, MSE 0.00373289 (0.00384015), PSNR 24.27954865 (24.22609264), SSIM 0.66794264 (0.65326945)
Finished training epoch 21
Validate: MSE 0.00420914 (0.00370093), PSNR 23.75806427 (24.35240519), SSIM 0.60269821 (0.67062722)
Finished validation.
Starting training epoch 22
Epoch: 22, MSE 0.00374454 (0.00369370), PSNR 24.26601982 (24.39388534), SSIM 0.65046918 (0.66762037)
Finished training epoch 22
Validate: MSE 0.00415026 (0.00360335), PSNR 23.81924629 (24.48354671), SSIM 0.59519923 (0.67620932)
Finished validation.
Starting training epoch 23
Epoch: 23, MSE 0.0034550

  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00474118 (0.00406670), PSNR 23.24113655 (23.96244832), SSIM 0.62799299 (0.69984362)
Finished validation.
Starting training epoch 26
Epoch: 26, MSE 0.00321055 (0.00301428), PSNR 24.93420219 (25.24272123), SSIM 0.73240584 (0.72261278)
Finished training epoch 26
Validate: MSE 0.00855194 (0.00857477), PSNR 20.67935371 (20.71905896), SSIM 0.57207763 (0.63207033)
Finished validation.
Starting training epoch 27
Epoch: 27, MSE 0.00288238 (0.00283631), PSNR 25.40248680 (25.50615230), SSIM 0.77347010 (0.73591955)
Finished training epoch 27
Validate: MSE 0.00352338 (0.00379294), PSNR 24.53041077 (24.33077207), SSIM 0.65922898 (0.72887331)
Finished validation.
Starting training epoch 28
Epoch: 28, MSE 0.00329181 (0.00279114), PSNR 24.82564545 (25.57036095), SSIM 0.73284477 (0.74316322)
Finished training epoch 28
Validate: MSE 0.00388295 (0.00331312), PSNR 24.10837555 (24.86218895), SSIM 0.65925181 (0.73080270)
Finished validation.
Starting training epoch 29
Epoch: 29, MSE 0.0029345

  return func(*args, **kwargs)


Validate: MSE 0.01251063 (0.01655927), PSNR 19.02720833 (17.87922107), SSIM 0.55380827 (0.59211894)
Finished validation.
Starting training epoch 31
Epoch: 31, MSE 0.00299545 (0.00285763), PSNR 25.23537827 (25.47709685), SSIM 0.76296085 (0.75425888)
Finished training epoch 31


  return func(*args, **kwargs)


Validate: MSE 0.00342766 (0.00289734), PSNR 24.65001488 (25.44297367), SSIM 0.69295800 (0.75989280)
Finished validation.
Starting training epoch 32
Epoch: 32, MSE 0.00270856 (0.00283381), PSNR 25.67260933 (25.51655208), SSIM 0.76911497 (0.75695930)
Finished training epoch 32
Validate: MSE 0.00473454 (0.00551082), PSNR 23.24721909 (22.68098029), SSIM 0.65175390 (0.69621924)
Finished validation.
Starting training epoch 33
Epoch: 33, MSE 0.00275826 (0.00278783), PSNR 25.59365273 (25.57955425), SSIM 0.75085068 (0.75932768)
Finished training epoch 33
Validate: MSE 0.00317181 (0.00265197), PSNR 24.98692894 (25.83050909), SSIM 0.69720531 (0.76536020)
Finished validation.
Starting training epoch 34
Epoch: 34, MSE 0.00289981 (0.00275853), PSNR 25.37630272 (25.62587985), SSIM 0.76435339 (0.76158940)
Finished training epoch 34
Validate: MSE 0.00373870 (0.00408584), PSNR 24.27279091 (24.06547731), SSIM 0.67715883 (0.74859685)
Finished validation.
Starting training epoch 35
Epoch: 35, MSE 0.0026005

Validate: MSE 0.00273091 (0.00264295), PSNR 25.63692665 (25.84473232), SSIM 0.71924102 (0.77896737)
Finished validation.
Starting training epoch 62
Epoch: 62, MSE 0.00229629 (0.00235380), PSNR 26.38972664 (26.30896088), SSIM 0.77500153 (0.77967515)
Finished training epoch 62
Validate: MSE 0.00260248 (0.00233316), PSNR 25.84611893 (26.38419529), SSIM 0.72147036 (0.77888808)
Finished validation.
Starting training epoch 63
Epoch: 63, MSE 0.00267059 (0.00231210), PSNR 25.73392487 (26.38212170), SSIM 0.76878154 (0.78056599)
Finished training epoch 63
Validate: MSE 0.00282553 (0.00247917), PSNR 25.48899460 (26.12153637), SSIM 0.71661633 (0.78193716)
Finished validation.
Starting training epoch 64
Epoch: 64, MSE 0.00343134 (0.00230284), PSNR 24.64536285 (26.39707768), SSIM 0.75884616 (0.78064494)
Finished training epoch 64
Validate: MSE 0.00264254 (0.00244852), PSNR 25.77977753 (26.18405704), SSIM 0.71966052 (0.77549868)
Finished validation.
Starting training epoch 65
Epoch: 65, MSE 0.0021040

  return func(*args, **kwargs)


Validate: MSE 0.00268941 (0.00262155), PSNR 25.70342827 (25.85559079), SSIM 0.72026145 (0.77950286)
Finished validation.
Starting training epoch 80
Epoch: 80, MSE 0.00238644 (0.00223396), PSNR 26.22249413 (26.53277786), SSIM 0.76814330 (0.78227347)
Finished training epoch 80
Validate: MSE 0.00279326 (0.00236910), PSNR 25.53888130 (26.32571219), SSIM 0.71911216 (0.78197482)
Finished validation.
Starting training epoch 81
Epoch: 81, MSE 0.00207410 (0.00223409), PSNR 26.83169746 (26.53144527), SSIM 0.78770894 (0.78203711)
Finished training epoch 81
Validate: MSE 0.00275550 (0.00338943), PSNR 25.59799576 (24.89286531), SSIM 0.71378767 (0.76830961)
Finished validation.
Starting training epoch 82
Epoch: 82, MSE 0.00200070 (0.00222976), PSNR 26.98817444 (26.54227636), SSIM 0.77090919 (0.78214668)
Finished training epoch 82
Validate: MSE 0.00295232 (0.00253998), PSNR 25.29836273 (26.03293678), SSIM 0.71775764 (0.77621784)
Finished validation.
Starting training epoch 83
Epoch: 83, MSE 0.0022246

Validate: MSE 0.00261688 (0.00250636), PSNR 25.82215309 (26.09728273), SSIM 0.71940905 (0.77700566)
Finished validation.
Starting training epoch 110
Epoch: 110, MSE 0.00212661 (0.00213966), PSNR 26.72312164 (26.71709263), SSIM 0.79663438 (0.78323755)
Finished training epoch 110
Validate: MSE 0.00277756 (0.00231933), PSNR 25.56336212 (26.40843295), SSIM 0.72321939 (0.78204916)
Finished validation.
Starting training epoch 111
Epoch: 111, MSE 0.00260059 (0.00213758), PSNR 25.84927750 (26.72267074), SSIM 0.78291094 (0.78337463)
Finished training epoch 111
Validate: MSE 0.00270230 (0.00229533), PSNR 25.68266487 (26.45063047), SSIM 0.72127885 (0.77797752)
Finished validation.
Starting training epoch 112
Epoch: 112, MSE 0.00232605 (0.00212751), PSNR 26.33381271 (26.74349640), SSIM 0.78324199 (0.78323316)
Finished training epoch 112
Validate: MSE 0.00253494 (0.00221864), PSNR 25.96032333 (26.59554579), SSIM 0.71906269 (0.77901709)
Finished validation.
Starting training epoch 113
Epoch: 113, MS

Epoch: 139, MSE 0.00192080 (0.00201934), PSNR 27.16517639 (26.97037326), SSIM 0.77962017 (0.78283470)
Finished training epoch 139
Validate: MSE 0.00307084 (0.00293168), PSNR 25.12743187 (25.41668368), SSIM 0.69946277 (0.75788557)
Finished validation.
Starting training epoch 140
Epoch: 140, MSE 0.00211451 (0.00199524), PSNR 26.74790001 (27.01794977), SSIM 0.79333884 (0.78304822)
Finished training epoch 140
Validate: MSE 0.00271656 (0.00228107), PSNR 25.65980911 (26.48446988), SSIM 0.71670580 (0.77287681)
Finished validation.
Starting training epoch 141
Epoch: 141, MSE 0.00196972 (0.00200099), PSNR 27.05595779 (27.01052725), SSIM 0.78825772 (0.78286367)
Finished training epoch 141
Validate: MSE 0.00278885 (0.00281123), PSNR 25.54574966 (25.62039337), SSIM 0.70480776 (0.75716371)
Finished validation.
Starting training epoch 142
Epoch: 142, MSE 0.00188972 (0.00199600), PSNR 27.23603439 (27.01895954), SSIM 0.76709533 (0.78253129)
Finished training epoch 142
Validate: MSE 0.00294812 (0.00241

Validate: MSE 0.00270365 (0.00233850), PSNR 25.68049049 (26.37376708), SSIM 0.70993149 (0.76736728)
Finished validation.
Starting training epoch 169
Epoch: 169, MSE 0.00209411 (0.00186538), PSNR 26.79000282 (27.31198046), SSIM 0.78460729 (0.78258931)
Finished training epoch 169
Validate: MSE 0.00288977 (0.00239783), PSNR 25.39136314 (26.25794886), SSIM 0.68653071 (0.76197111)
Finished validation.
Starting training epoch 170
Epoch: 170, MSE 0.00211194 (0.00185607), PSNR 26.75318146 (27.33417258), SSIM 0.75826633 (0.78266329)
Finished training epoch 170
Validate: MSE 0.00263840 (0.00309639), PSNR 25.78658295 (25.15980431), SSIM 0.70993078 (0.76032078)
Finished validation.
Starting training epoch 171
Epoch: 171, MSE 0.00201824 (0.00188749), PSNR 26.95027542 (27.26345342), SSIM 0.77450782 (0.78175957)
Finished training epoch 171
Validate: MSE 0.00259488 (0.00232265), PSNR 25.85881996 (26.40144259), SSIM 0.70674664 (0.76320868)
Finished validation.
Starting training epoch 172
Epoch: 172, MS

Epoch: 198, MSE 0.00165182 (0.00172438), PSNR 27.82035828 (27.64940628), SSIM 0.79060400 (0.78316548)
Finished training epoch 198
Validate: MSE 0.00265272 (0.00240146), PSNR 25.76308250 (26.25234928), SSIM 0.69487977 (0.75721569)
Finished validation.
Starting training epoch 199
Epoch: 199, MSE 0.00170829 (0.00173166), PSNR 27.67438316 (27.63229904), SSIM 0.77854097 (0.78312127)
Finished training epoch 199
Validate: MSE 0.00274528 (0.00235277), PSNR 25.61413574 (26.34373093), SSIM 0.69921041 (0.76140056)
Finished validation.


<Figure size 432x288 with 0 Axes>

In [20]:
torch.save(model.state_dict(), f'{checkpoints}/last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')

In [21]:
# Validate
save_images = True
with torch.no_grad():
    validate(val_loader, model, criterion, save_images, -1)

Validate: MSE 0.00274528 (0.00235277), PSNR 25.61413574 (26.34373093), SSIM 0.69921052 (0.76140055)
Finished validation.


<Figure size 432x288 with 0 Axes>

In [22]:
# # Show images 
# image_pairs = []

# for i in range(10):
#     image_pairs.append((f'{color_imgs}img-{i}-epoch-{best_epoch}.jpg', f'{gray_imgs}img-{i}-epoch-{best_epoch}.jpg'))
    
# for c, g in image_pairs:
#   color = mpimg.imread(c)
#   gray  = mpimg.imread(g)
#   f, axarr = plt.subplots(1, 2)
#   f.set_size_inches(15, 15)
#   axarr[0].imshow(gray, cmap='gray')
#   axarr[1].imshow(color)
#   axarr[0].axis('off'), axarr[1].axis('off')
#   plt.show()