In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# For conversion
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchmetrics import MeanSquaredError, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from PIL import Image
# For utilities
import os, shutil, time

In [2]:
%load_ext tensorboard
%tensorboard --logdir=runs

In [3]:
# colab i kaggle jeszcze nie testowane
colab = False
kaggle = False
test_number = '12_2'

In [4]:
color_imgs = 'outputs/color/'
gray_imgs = 'outputs/gray/'
checkpoints = 'checkpoints'
if colab:
    from google.colab import drive
    drive.mount('/content/drive')
    dataset = '/content/drive/MyDrive/MGU/cifar10/'
    
    color_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{color_imgs}'
    gray_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{gray_imgs}'
    checkpoints = f'/content/drive/MyDrive/MGU/{test_number}/{checkpoints}'
elif kaggle:
    dataset = '/kaggle/input/cifar10/'
    
    color_imgs = f'{test_number}/{color_imgs}'
    gray_imgs = f'{test_number}/{gray_imgs}'
    checkpoints = f'{test_number}/{checkpoints}'
else:
    dataset = '../../datasets/cifar10/'

In [5]:
# Make folders and set parameters
os.makedirs(color_imgs, exist_ok=True)
os.makedirs(gray_imgs, exist_ok=True)
os.makedirs(checkpoints, exist_ok=True)
save_images = True
best_losses = [1e10, 1e10, 1e10]
best_epoch = -1
patience = 50
epochs = 500
batch_size = 128
SIZE = 32

In [6]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [7]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()
print(use_gpu)

True


In [8]:
class LabImageFolder(torch.utils.data.Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC),
                transforms.RandomCrop(SIZE),
                transforms.RandomHorizontalFlip(),  
            ])
        elif split == 'val':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC), 
                transforms.RandomCrop(SIZE), 
            ])
            
        self.split = split
        self.size = SIZE
        self.paths = [os.path.join(paths, file) for file in os.listdir(paths) if os.path.isfile(
            os.path.join(paths, file))]
        
        
    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("RGB")
        img_original = self.transforms(img)
        img_original = np.asarray(img_original)
        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        img_gray = rgb2gray(img_original)
        img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
        return img_gray, img_ab
    
    def __len__(self):
        return len(self.paths)

In [9]:
# Training
train_imagefolder = LabImageFolder(dataset + 'train')
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=batch_size, shuffle=True)
# Validation 
val_imagefolder = LabImageFolder(dataset + 'val' , 'val')
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=batch_size, shuffle=False)

