In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# For conversion
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchmetrics import MeanSquaredError, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from PIL import Image
# For utilities
import os, shutil, time

In [2]:
%load_ext tensorboard
%tensorboard --logdir=runs

In [3]:
# colab i kaggle jeszcze nie testowane
colab = False
kaggle = False
test_number = '12_2'

In [4]:
color_imgs = 'outputs/color/'
gray_imgs = 'outputs/gray/'
checkpoints = 'checkpoints'
if colab:
    from google.colab import drive
    drive.mount('/content/drive')
    dataset = '/content/drive/MyDrive/MGU/cifar10/'
    
    color_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{color_imgs}'
    gray_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{gray_imgs}'
    checkpoints = f'/content/drive/MyDrive/MGU/{test_number}/{checkpoints}'
elif kaggle:
    dataset = '/kaggle/input/cifar10/'
    
    color_imgs = f'{test_number}/{color_imgs}'
    gray_imgs = f'{test_number}/{gray_imgs}'
    checkpoints = f'{test_number}/{checkpoints}'
else:
    dataset = '../../datasets/cifar10/'

In [5]:
# Make folders and set parameters
os.makedirs(color_imgs, exist_ok=True)
os.makedirs(gray_imgs, exist_ok=True)
os.makedirs(checkpoints, exist_ok=True)
save_images = True
best_losses = [1e10, 1e10, 1e10]
best_epoch = -1
patience = 50
epochs = 500
batch_size = 128
SIZE = 32

In [6]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [7]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()
print(use_gpu)

True


In [8]:
class LabImageFolder(torch.utils.data.Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC),
                transforms.RandomCrop(SIZE),
                transforms.RandomHorizontalFlip(),  
            ])
        elif split == 'val':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC), 
                transforms.RandomCrop(SIZE), 
            ])
            
        self.split = split
        self.size = SIZE
        self.paths = [os.path.join(paths, file) for file in os.listdir(paths) if os.path.isfile(
            os.path.join(paths, file))]
        
        
    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("RGB")
        img_original = self.transforms(img)
        img_original = np.asarray(img_original)
        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        img_gray = rgb2gray(img_original)
        img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
        return img_gray, img_ab
    
    def __len__(self):
        return len(self.paths)

In [9]:
# Training
train_imagefolder = LabImageFolder(dataset + 'train')
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=batch_size, shuffle=True)
# Validation 
val_imagefolder = LabImageFolder(dataset + 'val' , 'val')
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=batch_size, shuffle=False)

