In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# For conversion
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchmetrics import MeanSquaredError, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from PIL import Image
# For utilities
import os, shutil, time

In [2]:
# colab i kaggle jeszcze nie testowane
colab = False
kaggle = False
test_number = '11_1'

In [3]:
color_imgs = 'outputs/color/'
gray_imgs = 'outputs/gray/'
checkpoints = 'checkpoints'
if colab:
    from google.colab import drive
    drive.mount('/content/drive')
    dataset = '/content/drive/MyDrive/MGU/cifar10/'
    
    color_imgs = '/content/drive/MyDrive/MGU/outputs/color/'
    gray_imgs = '/content/drive/MyDrive/MGU/outputs/gray/'
    checkpoints = '/content/drive/MyDrive/MGU/checkpoints'
elif kaggle:
    os.makedirs(test_number, exist_ok=True)
    results = "results"
    os.makedirs(results, exist_ok=True)
    dataset = '/kaggle/input/cifar10/'
else:
    dataset = '../../datasets/cifar10/'

In [4]:
%load_ext tensorboard
%tensorboard --logdir=runs

In [5]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [6]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()
print(use_gpu)

True


In [7]:
SIZE = 32
class LabImageFolder(torch.utils.data.Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC),
                transforms.RandomCrop(SIZE),
                transforms.RandomHorizontalFlip(),  
                transforms.ToTensor(),
                transforms.Normalize((0.4918019, 0.48263696, 0.44733968), (0.24705184, 0.2433789, 0.26162848)),
                transforms.ToPILImage(),
            ])
        elif split == 'val':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC), 
                transforms.RandomCrop(SIZE), 
                transforms.ToTensor(),
                transforms.Normalize((0.4918019, 0.48263696, 0.44733968), (0.24705184, 0.2433789, 0.26162848)),
                transforms.ToPILImage(),
            ])
            
        self.split = split
        self.size = SIZE
        self.paths = [os.path.join(paths, file) for file in os.listdir(paths) if os.path.isfile(
            os.path.join(paths, file))]
        
        
    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("RGB")
        img_original = self.transforms(img)
        img_original = np.asarray(img_original)
        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        img_gray = rgb2gray(img_original)
        img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
        return img_gray, img_ab
    
    def __len__(self):
        return len(self.paths)

In [8]:
# Training
batch_size = 128
train_imagefolder = LabImageFolder(dataset+'train')
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=batch_size, shuffle=True)
# Validation 
val_imagefolder = LabImageFolder(dataset+'val' , 'val')
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=batch_size, shuffle=False)

