In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# For conversion
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchmetrics import MeanSquaredError, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from PIL import Image
# For utilities
import os, shutil, time

In [2]:
%load_ext tensorboard
%tensorboard --logdir=runs

In [3]:
# colab i kaggle jeszcze nie testowane
colab = False
kaggle = False
test_number = '12_2'

In [4]:
color_imgs = 'outputs/color/'
gray_imgs = 'outputs/gray/'
checkpoints = 'checkpoints'
if colab:
    from google.colab import drive
    drive.mount('/content/drive')
    dataset = '/content/drive/MyDrive/MGU/cifar10/'
    
    color_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{color_imgs}'
    gray_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{gray_imgs}'
    checkpoints = f'/content/drive/MyDrive/MGU/{test_number}/{checkpoints}'
elif kaggle:
    dataset = '/kaggle/input/cifar10/'
    
    color_imgs = f'{test_number}/{color_imgs}'
    gray_imgs = f'{test_number}/{gray_imgs}'
    checkpoints = f'{test_number}/{checkpoints}'
else:
    dataset = '../../datasets/cifar10/'

In [5]:
# Make folders and set parameters
os.makedirs(color_imgs, exist_ok=True)
os.makedirs(gray_imgs, exist_ok=True)
os.makedirs(checkpoints, exist_ok=True)
save_images = True
best_losses = [1e10, 1e10, 1e10]
best_epoch = -1
patience = 50
epochs = 500
batch_size = 128
SIZE = 32

In [6]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [7]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()
print(use_gpu)

True


In [8]:
class LabImageFolder(torch.utils.data.Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC),
                transforms.RandomCrop(SIZE),
                transforms.RandomHorizontalFlip(),  
                transforms.RandomVerticalFlip(),  
                transforms.RandomRotation(10),
                transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05)
            ])
        elif split == 'val':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC), 
                transforms.RandomCrop(SIZE), 
            ])
            
        self.split = split
        self.size = SIZE
        self.paths = [os.path.join(paths, file) for file in os.listdir(paths) if os.path.isfile(
            os.path.join(paths, file))]
        
        
    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("RGB")
        img_original = self.transforms(img)
        img_original = np.asarray(img_original)
        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        img_gray = rgb2gray(img_original)
        img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
        return img_gray, img_ab
    
    def __len__(self):
        return len(self.paths)

In [9]:
# Training
train_imagefolder = LabImageFolder(dataset + 'train')
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=batch_size, shuffle=True)
# Validation 
val_imagefolder = LabImageFolder(dataset + 'val' , 'val')
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=batch_size, shuffle=False)