In [10]:
kernel_size=3
stride_en=1
stride_de=1
padding=1
scale_factor=2
padding_mode='zeros'
channels_base = 256
p1 = .25

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        self.conv1 = nn.Conv2d(1, channels_base, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        self.conv2 = nn.Conv2d(channels_base, channels_base * 2, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)

        self.convtrans1 = nn.ConvTranspose2d(channels_base * 2, channels_base, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans2 = nn.ConvTranspose2d(channels_base, channels_base // 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans3 = nn.ConvTranspose2d(channels_base // 2, 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)

        self.batchnorm1 = nn.BatchNorm2d(channels_base)
        self.batchnorm2 = nn.BatchNorm2d(channels_base * 2)
        self.batchnorm3 = nn.BatchNorm2d(channels_base)
        self.batchnorm4 = nn.BatchNorm2d(channels_base // 2)
        
    def forward(self, input):
        # encoder
        x = F.leaky_relu(self.batchnorm1(self.conv1(input)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.avg_pool2d(x, kernel_size=3, stride=2, padding=1)
        x = F.leaky_relu(self.batchnorm2(self.conv2(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.avg_pool2d(x, kernel_size=3, stride=2, padding=1)
        
        # decoder
        x = F.leaky_relu(self.batchnorm3(self.convtrans1(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.leaky_relu(self.batchnorm4(self.convtrans2(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.interpolate(self.convtrans3(x), scale_factor=scale_factor)

        return x

In [11]:
model = Autoencoder()

In [12]:
criterion = [MeanSquaredError(), PeakSignalNoiseRatio(data_range=1.0), StructuralSimilarityIndexMeasure(data_range=1.0)]

In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [14]:
# # Move model and loss function to GPU
if use_gpu: 
    criterion = [criterion[0].to("cuda"), criterion[1].to("cuda"), criterion[2].to("cuda")]
    model = model.cuda()

In [15]:
if use_gpu: 
    from torchsummary import summary
    summary(model, (1, SIZE, SIZE))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 256, 32, 32]           2,560
       BatchNorm2d-2          [-1, 256, 32, 32]             512
            Conv2d-3          [-1, 512, 16, 16]       1,180,160
       BatchNorm2d-4          [-1, 512, 16, 16]           1,024
   ConvTranspose2d-5            [-1, 256, 8, 8]       1,179,904
       BatchNorm2d-6            [-1, 256, 8, 8]             512
   ConvTranspose2d-7          [-1, 128, 16, 16]         295,040
       BatchNorm2d-8          [-1, 128, 16, 16]             256
   ConvTranspose2d-9            [-1, 2, 16, 16]           2,306
Total params: 2,662,274
Trainable params: 2,662,274
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 6.75
Params size (MB): 10.16
Estimated Total Size (MB): 16.91
-------------------------------------

In [16]:
class AverageMeter(object):
    '''A handy class from the PyTorch ImageNet tutorial''' 
    def __init__(self):
        self.reset()
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
       Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
    color_image = lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    if save_path is not None and save_name is not None: 
        plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
        plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

In [17]:
def validate(val_loader, model, criterion, save_images, epoch):
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]

    model.eval()
    already_saved_images = False
    for gray, ab in val_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()

        # Run model and record loss
        output_ab = model(gray) # throw away class predictions
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))

        # Save images to file
        if save_images and not already_saved_images:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)): # save at most 5 images
                save_path = {'grayscale': gray_imgs, 'colorized': color_imgs}
                save_name = 'img-{}-epoch-{}.jpg'.format(j, epoch)
                to_rgb(gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

    print(f'Validate: MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print('Finished validation.')
    if epoch >= 0:
        writer.add_scalar("MSE/test", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/test", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/test", _loss[2].avg, epoch)
    return _loss[0].avg, _loss[1].avg, _loss[2].avg

In [18]:
def train(train_loader, model, criterion, optimizer, epoch):
    print(f'Starting training epoch {epoch}')
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]
    
    model.train()

    for gray, ab in train_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()
            
        optimizer.zero_grad()

        output_ab = model(gray) 
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        loss[0].backward()
        optimizer.step()
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))
        
    print(f'Epoch: {epoch}, MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print(f'Finished training epoch {epoch}')
    if epoch >= 0:
        writer.add_scalar("MSE/train", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/train", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/train", _loss[2].avg, epoch)

In [19]:
# Train model
for epoch in range(epochs):
    # Train for one epoch, then validate
    train(train_loader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(val_loader, model, criterion, save_images, epoch)
    # Save checkpoint and replace old best model if current model is better
    if losses[0] < best_losses[0]:
        best_losses[0] = losses[0]
        best_epoch = epoch
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}.pth')
    if losses[1] < best_losses[1]:
        best_losses[1] = losses[1]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-PSNRLoss-{losses[1]:.8f}.pth')
    if losses[2] < best_losses[2]:
        best_losses[2] = losses[2]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-SSIMLoss-{losses[2]:.8f}.pth')
    
    if epoch - best_epoch >= patience and epoch >= 100:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}-early_stop.pth')
        break
    
    if epoch == epochs - 1:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')


Starting training epoch 0
Epoch: 0, MSE 0.07072462 (1.60808911), PSNR 11.50429344 (5.11470274), SSIM 0.04266262 (0.01149934)
Finished training epoch 0


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.07407236 (0.07573690), PSNR 11.30343819 (11.23587111), SSIM 0.02681173 (0.03313784)
Finished validation.
Starting training epoch 1
Epoch: 1, MSE 0.03401009 (0.04447150), PSNR 14.68392277 (13.65153554), SSIM 0.13707161 (0.10442819)
Finished training epoch 1


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.03549461 (0.03509330), PSNR 14.49837494 (14.69999689), SSIM 0.12617514 (0.16606874)
Finished validation.
Starting training epoch 2
Epoch: 2, MSE 0.02347648 (0.02733863), PSNR 16.29367065 (15.64726524), SSIM 0.30162421 (0.23670072)
Finished training epoch 2


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.02394632 (0.02349041), PSNR 16.20761108 (16.47217056), SSIM 0.21102938 (0.29320520)
Finished validation.
Starting training epoch 3
Epoch: 3, MSE 0.01873474 (0.02152711), PSNR 17.27352333 (16.68045031), SSIM 0.36386228 (0.32259502)
Finished training epoch 3


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.01972169 (0.01996435), PSNR 17.05055809 (17.18555763), SSIM 0.26313692 (0.36000708)
Finished validation.
Starting training epoch 4
Epoch: 4, MSE 0.01481940 (0.01736999), PSNR 18.29169273 (17.61093357), SSIM 0.40339822 (0.37633833)
Finished training epoch 4


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.01749817 (0.01634133), PSNR 17.57007408 (18.04652719), SSIM 0.29060411 (0.40036002)
Finished validation.
Starting training epoch 5
Epoch: 5, MSE 0.01174779 (0.01405777), PSNR 19.30043602 (18.52986926), SSIM 0.44111291 (0.41571487)
Finished training epoch 5


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.01377232 (0.01308868), PSNR 18.60992622 (19.00236447), SSIM 0.31990287 (0.43465957)
Finished validation.
Starting training epoch 6
Epoch: 6, MSE 0.00978564 (0.01131359), PSNR 20.09410858 (19.47452397), SSIM 0.43538895 (0.44923264)
Finished training epoch 6


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.01113847 (0.01062896), PSNR 19.53174210 (19.89176075), SSIM 0.34918386 (0.46528717)
Finished validation.
Starting training epoch 7
Epoch: 7, MSE 0.00862603 (0.00911827), PSNR 20.64189148 (20.41138073), SSIM 0.50171423 (0.48002992)
Finished training epoch 7


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00877907 (0.00846988), PSNR 20.56551361 (20.86263978), SSIM 0.37640035 (0.49591719)
Finished validation.
Starting training epoch 8
Epoch: 8, MSE 0.00627607 (0.00739454), PSNR 22.02312088 (21.32233459), SSIM 0.51387209 (0.50987789)
Finished training epoch 8


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00719243 (0.00671812), PSNR 21.43124390 (21.84304150), SSIM 0.42010704 (0.52986460)
Finished validation.
Starting training epoch 9
Epoch: 9, MSE 0.00565096 (0.00605290), PSNR 22.47877693 (22.19075912), SSIM 0.53359568 (0.53979124)
Finished training epoch 9


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00635990 (0.00604529), PSNR 21.96549606 (22.29216298), SSIM 0.43488824 (0.54504077)
Finished validation.
Starting training epoch 10
Epoch: 10, MSE 0.00401917 (0.00505727), PSNR 23.95863152 (22.97330541), SSIM 0.59099627 (0.56880311)
Finished training epoch 10


  return func(*args, **kwargs)


Validate: MSE 0.00514921 (0.00472895), PSNR 22.88259506 (23.33915038), SSIM 0.48183540 (0.58471767)
Finished validation.
Starting training epoch 11
Epoch: 11, MSE 0.00407717 (0.00432377), PSNR 23.89641571 (23.65478519), SSIM 0.56781644 (0.59668831)
Finished training epoch 11


  return func(*args, **kwargs)


Validate: MSE 0.00445731 (0.00406229), PSNR 23.50926971 (23.98490038), SSIM 0.50819743 (0.61092756)
Finished validation.
Starting training epoch 12
Epoch: 12, MSE 0.00353274 (0.00379324), PSNR 24.51888657 (24.22536791), SSIM 0.62378126 (0.62374339)
Finished training epoch 12


  return func(*args, **kwargs)


Validate: MSE 0.00391803 (0.00356458), PSNR 24.06932640 (24.54258712), SSIM 0.54748470 (0.63903980)
Finished validation.
Starting training epoch 13
Epoch: 13, MSE 0.00333262 (0.00341675), PSNR 24.77214432 (24.68081540), SSIM 0.66663539 (0.64875606)
Finished training epoch 13
Validate: MSE 0.00349689 (0.00323088), PSNR 24.56318092 (24.97587624), SSIM 0.57970476 (0.66485303)
Finished validation.
Starting training epoch 14
Epoch: 14, MSE 0.00326167 (0.00316175), PSNR 24.86560249 (25.02066247), SSIM 0.66913605 (0.67220098)
Finished training epoch 14
Validate: MSE 0.00345969 (0.00307252), PSNR 24.60962868 (25.18928112), SSIM 0.59620911 (0.68284313)
Finished validation.
Starting training epoch 15
Epoch: 15, MSE 0.00284909 (0.00298544), PSNR 25.45293045 (25.26949927), SSIM 0.71679562 (0.69359905)
Finished training epoch 15


  return func(*args, **kwargs)


Validate: MSE 0.00317943 (0.00289506), PSNR 24.97650528 (25.45767310), SSIM 0.62808359 (0.70583982)
Finished validation.
Starting training epoch 16
Epoch: 16, MSE 0.00289681 (0.00286967), PSNR 25.38080025 (25.44029425), SSIM 0.72761053 (0.71252149)
Finished training epoch 16
Validate: MSE 0.00310218 (0.00280921), PSNR 25.08332825 (25.58516345), SSIM 0.65666610 (0.72128322)
Finished validation.
Starting training epoch 17
Epoch: 17, MSE 0.00243182 (0.00279571), PSNR 26.14069366 (25.55891651), SSIM 0.75314438 (0.72842307)
Finished training epoch 17
Validate: MSE 0.00307271 (0.00276232), PSNR 25.12479019 (25.65671175), SSIM 0.66912407 (0.73507667)
Finished validation.
Starting training epoch 18
Epoch: 18, MSE 0.00287867 (0.00274622), PSNR 25.40807343 (25.63413995), SSIM 0.72968900 (0.74129820)
Finished training epoch 18
Validate: MSE 0.00302075 (0.00271497), PSNR 25.19885254 (25.73333123), SSIM 0.68399227 (0.74794175)
Finished validation.
Starting training epoch 19
Epoch: 19, MSE 0.0034670

  return func(*args, **kwargs)


Validate: MSE 0.00297936 (0.00295007), PSNR 25.25877380 (25.42040400), SSIM 0.70095307 (0.76349159)
Finished validation.
Starting training epoch 33
Epoch: 33, MSE 0.00202285 (0.00249170), PSNR 26.94036484 (26.05742336), SSIM 0.79184937 (0.77080289)
Finished training epoch 33
Validate: MSE 0.00297518 (0.00259303), PSNR 25.26486015 (25.93190522), SSIM 0.70252109 (0.76964974)
Finished validation.
Starting training epoch 34
Epoch: 34, MSE 0.00213707 (0.00247271), PSNR 26.70180511 (26.08938129), SSIM 0.77699697 (0.77136796)
Finished training epoch 34
Validate: MSE 0.00387242 (0.00379146), PSNR 24.12017822 (24.34542046), SSIM 0.68652093 (0.74232696)
Finished validation.
Starting training epoch 35
Epoch: 35, MSE 0.00228152 (0.00244629), PSNR 26.41776276 (26.13766094), SSIM 0.77306241 (0.77225883)
Finished training epoch 35
Validate: MSE 0.00284572 (0.00250252), PSNR 25.45807076 (26.08296546), SSIM 0.70487809 (0.76959140)
Finished validation.
Starting training epoch 36
Epoch: 36, MSE 0.0024188

Validate: MSE 0.00274966 (0.00244929), PSNR 25.60720253 (26.19342536), SSIM 0.71149433 (0.77427861)
Finished validation.
Starting training epoch 63
Epoch: 63, MSE 0.00231399 (0.00231162), PSNR 26.35639000 (26.38566203), SSIM 0.77253848 (0.77568401)
Finished training epoch 63
Validate: MSE 0.00248195 (0.00255029), PSNR 26.05207253 (26.02191525), SSIM 0.71682549 (0.76701060)
Finished validation.
Starting training epoch 64
Epoch: 64, MSE 0.00265851 (0.00231238), PSNR 25.75361824 (26.38462674), SSIM 0.75694102 (0.77582355)
Finished training epoch 64
Validate: MSE 0.00285713 (0.00238966), PSNR 25.44069481 (26.28504410), SSIM 0.71203244 (0.77425201)
Finished validation.
Starting training epoch 65
Epoch: 65, MSE 0.00254376 (0.00229875), PSNR 25.94524384 (26.40866336), SSIM 0.75440294 (0.77589314)
Finished training epoch 65
Validate: MSE 0.00283700 (0.00233625), PSNR 25.47140503 (26.37759109), SSIM 0.71102923 (0.76919147)
Finished validation.
Starting training epoch 66
Epoch: 66, MSE 0.0022671

Validate: MSE 0.00250261 (0.00240180), PSNR 26.01607704 (26.24881435), SSIM 0.71860367 (0.77228563)
Finished validation.
Starting training epoch 93
Epoch: 93, MSE 0.00196277 (0.00222461), PSNR 27.07131386 (26.55003261), SSIM 0.80478507 (0.77665337)
Finished training epoch 93
Validate: MSE 0.00268036 (0.00256271), PSNR 25.71807098 (25.99945399), SSIM 0.71186507 (0.77337163)
Finished validation.
Starting training epoch 94
Epoch: 94, MSE 0.00230432 (0.00224544), PSNR 26.37456703 (26.51115161), SSIM 0.76909512 (0.77547853)
Finished training epoch 94
Validate: MSE 0.00272608 (0.00239366), PSNR 25.64460564 (26.26342319), SSIM 0.71098763 (0.77420170)
Finished validation.
Starting training epoch 95
Epoch: 95, MSE 0.00189402 (0.00222652), PSNR 27.22614670 (26.54888905), SSIM 0.78793728 (0.77613750)
Finished training epoch 95
Validate: MSE 0.00266714 (0.00274289), PSNR 25.73953247 (25.73293042), SSIM 0.70647961 (0.76791816)
Finished validation.
Starting training epoch 96
Epoch: 96, MSE 0.0026520

Epoch: 122, MSE 0.00197864 (0.00218208), PSNR 27.03633690 (26.63517798), SSIM 0.79355800 (0.77626405)
Finished training epoch 122
Validate: MSE 0.00267355 (0.00242367), PSNR 25.72912216 (26.23374184), SSIM 0.71754050 (0.77713594)
Finished validation.
Starting training epoch 123
Epoch: 123, MSE 0.00236723 (0.00221898), PSNR 26.25760269 (26.56444439), SSIM 0.78969204 (0.77493167)
Finished training epoch 123
Validate: MSE 0.00308434 (0.00274860), PSNR 25.10836983 (25.70624048), SSIM 0.71479917 (0.77230780)
Finished validation.
Starting training epoch 124
Epoch: 124, MSE 0.00223528 (0.00219682), PSNR 26.50668526 (26.60438000), SSIM 0.75998789 (0.77553653)
Finished training epoch 124
Validate: MSE 0.00266450 (0.00239990), PSNR 25.74383736 (26.26414468), SSIM 0.70446813 (0.76576996)
Finished validation.
Starting training epoch 125
Epoch: 125, MSE 0.00188190 (0.00216333), PSNR 27.25403404 (26.67135324), SSIM 0.76904577 (0.77637950)
Finished training epoch 125
Validate: MSE 0.00266189 (0.00247

Validate: MSE 0.00257213 (0.00228112), PSNR 25.89706421 (26.48393003), SSIM 0.71351755 (0.77225279)
Finished validation.
Starting training epoch 152
Epoch: 152, MSE 0.00233719 (0.00215663), PSNR 26.31305122 (26.68478596), SSIM 0.77501118 (0.77510294)
Finished training epoch 152
Validate: MSE 0.00287257 (0.00246014), PSNR 25.41729736 (26.17273691), SSIM 0.70735025 (0.77233085)
Finished validation.
Starting training epoch 153
Epoch: 153, MSE 0.00190559 (0.00214694), PSNR 27.19969940 (26.70277093), SSIM 0.76740164 (0.77558814)
Finished training epoch 153
Validate: MSE 0.00252419 (0.00230385), PSNR 25.97876930 (26.43427926), SSIM 0.71275467 (0.76641154)
Finished validation.
Starting training epoch 154
Epoch: 154, MSE 0.00195968 (0.00215345), PSNR 27.07813644 (26.69340973), SSIM 0.78287190 (0.77511078)
Finished training epoch 154
Validate: MSE 0.00253113 (0.00255122), PSNR 25.96684647 (26.03572877), SSIM 0.71383709 (0.76079172)
Finished validation.
Starting training epoch 155
Epoch: 155, MS

Epoch: 181, MSE 0.00217152 (0.00213857), PSNR 26.63235664 (26.71989969), SSIM 0.76209247 (0.77434780)
Finished training epoch 181
Validate: MSE 0.00277754 (0.00257184), PSNR 25.56338882 (25.96881353), SSIM 0.71663028 (0.76503933)
Finished validation.
Starting training epoch 182
Epoch: 182, MSE 0.00310277 (0.00211523), PSNR 25.08250427 (26.76886502), SSIM 0.77135164 (0.77510127)
Finished training epoch 182
Validate: MSE 0.00257863 (0.00225230), PSNR 25.88611412 (26.52870741), SSIM 0.70902109 (0.77042543)
Finished validation.
Starting training epoch 183
Epoch: 183, MSE 0.00237804 (0.00212069), PSNR 26.23781586 (26.75729332), SSIM 0.77891171 (0.77510679)
Finished training epoch 183
Validate: MSE 0.00268771 (0.00229510), PSNR 25.70618057 (26.45280059), SSIM 0.72100818 (0.77189879)
Finished validation.
Starting training epoch 184
Epoch: 184, MSE 0.00227655 (0.00210788), PSNR 26.42722130 (26.78127018), SSIM 0.74932027 (0.77510258)
Finished training epoch 184
Validate: MSE 0.00276604 (0.00225

<Figure size 432x288 with 0 Axes>

In [20]:
torch.save(model.state_dict(), f'{checkpoints}/last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')

In [21]:
# Validate
save_images = True
with torch.no_grad():
    validate(val_loader, model, criterion, save_images, -1)

Validate: MSE 0.00257613 (0.00246471), PSNR 25.89031792 (26.17519067), SSIM 0.70988852 (0.76764289)
Finished validation.


<Figure size 432x288 with 0 Axes>

In [22]:
# # Show images 
# image_pairs = []

# for i in range(10):
#     image_pairs.append((f'{color_imgs}img-{i}-epoch-{best_epoch}.jpg', f'{gray_imgs}img-{i}-epoch-{best_epoch}.jpg'))
    
# for c, g in image_pairs:
#   color = mpimg.imread(c)
#   gray  = mpimg.imread(g)
#   f, axarr = plt.subplots(1, 2)
#   f.set_size_inches(15, 15)
#   axarr[0].imshow(gray, cmap='gray')
#   axarr[1].imshow(color)
#   axarr[0].axis('off'), axarr[1].axis('off')
#   plt.show()