In [9]:
kernel_size=3
stride_en=2
stride_de=1
padding=1
scale_factor=2
padding_mode='zeros'
channels_base = 32

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        self.conv1 = nn.Conv2d(1, channels_base, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        self.conv2 = nn.Conv2d(channels_base, channels_base * 2, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        
        self.convtrans1 = nn.ConvTranspose2d(channels_base * 2, channels_base, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans2 = nn.ConvTranspose2d(channels_base, channels_base // 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans3 = nn.ConvTranspose2d(channels_base // 2, 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)

        self.batchnorm1 = nn.BatchNorm2d(channels_base // 2)
        self.batchnorm2 = nn.BatchNorm2d(channels_base)
        self.batchnorm3 = nn.BatchNorm2d(channels_base * 2)
        
        
    def forward(self, input):
        # encoder
        x = F.relu(self.batchnorm2(self.conv1(input)))
        x = F.relu(self.batchnorm3(self.conv2(x)))
        
        # decoder
        x = F.relu(self.batchnorm2(self.convtrans1(x)))
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.relu(self.batchnorm1(self.convtrans2(x)))
        x = F.interpolate(self.convtrans3(x), scale_factor=scale_factor)

        return x

In [10]:
model = Autoencoder()

In [11]:
criterion = [MeanSquaredError(), PeakSignalNoiseRatio(data_range=1.0), StructuralSimilarityIndexMeasure(data_range=1.0)]

In [12]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [13]:
# # Move model and loss function to GPU
if use_gpu: 
    criterion = [criterion[0].to("cuda"), criterion[1].to("cuda"), criterion[2].to("cuda")]
    model = model.cuda()

In [14]:
if use_gpu: 
    from torchsummary import summary
    summary(model, (1, SIZE, SIZE))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 16, 16]             320
       BatchNorm2d-2           [-1, 32, 16, 16]              64
            Conv2d-3             [-1, 64, 8, 8]          18,496
       BatchNorm2d-4             [-1, 64, 8, 8]             128
   ConvTranspose2d-5             [-1, 32, 8, 8]          18,464
       BatchNorm2d-6             [-1, 32, 8, 8]              64
   ConvTranspose2d-7           [-1, 16, 16, 16]           4,624
       BatchNorm2d-8           [-1, 16, 16, 16]              32
   ConvTranspose2d-9            [-1, 2, 16, 16]             290
Total params: 42,482
Trainable params: 42,482
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.29
Params size (MB): 0.16
Estimated Total Size (MB): 0.45
---------------------------------------------

In [15]:
class AverageMeter(object):
    '''A handy class from the PyTorch ImageNet tutorial''' 
    def __init__(self):
        self.reset()
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
       Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
    color_image = lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    if save_path is not None and save_name is not None: 
        plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
        plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

In [16]:
def validate(val_loader, model, criterion, save_images, epoch):
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]

    model.eval()
    already_saved_images = False
    for gray, ab in val_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()

        # Run model and record loss
        output_ab = model(gray) # throw away class predictions
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))

        # Save images to file
        if save_images and not already_saved_images:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)): # save at most 5 images
                save_path = {'grayscale': gray_imgs, 'colorized': color_imgs}
                save_name = 'img-{}-epoch-{}.jpg'.format(j, epoch)
                to_rgb(gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

    print(f'Validate: MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print('Finished validation.')
    if epoch >= 0:        
#         writer.add_scalars(f'loss/test', {
#             'MSE': _loss[0].avg,
#             'PSNR': _loss[1].avg,
#             'SSIM': _loss[2].avg,
#         }, epoch)
        writer.add_scalar("MSE/test", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/test", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/test", _loss[2].avg, epoch)
    return _loss[0].avg, _loss[1].avg, _loss[2].avg

In [17]:
def train(train_loader, model, criterion, optimizer, epoch):
    print(f'Starting training epoch {epoch}')
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]
    
    model.train()

    for gray, ab in train_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()
            
        optimizer.zero_grad()

        output_ab = model(gray) 
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        loss[0].backward()
        optimizer.step()
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))
        
    print(f'Epoch: {epoch}, MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print(f'Finished training epoch {epoch}')
    if epoch >= 0:
#         writer.add_scalars(f'loss/train', {
#             'MSE': _loss[0].avg,
#             'PSNR': _loss[1].avg,
#             'SSIM': _loss[2].avg,
#         }, epoch)
        writer.add_scalar("MSE/train", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/train", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/train", _loss[2].avg, epoch)

In [18]:
# Make folders and set parameters
os.makedirs(color_imgs, exist_ok=True)
os.makedirs(gray_imgs, exist_ok=True)
os.makedirs(checkpoints, exist_ok=True)
save_images = True
best_losses = [1e10, 1e10, 1e10]
best_epoch = -1
patience = 50
epochs = 500

In [19]:
# Train model
for epoch in range(epochs):
    # Train for one epoch, then validate
    train(train_loader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(val_loader, model, criterion, save_images, epoch)
    # Save checkpoint and replace old best model if current model is better
    if losses[0] < best_losses[0]:
        best_losses[0] = losses[0]
        best_epoch = epoch
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}.pth')
    if losses[1] < best_losses[1]:
        best_losses[1] = losses[1]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-PSNRLoss-{losses[1]:.8f}.pth')
    if losses[2] < best_losses[2]:
        best_losses[2] = losses[2]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-SSIMLoss-{losses[2]:.8f}.pth')
    
    if epoch - best_epoch >= patience:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}-early_stop.pth')
        break
    
    if epoch == epochs - 1:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')