In [10]:
kernel_size=3
stride_en=1
stride_de=1
padding=1
scale_factor=2
padding_mode='zeros'
channels_base = 256
p1 = .25

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        self.conv1 = nn.Conv2d(1, channels_base, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        self.conv2 = nn.Conv2d(channels_base, channels_base * 2, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)

        self.convtrans1 = nn.ConvTranspose2d(channels_base * 2, channels_base, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans2 = nn.ConvTranspose2d(channels_base, channels_base // 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans3 = nn.ConvTranspose2d(channels_base // 2, 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)

        self.batchnorm1 = nn.BatchNorm2d(channels_base)
        self.batchnorm2 = nn.BatchNorm2d(channels_base * 2)
        self.batchnorm3 = nn.BatchNorm2d(channels_base)
        self.batchnorm4 = nn.BatchNorm2d(channels_base // 2)
        
    def forward(self, input):
        # encoder
        x = F.leaky_relu(self.batchnorm1(self.conv1(input)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        x = F.leaky_relu(self.batchnorm2(self.conv2(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        
        # decoder
        x = F.leaky_relu(self.batchnorm3(self.convtrans1(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.leaky_relu(self.batchnorm4(self.convtrans2(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.interpolate(self.convtrans3(x), scale_factor=scale_factor)

        return x

In [11]:
model = Autoencoder()

In [12]:
criterion = [MeanSquaredError(), PeakSignalNoiseRatio(data_range=1.0), StructuralSimilarityIndexMeasure(data_range=1.0)]

In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [14]:
# # Move model and loss function to GPU
if use_gpu: 
    criterion = [criterion[0].to("cuda"), criterion[1].to("cuda"), criterion[2].to("cuda")]
    model = model.cuda()

In [15]:
if use_gpu: 
    from torchsummary import summary
    summary(model, (1, SIZE, SIZE))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 256, 32, 32]           2,560
       BatchNorm2d-2          [-1, 256, 32, 32]             512
            Conv2d-3          [-1, 512, 16, 16]       1,180,160
       BatchNorm2d-4          [-1, 512, 16, 16]           1,024
   ConvTranspose2d-5            [-1, 256, 8, 8]       1,179,904
       BatchNorm2d-6            [-1, 256, 8, 8]             512
   ConvTranspose2d-7          [-1, 128, 16, 16]         295,040
       BatchNorm2d-8          [-1, 128, 16, 16]             256
   ConvTranspose2d-9            [-1, 2, 16, 16]           2,306
Total params: 2,662,274
Trainable params: 2,662,274
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 6.75
Params size (MB): 10.16
Estimated Total Size (MB): 16.91
-------------------------------------

In [16]:
class AverageMeter(object):
    '''A handy class from the PyTorch ImageNet tutorial''' 
    def __init__(self):
        self.reset()
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
       Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
    color_image = lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    if save_path is not None and save_name is not None: 
        plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
        plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

In [17]:
def validate(val_loader, model, criterion, save_images, epoch):
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]

    model.eval()
    already_saved_images = False
    for gray, ab in val_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()

        # Run model and record loss
        output_ab = model(gray) # throw away class predictions
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))

        # Save images to file
        if save_images and not already_saved_images:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)): # save at most 5 images
                save_path = {'grayscale': gray_imgs, 'colorized': color_imgs}
                save_name = 'img-{}-epoch-{}.jpg'.format(j, epoch)
                to_rgb(gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

    print(f'Validate: MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print('Finished validation.')
    if epoch >= 0:
        writer.add_scalar("MSE/test", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/test", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/test", _loss[2].avg, epoch)
    return _loss[0].avg, _loss[1].avg, _loss[2].avg

In [18]:
def train(train_loader, model, criterion, optimizer, epoch):
    print(f'Starting training epoch {epoch}')
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]
    
    model.train()

    for gray, ab in train_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()
            
        optimizer.zero_grad()

        output_ab = model(gray) 
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        loss[0].backward()
        optimizer.step()
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))
        
    print(f'Epoch: {epoch}, MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print(f'Finished training epoch {epoch}')
    if epoch >= 0:
        writer.add_scalar("MSE/train", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/train", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/train", _loss[2].avg, epoch)

In [19]:
# Train model
for epoch in range(epochs):
    # Train for one epoch, then validate
    train(train_loader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(val_loader, model, criterion, save_images, epoch)
    # Save checkpoint and replace old best model if current model is better
    if losses[0] < best_losses[0]:
        best_losses[0] = losses[0]
        best_epoch = epoch
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}.pth')
    if losses[1] < best_losses[1]:
        best_losses[1] = losses[1]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-PSNRLoss-{losses[1]:.8f}.pth')
    if losses[2] < best_losses[2]:
        best_losses[2] = losses[2]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-SSIMLoss-{losses[2]:.8f}.pth')
    
    if epoch - best_epoch >= patience and epoch >= 100:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}-early_stop.pth')
        break
    
    if epoch == epochs - 1:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')


Starting training epoch 0
Epoch: 0, MSE 0.08327873 (1.41156435), PSNR 10.79465866 (4.66442880), SSIM 0.03115659 (0.01119883)
Finished training epoch 0


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.09165615 (0.07927416), PSNR 10.37838268 (11.08877255), SSIM 0.02934813 (0.03353269)
Finished validation.
Starting training epoch 1
Epoch: 1, MSE 0.02617712 (0.04396464), PSNR 15.82078075 (13.78551768), SSIM 0.18370314 (0.09407498)
Finished training epoch 1


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.03138791 (0.02542029), PSNR 15.03237534 (16.10343558), SSIM 0.14476478 (0.17065025)
Finished validation.
Starting training epoch 2
Epoch: 2, MSE 0.01852365 (0.02168833), PSNR 17.32273293 (16.67078659), SSIM 0.28183851 (0.24049046)
Finished training epoch 2


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.02340576 (0.01828765), PSNR 16.30677223 (17.58876911), SSIM 0.25454295 (0.30432130)
Finished validation.
Starting training epoch 3
Epoch: 3, MSE 0.01326566 (0.01546416), PSNR 18.77271080 (18.12288764), SSIM 0.40604377 (0.35810766)
Finished training epoch 3


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.01578336 (0.01243241), PSNR 18.01800537 (19.25136619), SSIM 0.34343523 (0.41383523)
Finished validation.
Starting training epoch 4
Epoch: 4, MSE 0.01077255 (0.01181445), PSNR 19.67681694 (19.28909226), SSIM 0.46079960 (0.43453321)
Finished training epoch 4


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.01416712 (0.01062667), PSNR 18.48718262 (19.92237471), SSIM 0.37188411 (0.46193041)
Finished validation.
Starting training epoch 5
Epoch: 5, MSE 0.00907795 (0.00921550), PSNR 20.42012215 (20.36762797), SSIM 0.52579266 (0.48528625)
Finished training epoch 5


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.01121452 (0.00819643), PSNR 19.50219345 (21.01446596), SSIM 0.41728380 (0.50596116)
Finished validation.
Starting training epoch 6
Epoch: 6, MSE 0.00692929 (0.00732015), PSNR 21.59311485 (21.36732090), SSIM 0.52054709 (0.52382519)
Finished training epoch 6


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00873746 (0.00670247), PSNR 20.58614731 (21.86509345), SSIM 0.44659895 (0.54279761)
Finished validation.
Starting training epoch 7
Epoch: 7, MSE 0.00579280 (0.00587740), PSNR 22.37111473 (22.32131591), SSIM 0.54582489 (0.55692842)
Finished training epoch 7


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00720929 (0.00535719), PSNR 21.42107201 (22.80568924), SSIM 0.47073993 (0.57210340)
Finished validation.
Starting training epoch 8
Epoch: 8, MSE 0.00427757 (0.00482007), PSNR 23.68802452 (23.18315956), SSIM 0.60828185 (0.58626762)
Finished training epoch 8


  return func(*args, **kwargs)


Validate: MSE 0.00545140 (0.00414716), PSNR 22.63491631 (23.89537502), SSIM 0.51718628 (0.61368704)
Finished validation.
Starting training epoch 9
Epoch: 9, MSE 0.00448594 (0.00407090), PSNR 23.48146820 (23.91739266), SSIM 0.60612476 (0.61499954)
Finished training epoch 9


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00488051 (0.00385650), PSNR 23.11534309 (24.20150918), SSIM 0.52410769 (0.62742977)
Finished validation.
Starting training epoch 10
Epoch: 10, MSE 0.00345364 (0.00354180), PSNR 24.61722374 (24.52236394), SSIM 0.63536131 (0.64149206)
Finished training epoch 10


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00420656 (0.00332073), PSNR 23.76072502 (24.84407592), SSIM 0.55913550 (0.65549332)
Finished validation.
Starting training epoch 11
Epoch: 11, MSE 0.00269810 (0.00317508), PSNR 25.68941307 (24.99903232), SSIM 0.69224948 (0.66567271)
Finished training epoch 11


  return func(*args, **kwargs)


Validate: MSE 0.00362383 (0.00301866), PSNR 24.40832329 (25.25874289), SSIM 0.58526117 (0.67807574)
Finished validation.
Starting training epoch 12
Epoch: 12, MSE 0.00289155 (0.00293491), PSNR 25.38869476 (25.33984872), SSIM 0.70861322 (0.68827626)
Finished training epoch 12


  return func(*args, **kwargs)


Validate: MSE 0.00336910 (0.00282643), PSNR 24.72486305 (25.54649506), SSIM 0.60685742 (0.69971727)
Finished validation.
Starting training epoch 13
Epoch: 13, MSE 0.00272980 (0.00278547), PSNR 25.63868904 (25.57014605), SSIM 0.73394442 (0.70821523)
Finished training epoch 13
Validate: MSE 0.00368224 (0.00286456), PSNR 24.33887672 (25.49092531), SSIM 0.62909061 (0.72083154)
Finished validation.
Starting training epoch 14
Epoch: 14, MSE 0.00226623 (0.00269786), PSNR 26.44695663 (25.71155614), SSIM 0.74129897 (0.72452503)
Finished training epoch 14


  return func(*args, **kwargs)


Validate: MSE 0.00406546 (0.00351352), PSNR 23.90890503 (24.63951240), SSIM 0.64347064 (0.71797514)
Finished validation.
Starting training epoch 15
Epoch: 15, MSE 0.00332611 (0.00262087), PSNR 24.78062630 (25.83576274), SSIM 0.74022800 (0.73810725)
Finished training epoch 15
Validate: MSE 0.00301585 (0.00255700), PSNR 25.20590591 (25.98579915), SSIM 0.66532636 (0.74587521)
Finished validation.
Starting training epoch 16
Epoch: 16, MSE 0.00267474 (0.00258617), PSNR 25.72717667 (25.89654401), SSIM 0.73929989 (0.74818902)
Finished training epoch 16
Validate: MSE 0.00342128 (0.00278739), PSNR 24.65811539 (25.61195264), SSIM 0.67378259 (0.74688625)
Finished validation.
Starting training epoch 17
Epoch: 17, MSE 0.00243358 (0.00258128), PSNR 26.13754082 (25.90413746), SSIM 0.76573598 (0.75525377)
Finished training epoch 17
Validate: MSE 0.00363139 (0.00313587), PSNR 24.39927101 (25.13157484), SSIM 0.68140787 (0.75113333)
Finished validation.
Starting training epoch 18
Epoch: 18, MSE 0.0028184

  return func(*args, **kwargs)


Validate: MSE 0.00320096 (0.00261417), PSNR 24.94720078 (25.88745270), SSIM 0.69084597 (0.76524384)
Finished validation.
Starting training epoch 21
Epoch: 21, MSE 0.00278325 (0.00253182), PSNR 25.55447769 (25.98888943), SSIM 0.76136285 (0.76611014)
Finished training epoch 21
Validate: MSE 0.00315539 (0.00266560), PSNR 25.00946999 (25.81691840), SSIM 0.69602323 (0.76422516)
Finished validation.
Starting training epoch 22
Epoch: 22, MSE 0.00263916 (0.00254042), PSNR 25.78533936 (25.97181633), SSIM 0.78725046 (0.76657467)
Finished training epoch 22
Validate: MSE 0.00461108 (0.00417867), PSNR 23.36197281 (23.90708240), SSIM 0.68336153 (0.74853590)
Finished validation.
Starting training epoch 23
Epoch: 23, MSE 0.00281146 (0.00254122), PSNR 25.51067352 (25.97104064), SSIM 0.75999784 (0.76683220)
Finished training epoch 23
Validate: MSE 0.00315815 (0.00270139), PSNR 25.00567436 (25.75696445), SSIM 0.69246459 (0.76202718)
Finished validation.
Starting training epoch 24
Epoch: 24, MSE 0.0025914

  return func(*args, **kwargs)


Validate: MSE 0.00295599 (0.00296573), PSNR 25.29296875 (25.36271628), SSIM 0.70073497 (0.76789003)
Finished validation.
Starting training epoch 26
Epoch: 26, MSE 0.00256953 (0.00251765), PSNR 25.90146637 (26.01077316), SSIM 0.75638771 (0.76821781)
Finished training epoch 26
Validate: MSE 0.00353139 (0.00277689), PSNR 24.52053833 (25.62114070), SSIM 0.68553889 (0.76305651)
Finished validation.
Starting training epoch 27
Epoch: 27, MSE 0.00252859 (0.00250380), PSNR 25.97121429 (26.03570506), SSIM 0.78616560 (0.76858965)
Finished training epoch 27
Validate: MSE 0.00323070 (0.00294430), PSNR 24.90703201 (25.38272813), SSIM 0.69746232 (0.76679665)
Finished validation.
Starting training epoch 28
Epoch: 28, MSE 0.00260592 (0.00251669), PSNR 25.84038925 (26.01545135), SSIM 0.74139726 (0.76864475)
Finished training epoch 28
Validate: MSE 0.00283168 (0.00245264), PSNR 25.47955513 (26.15884510), SSIM 0.70177466 (0.77016770)
Finished validation.
Starting training epoch 29
Epoch: 29, MSE 0.0020827

Validate: MSE 0.00251547 (0.00230858), PSNR 25.99381065 (26.42210941), SSIM 0.71611786 (0.77471505)
Finished validation.
Starting training epoch 56
Epoch: 56, MSE 0.00240966 (0.00229152), PSNR 26.18043900 (26.42115568), SSIM 0.76841623 (0.77645819)
Finished training epoch 56
Validate: MSE 0.00331487 (0.00267321), PSNR 24.79532623 (25.81261281), SSIM 0.70252383 (0.77042217)
Finished validation.
Starting training epoch 57
Epoch: 57, MSE 0.00260748 (0.00226141), PSNR 25.83778954 (26.47894981), SSIM 0.75789952 (0.77716951)
Finished training epoch 57
Validate: MSE 0.00281069 (0.00232505), PSNR 25.51186943 (26.39964201), SSIM 0.71293312 (0.77354200)
Finished validation.
Starting training epoch 58
Epoch: 58, MSE 0.00230372 (0.00226874), PSNR 26.37569618 (26.46468799), SSIM 0.75898641 (0.77688253)
Finished training epoch 58
Validate: MSE 0.00274775 (0.00232695), PSNR 25.61022568 (26.40042575), SSIM 0.71578777 (0.77513224)
Finished validation.
Starting training epoch 59
Epoch: 59, MSE 0.0022010

Validate: MSE 0.00257143 (0.00227912), PSNR 25.89825439 (26.48612681), SSIM 0.71421230 (0.77459559)
Finished validation.
Starting training epoch 86
Epoch: 86, MSE 0.00172948 (0.00216325), PSNR 27.62083244 (26.67195299), SSIM 0.77403849 (0.77597166)
Finished training epoch 86
Validate: MSE 0.00283951 (0.00247480), PSNR 25.46755791 (26.12292038), SSIM 0.70529652 (0.76649206)
Finished validation.
Starting training epoch 87
Epoch: 87, MSE 0.00219615 (0.00213601), PSNR 26.58338356 (26.72512149), SSIM 0.77074587 (0.77698266)
Finished training epoch 87
Validate: MSE 0.00269899 (0.00246931), PSNR 25.68798828 (26.12706580), SSIM 0.70866299 (0.76505907)
Finished validation.
Starting training epoch 88
Epoch: 88, MSE 0.00192346 (0.00213961), PSNR 27.15915680 (26.71736634), SSIM 0.78367889 (0.77653607)
Finished training epoch 88
Validate: MSE 0.00261191 (0.00236535), PSNR 25.83041191 (26.31267763), SSIM 0.70943612 (0.76740961)
Finished validation.
Starting training epoch 89
Epoch: 89, MSE 0.0019218

Epoch: 115, MSE 0.00205080 (0.00206356), PSNR 26.88077354 (26.87405537), SSIM 0.79356074 (0.77493732)
Finished training epoch 115
Validate: MSE 0.00268470 (0.00223059), PSNR 25.71104622 (26.58017771), SSIM 0.70730358 (0.77114557)
Finished validation.
Starting training epoch 116
Epoch: 116, MSE 0.00218770 (0.00203172), PSNR 26.60012436 (26.94361907), SSIM 0.78493917 (0.77589516)
Finished training epoch 116
Validate: MSE 0.00262071 (0.00225222), PSNR 25.81580162 (26.52603053), SSIM 0.70185769 (0.76649484)
Finished validation.
Starting training epoch 117
Epoch: 117, MSE 0.00180135 (0.00204732), PSNR 27.44402504 (26.91067619), SSIM 0.78071767 (0.77530359)
Finished training epoch 117
Validate: MSE 0.00271565 (0.00228474), PSNR 25.66126060 (26.47051876), SSIM 0.70046651 (0.76740549)
Finished validation.
Starting training epoch 118
Epoch: 118, MSE 0.00177562 (0.00202678), PSNR 27.50649261 (26.95241865), SSIM 0.78223705 (0.77598742)
Finished training epoch 118
Validate: MSE 0.00271229 (0.00228

<Figure size 432x288 with 0 Axes>

In [20]:
torch.save(model.state_dict(), f'{checkpoints}/last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')

In [21]:
# Validate
save_images = True
with torch.no_grad():
    validate(val_loader, model, criterion, save_images, -1)

Validate: MSE 0.00268597 (0.00230487), PSNR 25.70898819 (26.42969883), SSIM 0.70112538 (0.76225404)
Finished validation.


<Figure size 432x288 with 0 Axes>

In [22]:
# # Show images 
# image_pairs = []

# for i in range(10):
#     image_pairs.append((f'{color_imgs}img-{i}-epoch-{best_epoch}.jpg', f'{gray_imgs}img-{i}-epoch-{best_epoch}.jpg'))
    
# for c, g in image_pairs:
#   color = mpimg.imread(c)
#   gray  = mpimg.imread(g)
#   f, axarr = plt.subplots(1, 2)
#   f.set_size_inches(15, 15)
#   axarr[0].imshow(gray, cmap='gray')
#   axarr[1].imshow(color)
#   axarr[0].axis('off'), axarr[1].axis('off')
#   plt.show()