In [10]:
kernel_size=3
stride_en=1
stride_de=1
padding=1
scale_factor=2
padding_mode='zeros'
channels_base = 256
p1 = .25

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        self.conv1 = nn.Conv2d(1, channels_base, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        self.conv2 = nn.Conv2d(channels_base, channels_base * 2, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)

        self.convtrans1 = nn.ConvTranspose2d(channels_base * 2, channels_base, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans2 = nn.ConvTranspose2d(channels_base, channels_base // 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans3 = nn.ConvTranspose2d(channels_base // 2, 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)

        self.batchnorm1 = nn.BatchNorm2d(channels_base)
        self.batchnorm2 = nn.BatchNorm2d(channels_base * 2)
        self.batchnorm3 = nn.BatchNorm2d(channels_base)
        self.batchnorm4 = nn.BatchNorm2d(channels_base // 2)
        self.batchnorm5 = nn.BatchNorm2d(2)
        
    def forward(self, input):
        # encoder
        x = F.leaky_relu(self.batchnorm1(self.conv1(input)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = y = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        x = F.leaky_relu(self.batchnorm2(self.conv2(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        
        # decoder
        x = F.leaky_relu(self.batchnorm3(self.convtrans1(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.leaky_relu(self.batchnorm4(self.convtrans2(x + y)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.leaky_relu(self.batchnorm5(self.convtrans3(x + input)), negative_slope=0.1)

        return x

In [11]:
model = Autoencoder()

In [12]:
criterion = [MeanSquaredError(), PeakSignalNoiseRatio(data_range=1.0), StructuralSimilarityIndexMeasure(data_range=1.0)]

In [13]:
optimizer = torch.optim.SGD(model.parameters(), lr=1e-1, momentum=0.9)

In [14]:
# # Move model and loss function to GPU
if use_gpu: 
    criterion = [criterion[0].to("cuda"), criterion[1].to("cuda"), criterion[2].to("cuda")]
    model = model.cuda()

In [15]:
if use_gpu: 
    from torchsummary import summary
    summary(model, (1, SIZE, SIZE))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 256, 32, 32]           2,560
       BatchNorm2d-2          [-1, 256, 32, 32]             512
            Conv2d-3          [-1, 512, 16, 16]       1,180,160
       BatchNorm2d-4          [-1, 512, 16, 16]           1,024
   ConvTranspose2d-5            [-1, 256, 8, 8]       1,179,904
       BatchNorm2d-6            [-1, 256, 8, 8]             512
   ConvTranspose2d-7          [-1, 128, 16, 16]         295,040
       BatchNorm2d-8          [-1, 128, 16, 16]             256
   ConvTranspose2d-9            [-1, 2, 32, 32]           2,306
      BatchNorm2d-10            [-1, 2, 32, 32]               4
Total params: 2,662,278
Trainable params: 2,662,278
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 6.78
Params size (MB): 10.16
Estima

In [16]:
class AverageMeter(object):
    '''A handy class from the PyTorch ImageNet tutorial''' 
    def __init__(self):
        self.reset()
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
       Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
    color_image = lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    if save_path is not None and save_name is not None: 
        plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
        plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

In [17]:
def validate(val_loader, model, criterion, save_images, epoch):
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]

    model.eval()
    already_saved_images = False
    for gray, ab in val_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()

        # Run model and record loss
        output_ab = model(gray) # throw away class predictions
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))

        # Save images to file
        if save_images and not already_saved_images:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)): # save at most 5 images
                save_path = {'grayscale': gray_imgs, 'colorized': color_imgs}
                save_name = 'img-{}-epoch-{}.jpg'.format(j, epoch)
                to_rgb(gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

    print(f'Validate: MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print('Finished validation.')
    if epoch >= 0:
        writer.add_scalar("MSE/test", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/test", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/test", _loss[2].avg, epoch)
    return _loss[0].avg, _loss[1].avg, _loss[2].avg

In [18]:
def train(train_loader, model, criterion, optimizer, epoch):
    print(f'Starting training epoch {epoch}')
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]
    
    model.train()

    for gray, ab in train_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()
            
        optimizer.zero_grad()

        output_ab = model(gray) 
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        loss[0].backward()
        optimizer.step()
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))
        
    print(f'Epoch: {epoch}, MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print(f'Finished training epoch {epoch}')
    if epoch >= 0:
        writer.add_scalar("MSE/train", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/train", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/train", _loss[2].avg, epoch)

In [19]:
# Train model
for epoch in range(epochs):
    # Train for one epoch, then validate
    train(train_loader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(val_loader, model, criterion, save_images, epoch)
    # Save checkpoint and replace old best model if current model is better
    if losses[0] < best_losses[0]:
        best_losses[0] = losses[0]
        best_epoch = epoch
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}.pth')
    if losses[1] < best_losses[1]:
        best_losses[1] = losses[1]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-PSNRLoss-{losses[1]:.8f}.pth')
    if losses[2] < best_losses[2]:
        best_losses[2] = losses[2]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-SSIMLoss-{losses[2]:.8f}.pth')
    
    if epoch - best_epoch >= patience and epoch >= 100:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}-early_stop.pth')
        break
    
    if epoch == epochs - 1:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')


Starting training epoch 0
Epoch: 0, MSE 0.00235885 (0.01042341), PSNR 26.27300262 (24.30010951), SSIM 0.78276461 (0.71723842)
Finished training epoch 0
Validate: MSE 0.00317596 (0.00275923), PSNR 24.98125458 (25.66937914), SSIM 0.70689124 (0.77535872)
Finished validation.
Starting training epoch 1
Epoch: 1, MSE 0.00230724 (0.00276383), PSNR 26.36907387 (25.60730027), SSIM 0.78261352 (0.76921813)
Finished training epoch 1
Validate: MSE 0.00324621 (0.00276963), PSNR 24.88623428 (25.64948160), SSIM 0.70645392 (0.77597161)
Finished validation.
Starting training epoch 2
Epoch: 2, MSE 0.00314887 (0.00276381), PSNR 25.01845360 (25.60807843), SSIM 0.76137704 (0.76920068)
Finished training epoch 2
Validate: MSE 0.00318277 (0.00276236), PSNR 24.97194862 (25.66433061), SSIM 0.70645636 (0.77407864)
Finished validation.
Starting training epoch 3
Epoch: 3, MSE 0.00296145 (0.00275155), PSNR 25.28495026 (25.62525479), SSIM 0.74671787 (0.76936509)
Finished training epoch 3
Validate: MSE 0.00320415 (0.0

Epoch: 30, MSE 0.00324368 (0.00274652), PSNR 24.88961601 (25.63402430), SSIM 0.76585710 (0.76988708)
Finished training epoch 30
Validate: MSE 0.00328586 (0.00277535), PSNR 24.83351326 (25.64083206), SSIM 0.70722431 (0.77647367)
Finished validation.
Starting training epoch 31
Epoch: 31, MSE 0.00314896 (0.00274795), PSNR 25.01832199 (25.63243118), SSIM 0.77711183 (0.76964513)
Finished training epoch 31
Validate: MSE 0.00324038 (0.00275333), PSNR 24.89404297 (25.67414033), SSIM 0.70806223 (0.77612548)
Finished validation.
Starting training epoch 32
Epoch: 32, MSE 0.00257181 (0.00273832), PSNR 25.89760780 (25.64722533), SSIM 0.76115590 (0.77002095)
Finished training epoch 32
Validate: MSE 0.00313173 (0.00273684), PSNR 25.04215431 (25.70609523), SSIM 0.70846617 (0.77582368)
Finished validation.
Starting training epoch 33
Epoch: 33, MSE 0.00233047 (0.00275406), PSNR 26.32556915 (25.62319534), SSIM 0.78422701 (0.76960755)
Finished training epoch 33
Validate: MSE 0.00317503 (0.00274919), PSNR 

Epoch: 60, MSE 0.00227564 (0.00256009), PSNR 26.42896652 (25.94061686), SSIM 0.76620930 (0.76466864)
Finished training epoch 60
Validate: MSE 0.00299092 (0.00256525), PSNR 25.24195480 (25.97591328), SSIM 0.70776576 (0.77086612)
Finished validation.
Starting training epoch 61
Epoch: 61, MSE 0.00376656 (0.00255198), PSNR 24.24054718 (25.95366472), SSIM 0.72412831 (0.76470643)
Finished training epoch 61
Validate: MSE 0.00295399 (0.00253800), PSNR 25.29590607 (26.01862931), SSIM 0.70388567 (0.76782698)
Finished validation.
Starting training epoch 62
Epoch: 62, MSE 0.00354556 (0.00255298), PSNR 24.50314713 (25.95129493), SSIM 0.71935976 (0.76457838)
Finished training epoch 62
Validate: MSE 0.00294927 (0.00253405), PSNR 25.30286026 (26.02529485), SSIM 0.70318097 (0.76725309)
Finished validation.
Starting training epoch 63
Epoch: 63, MSE 0.00240451 (0.00254522), PSNR 26.18973541 (25.96659975), SSIM 0.76193511 (0.76447244)
Finished training epoch 63
Validate: MSE 0.00287744 (0.00252400), PSNR 

Epoch: 90, MSE 0.00270991 (0.00248591), PSNR 25.67044640 (26.06837488), SSIM 0.75659138 (0.76621256)
Finished training epoch 90
Validate: MSE 0.00296006 (0.00248224), PSNR 25.28698730 (26.11450776), SSIM 0.70696372 (0.77064252)
Finished validation.
Starting training epoch 91
Epoch: 91, MSE 0.00229723 (0.00247757), PSNR 26.38795280 (26.08136859), SSIM 0.77940673 (0.76617164)
Finished training epoch 91
Validate: MSE 0.00295243 (0.00246942), PSNR 25.29819870 (26.13522975), SSIM 0.70928836 (0.77192303)
Finished validation.
Starting training epoch 92
Epoch: 92, MSE 0.00224138 (0.00248699), PSNR 26.49485016 (26.06393666), SSIM 0.76761502 (0.76590773)
Finished training epoch 92
Validate: MSE 0.00285208 (0.00245802), PSNR 25.44838333 (26.15379327), SSIM 0.70707834 (0.76844352)
Finished validation.
Starting training epoch 93
Epoch: 93, MSE 0.00297783 (0.00248529), PSNR 25.26100349 (26.06533263), SSIM 0.74034607 (0.76609172)
Finished training epoch 93
Validate: MSE 0.00286996 (0.00245578), PSNR 

Epoch: 120, MSE 0.00294667 (0.00245304), PSNR 25.30668259 (26.12498110), SSIM 0.75871438 (0.76731975)
Finished training epoch 120
Validate: MSE 0.00286568 (0.00242171), PSNR 25.42772293 (26.21843194), SSIM 0.70732176 (0.77209305)
Finished validation.
Starting training epoch 121
Epoch: 121, MSE 0.00218995 (0.00245567), PSNR 26.59565735 (26.12404433), SSIM 0.76148224 (0.76739127)
Finished training epoch 121
Validate: MSE 0.00282830 (0.00241426), PSNR 25.48474884 (26.23488925), SSIM 0.71313334 (0.77573179)
Finished validation.
Starting training epoch 122
Epoch: 122, MSE 0.00242794 (0.00245520), PSNR 26.14762115 (26.12269351), SSIM 0.78043622 (0.76763264)
Finished training epoch 122
Validate: MSE 0.00285086 (0.00245052), PSNR 25.45023918 (26.17027351), SSIM 0.71070188 (0.77122953)
Finished validation.
Starting training epoch 123
Epoch: 123, MSE 0.00219203 (0.00245388), PSNR 26.59153366 (26.12536056), SSIM 0.77721250 (0.76757756)
Finished training epoch 123
Validate: MSE 0.00284351 (0.00240

Validate: MSE 0.00285914 (0.00241355), PSNR 25.43764877 (26.23524543), SSIM 0.71198350 (0.77800937)
Finished validation.
Starting training epoch 150
Epoch: 150, MSE 0.00233983 (0.00243131), PSNR 26.30815697 (26.16540809), SSIM 0.77309126 (0.76862504)
Finished training epoch 150
Validate: MSE 0.00279958 (0.00240369), PSNR 25.52907181 (26.25056988), SSIM 0.71009833 (0.77348758)
Finished validation.
Starting training epoch 151
Epoch: 151, MSE 0.00212588 (0.00243163), PSNR 26.72460365 (26.16430801), SSIM 0.76144183 (0.76851134)
Finished training epoch 151
Validate: MSE 0.00289737 (0.00240358), PSNR 25.37995529 (26.25452509), SSIM 0.71323693 (0.77678447)
Finished validation.
Starting training epoch 152
Epoch: 152, MSE 0.00267537 (0.00243354), PSNR 25.72615814 (26.16273300), SSIM 0.73397112 (0.76852986)
Finished training epoch 152
Validate: MSE 0.00276889 (0.00239259), PSNR 25.57693672 (26.27084095), SSIM 0.71054912 (0.77445946)
Finished validation.
Starting training epoch 153
Epoch: 153, MS

Epoch: 179, MSE 0.00268271 (0.00242024), PSNR 25.71426201 (26.18475356), SSIM 0.74581319 (0.76924199)
Finished training epoch 179
Validate: MSE 0.00280620 (0.00239216), PSNR 25.51881027 (26.27381171), SSIM 0.71282971 (0.77332953)
Finished validation.
Starting training epoch 180
Epoch: 180, MSE 0.00289590 (0.00242054), PSNR 25.38216209 (26.18483518), SSIM 0.74552947 (0.76904563)
Finished training epoch 180
Validate: MSE 0.00284237 (0.00239634), PSNR 25.46318436 (26.26756924), SSIM 0.71348250 (0.77713823)
Finished validation.
Starting training epoch 181
Epoch: 181, MSE 0.00222526 (0.00241382), PSNR 26.52618599 (26.19356742), SSIM 0.76597464 (0.76943269)
Finished training epoch 181
Validate: MSE 0.00277631 (0.00238411), PSNR 25.56532288 (26.28652776), SSIM 0.71259010 (0.77421983)
Finished validation.
Starting training epoch 182
Epoch: 182, MSE 0.00264825 (0.00242322), PSNR 25.77040100 (26.17792964), SSIM 0.74933577 (0.76930538)
Finished training epoch 182
Validate: MSE 0.00281399 (0.00238

Validate: MSE 0.00281338 (0.00238911), PSNR 25.50771141 (26.28099137), SSIM 0.71496689 (0.77742076)
Finished validation.
Starting training epoch 209
Epoch: 209, MSE 0.00212946 (0.00240368), PSNR 26.71730423 (26.21385665), SSIM 0.76651919 (0.76977665)
Finished training epoch 209
Validate: MSE 0.00285032 (0.00239434), PSNR 25.45105553 (26.27103231), SSIM 0.71402085 (0.77822684)
Finished validation.
Starting training epoch 210
Epoch: 210, MSE 0.00271016 (0.00241762), PSNR 25.67004967 (26.19052446), SSIM 0.75165355 (0.76975145)
Finished training epoch 210
Validate: MSE 0.00278641 (0.00237579), PSNR 25.54955482 (26.30412658), SSIM 0.71601152 (0.77927274)
Finished validation.
Starting training epoch 211
Epoch: 211, MSE 0.00222780 (0.00241109), PSNR 26.52123070 (26.20348245), SSIM 0.75572348 (0.76986492)
Finished training epoch 211
Validate: MSE 0.00276698 (0.00237421), PSNR 25.57993317 (26.30471234), SSIM 0.71562660 (0.77783598)
Finished validation.
Starting training epoch 212
Epoch: 212, MS

Epoch: 238, MSE 0.00229570 (0.00240260), PSNR 26.39084625 (26.21919827), SSIM 0.78710335 (0.77026402)
Finished training epoch 238
Validate: MSE 0.00283164 (0.00237636), PSNR 25.47961807 (26.30475843), SSIM 0.71654463 (0.77930837)
Finished validation.
Starting training epoch 239
Epoch: 239, MSE 0.00253618 (0.00240439), PSNR 25.95820618 (26.21082142), SSIM 0.77164829 (0.77027847)
Finished training epoch 239
Validate: MSE 0.00283617 (0.00238584), PSNR 25.47267151 (26.28549802), SSIM 0.71390879 (0.77697555)
Finished validation.
Starting training epoch 240
Epoch: 240, MSE 0.00238205 (0.00240127), PSNR 26.23048401 (26.21667766), SSIM 0.77637476 (0.77022881)
Finished training epoch 240
Validate: MSE 0.00274518 (0.00237167), PSNR 25.61429024 (26.30856770), SSIM 0.71419913 (0.77737241)
Finished validation.
Starting training epoch 241
Epoch: 241, MSE 0.00214290 (0.00240130), PSNR 26.68997574 (26.21752976), SSIM 0.76922196 (0.77029179)
Finished training epoch 241
Validate: MSE 0.00282946 (0.00238

Validate: MSE 0.00277498 (0.00235528), PSNR 25.56740570 (26.34103790), SSIM 0.71705401 (0.77914221)
Finished validation.
Starting training epoch 268
Epoch: 268, MSE 0.00222865 (0.00239021), PSNR 26.51958656 (26.23826737), SSIM 0.77309740 (0.77082885)
Finished training epoch 268
Validate: MSE 0.00275317 (0.00235187), PSNR 25.60167122 (26.34765403), SSIM 0.71531057 (0.77802333)
Finished validation.
Starting training epoch 269
Epoch: 269, MSE 0.00205807 (0.00239563), PSNR 26.86540031 (26.22754919), SSIM 0.77525437 (0.77053238)
Finished training epoch 269
Validate: MSE 0.00272350 (0.00237395), PSNR 25.64873314 (26.30382728), SSIM 0.71487963 (0.77765602)
Finished validation.
Starting training epoch 270
Epoch: 270, MSE 0.00248166 (0.00238879), PSNR 26.05258179 (26.24096917), SSIM 0.76144809 (0.77082530)
Finished training epoch 270
Validate: MSE 0.00270857 (0.00236406), PSNR 25.67259979 (26.32100031), SSIM 0.71511751 (0.77553152)
Finished validation.
Starting training epoch 271
Epoch: 271, MS

Epoch: 297, MSE 0.00275414 (0.00238379), PSNR 25.60014343 (26.24914325), SSIM 0.74907339 (0.77110914)
Finished training epoch 297
Validate: MSE 0.00276326 (0.00238063), PSNR 25.58578110 (26.29412675), SSIM 0.71587926 (0.77757887)
Finished validation.
Starting training epoch 298
Epoch: 298, MSE 0.00246783 (0.00239186), PSNR 26.07685471 (26.23714819), SSIM 0.78330457 (0.77095083)
Finished training epoch 298
Validate: MSE 0.00277515 (0.00236300), PSNR 25.56713104 (26.32691906), SSIM 0.71599197 (0.77743110)
Finished validation.
Starting training epoch 299
Epoch: 299, MSE 0.00276134 (0.00238944), PSNR 25.58879471 (26.24090234), SSIM 0.76176894 (0.77093844)
Finished training epoch 299
Validate: MSE 0.00275099 (0.00234619), PSNR 25.60511398 (26.35747307), SSIM 0.71869540 (0.77891465)
Finished validation.
Starting training epoch 300
Epoch: 300, MSE 0.00249464 (0.00238672), PSNR 26.02992630 (26.24264026), SSIM 0.77883065 (0.77086305)
Finished training epoch 300
Validate: MSE 0.00269168 (0.00235

Validate: MSE 0.00278953 (0.00236329), PSNR 25.54468918 (26.32841008), SSIM 0.71880543 (0.77896971)
Finished validation.
Starting training epoch 327
Epoch: 327, MSE 0.00218564 (0.00238014), PSNR 26.60421371 (26.25950828), SSIM 0.78557479 (0.77128088)
Finished training epoch 327
Validate: MSE 0.00277340 (0.00234488), PSNR 25.56987190 (26.36114878), SSIM 0.71842927 (0.77951232)
Finished validation.
Starting training epoch 328
Epoch: 328, MSE 0.00294110 (0.00238041), PSNR 25.31489563 (26.25482015), SSIM 0.74692410 (0.77135436)
Finished training epoch 328
Validate: MSE 0.00268941 (0.00234745), PSNR 25.70342827 (26.35284934), SSIM 0.71721703 (0.77587295)
Finished validation.
Starting training epoch 329
Epoch: 329, MSE 0.00266265 (0.00237340), PSNR 25.74685478 (26.26857359), SSIM 0.77844965 (0.77137520)
Finished training epoch 329
Validate: MSE 0.00269942 (0.00234198), PSNR 25.68729019 (26.36347447), SSIM 0.71816146 (0.77837570)
Finished validation.
Starting training epoch 330
Epoch: 330, MS

Epoch: 356, MSE 0.00258667 (0.00237757), PSNR 25.87258720 (26.26032887), SSIM 0.75242388 (0.77150783)
Finished training epoch 356
Validate: MSE 0.00271612 (0.00233955), PSNR 25.66051483 (26.36874391), SSIM 0.71778649 (0.77922349)
Finished validation.
Starting training epoch 357
Epoch: 357, MSE 0.00233365 (0.00236977), PSNR 26.31963348 (26.27473052), SSIM 0.76964378 (0.77154474)
Finished training epoch 357
Validate: MSE 0.00269020 (0.00233402), PSNR 25.70215416 (26.37763475), SSIM 0.71622729 (0.77697991)
Finished validation.
Starting training epoch 358
Epoch: 358, MSE 0.00233389 (0.00237822), PSNR 26.31919479 (26.25996741), SSIM 0.75642157 (0.77123903)
Finished training epoch 358
Validate: MSE 0.00271294 (0.00235113), PSNR 25.66560173 (26.34712256), SSIM 0.71768379 (0.77667011)
Finished validation.
Starting training epoch 359
Epoch: 359, MSE 0.00246343 (0.00237616), PSNR 26.08459854 (26.26478483), SSIM 0.75554144 (0.77146891)
Finished training epoch 359
Validate: MSE 0.00274060 (0.00234

Validate: MSE 0.00277403 (0.00234828), PSNR 25.56888199 (26.35438809), SSIM 0.71750832 (0.77963056)
Finished validation.
Starting training epoch 386
Epoch: 386, MSE 0.00301507 (0.00237363), PSNR 25.20702934 (26.27238980), SSIM 0.75482011 (0.77149772)
Finished training epoch 386
Validate: MSE 0.00267005 (0.00233759), PSNR 25.73479843 (26.37053742), SSIM 0.71746349 (0.77750489)
Finished validation.
Starting training epoch 387
Epoch: 387, MSE 0.00269688 (0.00237011), PSNR 25.69138527 (26.27621929), SSIM 0.75029111 (0.77170787)
Finished training epoch 387
Validate: MSE 0.00264192 (0.00235179), PSNR 25.78080177 (26.34325485), SSIM 0.71411079 (0.77467871)
Finished validation.
Starting training epoch 388
Epoch: 388, MSE 0.00230082 (0.00237154), PSNR 26.38116646 (26.27135139), SSIM 0.77574116 (0.77151852)
Finished training epoch 388
Validate: MSE 0.00272568 (0.00235086), PSNR 25.64525604 (26.34635636), SSIM 0.71770263 (0.78033218)
Finished validation.
Starting training epoch 389
Epoch: 389, MS

Epoch: 415, MSE 0.00224790 (0.00237055), PSNR 26.48222733 (26.27507224), SSIM 0.76769346 (0.77167960)
Finished training epoch 415
Validate: MSE 0.00267844 (0.00232505), PSNR 25.72118378 (26.39521332), SSIM 0.72041953 (0.77837026)
Finished validation.
Starting training epoch 416
Epoch: 416, MSE 0.00229014 (0.00236570), PSNR 26.40138626 (26.28284980), SSIM 0.78776920 (0.77170547)
Finished training epoch 416
Validate: MSE 0.00269147 (0.00232006), PSNR 25.70010757 (26.40482596), SSIM 0.71964419 (0.77847361)
Finished validation.
Starting training epoch 417
Epoch: 417, MSE 0.00226103 (0.00237037), PSNR 26.45693779 (26.27437839), SSIM 0.78152728 (0.77180689)
Finished training epoch 417
Validate: MSE 0.00270200 (0.00234772), PSNR 25.68313789 (26.35342171), SSIM 0.71843320 (0.78115237)
Finished validation.
Starting training epoch 418
Epoch: 418, MSE 0.00211440 (0.00237315), PSNR 26.74812317 (26.26764427), SSIM 0.77327776 (0.77155796)
Finished training epoch 418
Validate: MSE 0.00271725 (0.00232

Validate: MSE 0.00272602 (0.00233266), PSNR 25.64470863 (26.38223598), SSIM 0.71800297 (0.77946325)
Finished validation.
Starting training epoch 445
Epoch: 445, MSE 0.00194329 (0.00236698), PSNR 27.11462021 (26.28211537), SSIM 0.78344303 (0.77177416)
Finished training epoch 445
Validate: MSE 0.00267237 (0.00232978), PSNR 25.73102570 (26.38652189), SSIM 0.71776736 (0.77704517)
Finished validation.
Starting training epoch 446
Epoch: 446, MSE 0.00187975 (0.00236046), PSNR 27.25898743 (26.29392276), SSIM 0.77480423 (0.77193039)
Finished training epoch 446
Validate: MSE 0.00270885 (0.00233730), PSNR 25.67214584 (26.37300934), SSIM 0.71788192 (0.77703976)
Finished validation.
Starting training epoch 447
Epoch: 447, MSE 0.00223501 (0.00235767), PSNR 26.50720406 (26.29606642), SSIM 0.75869632 (0.77195475)
Finished training epoch 447
Validate: MSE 0.00269762 (0.00232637), PSNR 25.69018364 (26.39146325), SSIM 0.71837246 (0.77766770)
Finished validation.
Starting training epoch 448
Epoch: 448, MS

<Figure size 432x288 with 0 Axes>

In [20]:
torch.save(model.state_dict(), f'{checkpoints}/last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')

In [21]:
# Validate
save_images = True
with torch.no_grad():
    validate(val_loader, model, criterion, save_images, -1)

Validate: MSE 0.00271530 (0.00232422), PSNR 25.66182137 (26.39883887), SSIM 0.71675158 (0.77777563)
Finished validation.


<Figure size 432x288 with 0 Axes>

In [22]:
# # Show images 
# image_pairs = []

# for i in range(10):
#     image_pairs.append((f'{color_imgs}img-{i}-epoch-{best_epoch}.jpg', f'{gray_imgs}img-{i}-epoch-{best_epoch}.jpg'))
    
# for c, g in image_pairs:
#   color = mpimg.imread(c)
#   gray  = mpimg.imread(g)
#   f, axarr = plt.subplots(1, 2)
#   f.set_size_inches(15, 15)
#   axarr[0].imshow(gray, cmap='gray')
#   axarr[1].imshow(color)
#   axarr[0].axis('off'), axarr[1].axis('off')
#   plt.show()