Starting training epoch 0
Epoch: 0, MSE 0.02566485 (0.03873827), PSNR 15.90661144 (15.35927671), SSIM 0.10305224 (0.08811778)
Finished training epoch 0


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.03201771 (0.03400757), PSNR 14.94609642 (14.68681052), SSIM 0.06716947 (0.08486122)
Finished validation.
Starting training epoch 1
Epoch: 1, MSE 0.02353197 (0.02384786), PSNR 16.28341675 (16.22734642), SSIM 0.13987377 (0.12677702)
Finished training epoch 1
Validate: MSE 0.02520782 (0.02688161), PSNR 15.98464584 (15.71029350), SSIM 0.07523306 (0.09620079)
Finished validation.
Starting training epoch 2
Epoch: 2, MSE 0.02120981 (0.02316422), PSNR 16.73463249 (16.35308281), SSIM 0.17143460 (0.15204800)
Finished training epoch 2
Validate: MSE 0.02503398 (0.02671190), PSNR 16.01469994 (15.73808901), SSIM 0.07767942 (0.09736707)
Finished validation.
Starting training epoch 3
Epoch: 3, MSE 0.02360381 (0.02275074), PSNR 16.27017784 (16.43143164), SSIM 0.17178890 (0.17229363)
Finished training epoch 3
Validate: MSE 0.02461858 (0.02627271), PSNR 16.08736992 (15.81008468), SSIM 0.07871528 (0.09845925)
Finished validation.
Starting training epoch 4
Epoch: 4, MSE 0.02286714 (0.022500

  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.02594925 (0.02941873), PSNR 15.85875130 (15.32038445), SSIM 0.07917213 (0.09475616)
Finished validation.
Starting training epoch 31
Epoch: 31, MSE 0.02165746 (0.02124013), PSNR 16.64392471 (16.72973233), SSIM 0.22991347 (0.22480078)
Finished training epoch 31


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.02704304 (0.02981466), PSNR 15.67944431 (15.26004689), SSIM 0.08233679 (0.09990387)
Finished validation.
Starting training epoch 32
Epoch: 32, MSE 0.02104961 (0.02123196), PSNR 16.76755714 (16.73118348), SSIM 0.22873756 (0.22539796)
Finished training epoch 32


  return func(*args, **kwargs)


Validate: MSE 0.02581636 (0.02861374), PSNR 15.88104916 (15.43911388), SSIM 0.08262496 (0.09967734)
Finished validation.
Starting training epoch 33
Epoch: 33, MSE 0.02229561 (0.02120021), PSNR 16.51780510 (16.73778693), SSIM 0.22425231 (0.22586003)
Finished training epoch 33


  return func(*args, **kwargs)


Validate: MSE 0.02464824 (0.02711694), PSNR 16.08213997 (15.67231243), SSIM 0.08216868 (0.10073699)
Finished validation.
Starting training epoch 34
Epoch: 34, MSE 0.02269872 (0.02118256), PSNR 16.43998528 (16.74141876), SSIM 0.21956141 (0.22628816)
Finished training epoch 34


  return func(*args, **kwargs)


Validate: MSE 0.02392060 (0.02625977), PSNR 16.21227837 (15.81173759), SSIM 0.08119638 (0.10131671)
Finished validation.
Starting training epoch 35
Epoch: 35, MSE 0.02169269 (0.02118601), PSNR 16.63686371 (16.74055224), SSIM 0.22552219 (0.22627860)
Finished training epoch 35


  return func(*args, **kwargs)


Validate: MSE 0.02501258 (0.02750171), PSNR 16.01841545 (15.61076291), SSIM 0.08420247 (0.10374346)
Finished validation.
Starting training epoch 36
Epoch: 36, MSE 0.02086696 (0.02117318), PSNR 16.80540848 (16.74336997), SSIM 0.21834640 (0.22659716)
Finished training epoch 36


  return func(*args, **kwargs)


Validate: MSE 0.02390341 (0.02608036), PSNR 16.21540070 (15.84142129), SSIM 0.08148855 (0.10196694)
Finished validation.
Starting training epoch 37
Epoch: 37, MSE 0.02101424 (0.02117445), PSNR 16.77486229 (16.74309955), SSIM 0.23492730 (0.22652446)
Finished training epoch 37


  return func(*args, **kwargs)


Validate: MSE 0.02445424 (0.02723936), PSNR 16.11645889 (15.65349503), SSIM 0.08325971 (0.10205899)
Finished validation.
Starting training epoch 38
Epoch: 38, MSE 0.02056172 (0.02115179), PSNR 16.86940384 (16.74792079), SSIM 0.24071249 (0.22699692)
Finished training epoch 38


  return func(*args, **kwargs)


Validate: MSE 0.02475644 (0.02734844), PSNR 16.06311798 (15.63546675), SSIM 0.08194718 (0.10141085)
Finished validation.
Starting training epoch 39
Epoch: 39, MSE 0.02149695 (0.02114750), PSNR 16.67623138 (16.74865981), SSIM 0.22443871 (0.22730020)
Finished training epoch 39


  return func(*args, **kwargs)


Validate: MSE 0.02422091 (0.02646183), PSNR 16.15809441 (15.77827353), SSIM 0.08628722 (0.10623278)
Finished validation.
Starting training epoch 40
Epoch: 40, MSE 0.02070933 (0.02112498), PSNR 16.83833885 (16.75323931), SSIM 0.23918435 (0.22754211)
Finished training epoch 40


  return func(*args, **kwargs)


Validate: MSE 0.02417253 (0.02635364), PSNR 16.16677856 (15.79620945), SSIM 0.08160213 (0.10219303)
Finished validation.
Starting training epoch 41
Epoch: 41, MSE 0.02156560 (0.02112407), PSNR 16.66238213 (16.75348259), SSIM 0.21608813 (0.22773132)
Finished training epoch 41


  return func(*args, **kwargs)


Validate: MSE 0.02420137 (0.02628004), PSNR 16.16160011 (15.80826197), SSIM 0.08378289 (0.10445627)
Finished validation.
Starting training epoch 42
Epoch: 42, MSE 0.02152198 (0.02112378), PSNR 16.67117882 (16.75355166), SSIM 0.22544375 (0.22776445)
Finished training epoch 42


  return func(*args, **kwargs)


Validate: MSE 0.02362075 (0.02592947), PSNR 16.26706314 (15.86702534), SSIM 0.07841138 (0.09939863)
Finished validation.
Starting training epoch 43
Epoch: 43, MSE 0.02076984 (0.02112350), PSNR 16.82566833 (16.75358758), SSIM 0.21899395 (0.22780933)
Finished training epoch 43


  return func(*args, **kwargs)


Validate: MSE 0.02415203 (0.02681124), PSNR 16.17046356 (15.72269033), SSIM 0.07880256 (0.09788892)
Finished validation.
Starting training epoch 44
Epoch: 44, MSE 0.02099945 (0.02110973), PSNR 16.77791977 (16.75646470), SSIM 0.23230025 (0.22811946)
Finished training epoch 44


  return func(*args, **kwargs)


Validate: MSE 0.02447561 (0.02700225), PSNR 16.11266327 (15.69105117), SSIM 0.07885429 (0.09864960)
Finished validation.
Starting training epoch 45
Epoch: 45, MSE 0.02219500 (0.02109814), PSNR 16.53744698 (16.75879230), SSIM 0.22536902 (0.22818264)
Finished training epoch 45


  return func(*args, **kwargs)


Validate: MSE 0.02339380 (0.02544742), PSNR 16.30899239 (15.94862471), SSIM 0.08554246 (0.10522382)
Finished validation.
Starting training epoch 46
Epoch: 46, MSE 0.02168968 (0.02109152), PSNR 16.63746834 (16.76012383), SSIM 0.23307824 (0.22843000)
Finished training epoch 46


  return func(*args, **kwargs)


Validate: MSE 0.02332047 (0.02515598), PSNR 16.32262611 (15.99890248), SSIM 0.07544699 (0.09765058)
Finished validation.
Starting training epoch 47
Epoch: 47, MSE 0.02176658 (0.02109106), PSNR 16.62209702 (16.76017293), SSIM 0.22649534 (0.22841734)
Finished training epoch 47


  return func(*args, **kwargs)


Validate: MSE 0.02350503 (0.02559277), PSNR 16.28839111 (15.92379647), SSIM 0.07413311 (0.09574347)
Finished validation.
Starting training epoch 48
Epoch: 48, MSE 0.02041404 (0.02108178), PSNR 16.90071106 (16.76202361), SSIM 0.25741792 (0.22863854)
Finished training epoch 48


  return func(*args, **kwargs)


Validate: MSE 0.02367333 (0.02564347), PSNR 16.25740623 (15.91502849), SSIM 0.08203059 (0.10424026)
Finished validation.
Starting training epoch 49
Epoch: 49, MSE 0.02052274 (0.02106960), PSNR 16.87764549 (16.76475768), SSIM 0.23943552 (0.22878880)
Finished training epoch 49


  return func(*args, **kwargs)


Validate: MSE 0.02436030 (0.02690299), PSNR 16.13317299 (15.70713068), SSIM 0.07886285 (0.09898962)
Finished validation.
Starting training epoch 50
Epoch: 50, MSE 0.02127918 (0.02107689), PSNR 16.72045135 (16.76318610), SSIM 0.23487425 (0.22883900)
Finished training epoch 50


  return func(*args, **kwargs)


Validate: MSE 0.02381973 (0.02598117), PSNR 16.23063087 (15.85827000), SSIM 0.07892744 (0.09992226)
Finished validation.
Starting training epoch 51
Epoch: 51, MSE 0.02144646 (0.02107905), PSNR 16.68644142 (16.76273633), SSIM 0.22164920 (0.22867572)
Finished training epoch 51


  return func(*args, **kwargs)


Validate: MSE 0.02408200 (0.02598166), PSNR 16.18307304 (15.85807700), SSIM 0.08052842 (0.10313955)
Finished validation.
Starting training epoch 52
Epoch: 52, MSE 0.02219539 (0.02106992), PSNR 16.53737068 (16.76447171), SSIM 0.23276190 (0.22895311)
Finished training epoch 52


  return func(*args, **kwargs)


Validate: MSE 0.02362441 (0.02552225), PSNR 16.26638794 (15.93611489), SSIM 0.07574772 (0.09702773)
Finished validation.
Starting training epoch 53
Epoch: 53, MSE 0.02109775 (0.02107301), PSNR 16.75763702 (16.76392824), SSIM 0.22355235 (0.22899586)
Finished training epoch 53


  return func(*args, **kwargs)


Validate: MSE 0.02407856 (0.02640469), PSNR 16.18369484 (15.78831973), SSIM 0.07759871 (0.09763420)
Finished validation.
Starting training epoch 54
Epoch: 54, MSE 0.02027436 (0.02103482), PSNR 16.93052864 (16.77183194), SSIM 0.22303787 (0.22955822)
Finished training epoch 54


  return func(*args, **kwargs)


Validate: MSE 0.02431004 (0.02657746), PSNR 16.14214134 (15.76016684), SSIM 0.07283048 (0.09422694)
Finished validation.
Starting training epoch 55
Epoch: 55, MSE 0.02122312 (0.02104407), PSNR 16.73190689 (16.76978386), SSIM 0.23668654 (0.22946450)
Finished training epoch 55


  return func(*args, **kwargs)


Validate: MSE 0.02403568 (0.02659027), PSNR 16.19143486 (15.75831374), SSIM 0.07815764 (0.09815897)
Finished validation.
Starting training epoch 56
Epoch: 56, MSE 0.02012441 (0.02104795), PSNR 16.96276855 (16.76905125), SSIM 0.22720215 (0.22936639)
Finished training epoch 56


  return func(*args, **kwargs)


Validate: MSE 0.02416517 (0.02625642), PSNR 16.16810036 (15.81262411), SSIM 0.07664224 (0.09819770)
Finished validation.
Starting training epoch 57
Epoch: 57, MSE 0.02099045 (0.02105420), PSNR 16.77978134 (16.76792190), SSIM 0.22668676 (0.22945151)
Finished training epoch 57
Validate: MSE 0.02426228 (0.02645759), PSNR 16.15068436 (15.77980975), SSIM 0.07864384 (0.09928114)
Finished validation.
Starting training epoch 58
Epoch: 58, MSE 0.02117386 (0.02103473), PSNR 16.74199867 (16.77184128), SSIM 0.22842216 (0.22971015)
Finished training epoch 58


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.02484144 (0.02688950), PSNR 16.04823303 (15.70877227), SSIM 0.07858576 (0.10009469)
Finished validation.
Starting training epoch 59
Epoch: 59, MSE 0.02188213 (0.02103624), PSNR 16.59910393 (16.77154317), SSIM 0.22334018 (0.22966958)
Finished training epoch 59
Validate: MSE 0.02423912 (0.02610137), PSNR 16.15483093 (15.83856118), SSIM 0.07593790 (0.09816191)
Finished validation.
Starting training epoch 60
Epoch: 60, MSE 0.02062809 (0.02103897), PSNR 16.85540962 (16.77089665), SSIM 0.23947458 (0.22976989)
Finished training epoch 60


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.02452396 (0.02688938), PSNR 16.10409355 (15.70926002), SSIM 0.07610521 (0.09645122)
Finished validation.
Starting training epoch 61
Epoch: 61, MSE 0.02081448 (0.02103821), PSNR 16.81634331 (16.77109528), SSIM 0.24042431 (0.22957757)
Finished training epoch 61


  return func(*args, **kwargs)


Validate: MSE 0.02700669 (0.02978562), PSNR 15.68528557 (15.26516653), SSIM 0.07521623 (0.09350903)
Finished validation.
Starting training epoch 62
Epoch: 62, MSE 0.02212720 (0.02101697), PSNR 16.55073547 (16.77553496), SSIM 0.20924366 (0.23016480)
Finished training epoch 62


  return func(*args, **kwargs)


Validate: MSE 0.02538322 (0.02797948), PSNR 15.95453358 (15.53700650), SSIM 0.07765365 (0.09695007)
Finished validation.
Starting training epoch 63
Epoch: 63, MSE 0.02138253 (0.02102846), PSNR 16.69940758 (16.77322264), SSIM 0.22982010 (0.23005176)
Finished training epoch 63


  return func(*args, **kwargs)


Validate: MSE 0.02640006 (0.02884952), PSNR 15.78394985 (15.40316702), SSIM 0.07128498 (0.09161545)
Finished validation.
Starting training epoch 64
Epoch: 64, MSE 0.02037482 (0.02101873), PSNR 16.90906143 (16.77512522), SSIM 0.24839453 (0.23007349)
Finished training epoch 64


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.02733078 (0.03063387), PSNR 15.63347912 (15.14365759), SSIM 0.07469243 (0.09206939)
Finished validation.
Starting training epoch 65
Epoch: 65, MSE 0.02085571 (0.02101094), PSNR 16.80775070 (16.77682246), SSIM 0.22549777 (0.23025555)
Finished training epoch 65


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.02755257 (0.03009762), PSNR 15.59837818 (15.21935607), SSIM 0.06944743 (0.08879074)
Finished validation.
Starting training epoch 66
Epoch: 66, MSE 0.02249884 (0.02101668), PSNR 16.47839737 (16.77569932), SSIM 0.22269008 (0.23012692)
Finished training epoch 66


  return func(*args, **kwargs)


Validate: MSE 0.02589822 (0.02909548), PSNR 15.86730099 (15.36793197), SSIM 0.07327515 (0.08916299)
Finished validation.


<Figure size 432x288 with 0 Axes>

In [20]:
torch.save(model.state_dict(), f'{checkpoints}/last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')

In [21]:
# Validate
save_images = True
with torch.no_grad():
    validate(val_loader, model, criterion, save_images, -1)

Validate: MSE 0.02589822 (0.02909548), PSNR 15.86730099 (15.36793197), SSIM 0.07327514 (0.08916299)
Finished validation.


<Figure size 432x288 with 0 Axes>

In [22]:
# # Show images 
# image_pairs = []

# for i in range(10):
#     image_pairs.append((f'{color_imgs}img-{i}-epoch-{best_epoch}.jpg', f'{gray_imgs}img-{i}-epoch-{best_epoch}.jpg'))
    
# for c, g in image_pairs:
#   color = mpimg.imread(c)
#   gray  = mpimg.imread(g)
#   f, axarr = plt.subplots(1, 2)
#   f.set_size_inches(15, 15)
#   axarr[0].imshow(gray, cmap='gray')
#   axarr[1].imshow(color)
#   axarr[0].axis('off'), axarr[1].axis('off')
#   plt.show()