In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# For conversion
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchmetrics import MeanSquaredError, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from PIL import Image
# For utilities
import os, shutil, time

In [2]:
%load_ext tensorboard
%tensorboard --logdir=runs

In [3]:
# colab i kaggle jeszcze nie testowane
colab = False
kaggle = False
test_number = '12_2'

In [4]:
color_imgs = 'outputs/color/'
gray_imgs = 'outputs/gray/'
checkpoints = 'checkpoints'
if colab:
    from google.colab import drive
    drive.mount('/content/drive')
    dataset = '/content/drive/MyDrive/MGU/cifar10/'
    
    color_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{color_imgs}'
    gray_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{gray_imgs}'
    checkpoints = f'/content/drive/MyDrive/MGU/{test_number}/{checkpoints}'
elif kaggle:
    dataset = '/kaggle/input/cifar10/'
    
    color_imgs = f'{test_number}/{color_imgs}'
    gray_imgs = f'{test_number}/{gray_imgs}'
    checkpoints = f'{test_number}/{checkpoints}'
else:
    dataset = '../../datasets/cifar10/'

In [5]:
# Make folders and set parameters
os.makedirs(color_imgs, exist_ok=True)
os.makedirs(gray_imgs, exist_ok=True)
os.makedirs(checkpoints, exist_ok=True)
save_images = True
best_losses = [1e10, 1e10, 1e10]
best_epoch = -1
patience = 50
epochs = 500
batch_size = 128
SIZE = 32

In [6]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [7]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()
print(use_gpu)

True


In [8]:
class LabImageFolder(torch.utils.data.Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC),
                transforms.RandomCrop(SIZE),
                transforms.RandomHorizontalFlip(),  
                transforms.RandomVerticalFlip(),  
                transforms.RandomRotation(10),
                transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05)
            ])
        elif split == 'val':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC), 
                transforms.RandomCrop(SIZE), 
            ])
            
        self.split = split
        self.size = SIZE
        self.paths = [os.path.join(paths, file) for file in os.listdir(paths) if os.path.isfile(
            os.path.join(paths, file))]
        
        
    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("RGB")
        img_original = self.transforms(img)
        img_original = np.asarray(img_original)
        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        img_gray = rgb2gray(img_original)
        img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
        return img_gray, img_ab
    
    def __len__(self):
        return len(self.paths)

In [9]:
# Training
train_imagefolder = LabImageFolder(dataset + 'train')
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=batch_size, shuffle=True)
# Validation 
val_imagefolder = LabImageFolder(dataset + 'val' , 'val')
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=batch_size, shuffle=False)

In [10]:
kernel_size=3
stride_en=1
stride_de=1
padding=1
scale_factor=2
padding_mode='zeros'
channels_base = 256
p1 = .25

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        self.conv1 = nn.Conv2d(1, channels_base, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        self.conv2 = nn.Conv2d(channels_base, channels_base * 2, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)

        self.convtrans1 = nn.ConvTranspose2d(channels_base * 2, channels_base * 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans2 = nn.ConvTranspose2d(channels_base * 4, channels_base, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans3 = nn.ConvTranspose2d(channels_base * 2, 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)

        self.batchnorm1 = nn.BatchNorm2d(channels_base)
        self.batchnorm2 = nn.BatchNorm2d(channels_base * 2)
        self.batchnorm3 = nn.BatchNorm2d(channels_base * 2)
        self.batchnorm4 = nn.BatchNorm2d(channels_base)
        self.batchnorm5 = nn.BatchNorm2d(2)
        
    def forward(self, input):
        # encoder
        x = F.leaky_relu(self.batchnorm1(self.conv1(input)), negative_slope=0.1)
        x = y = F.dropout(x, p=p1)
        x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        x = F.leaky_relu(self.batchnorm2(self.conv2(x)), negative_slope=0.1)
        x = y1 = F.dropout(x, p=p1)
        x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        
        # decoder
        x = F.leaky_relu(self.batchnorm3(self.convtrans1(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.leaky_relu(self.batchnorm4(self.convtrans2(torch.cat((x, y1), 1))), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.leaky_relu(self.batchnorm5(self.convtrans3(torch.cat((x, y), 1))), negative_slope=0.1)

        return x

In [11]:
model = Autoencoder()

In [12]:
criterion = [MeanSquaredError(), PeakSignalNoiseRatio(data_range=1.0), StructuralSimilarityIndexMeasure(data_range=1.0)]

In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [14]:
# # Move model and loss function to GPU
if use_gpu: 
    criterion = [criterion[0].to("cuda"), criterion[1].to("cuda"), criterion[2].to("cuda")]
    model = model.cuda()

In [15]:
if use_gpu: 
    from torchsummary import summary
    summary(model, (1, SIZE, SIZE))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 256, 32, 32]           2,560
       BatchNorm2d-2          [-1, 256, 32, 32]             512
            Conv2d-3          [-1, 512, 16, 16]       1,180,160
       BatchNorm2d-4          [-1, 512, 16, 16]           1,024
   ConvTranspose2d-5            [-1, 512, 8, 8]       2,359,808
       BatchNorm2d-6            [-1, 512, 8, 8]           1,024
   ConvTranspose2d-7          [-1, 256, 16, 16]       2,359,552
       BatchNorm2d-8          [-1, 256, 16, 16]             512
   ConvTranspose2d-9            [-1, 2, 32, 32]           9,218
      BatchNorm2d-10            [-1, 2, 32, 32]               4
Total params: 5,914,374
Trainable params: 5,914,374
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 7.53
Params size (MB): 22.56
Estima

In [16]:
class AverageMeter(object):
    '''A handy class from the PyTorch ImageNet tutorial''' 
    def __init__(self):
        self.reset()
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
       Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
    color_image = lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    if save_path is not None and save_name is not None: 
        plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
        plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

In [17]:
def validate(val_loader, model, criterion, save_images, epoch):
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]

    model.eval()
    already_saved_images = False
    for gray, ab in val_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()

        # Run model and record loss
        output_ab = model(gray) # throw away class predictions
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))

        # Save images to file
        if save_images and not already_saved_images:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)): # save at most 5 images
                save_path = {'grayscale': gray_imgs, 'colorized': color_imgs}
                save_name = 'img-{}-epoch-{}.jpg'.format(j, epoch)
                to_rgb(gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

    print(f'Validate: MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print('Finished validation.')
    if epoch >= 0:
        writer.add_scalar("MSE/test", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/test", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/test", _loss[2].avg, epoch)
    return _loss[0].avg, _loss[1].avg, _loss[2].avg

In [18]:
def train(train_loader, model, criterion, optimizer, epoch):
    print(f'Starting training epoch {epoch}')
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]
    
    model.train()

    for gray, ab in train_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()
            
        optimizer.zero_grad()

        output_ab = model(gray) 
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        loss[0].backward()
        optimizer.step()
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))
        
    print(f'Epoch: {epoch}, MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print(f'Finished training epoch {epoch}')
    if epoch >= 0:
        writer.add_scalar("MSE/train", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/train", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/train", _loss[2].avg, epoch)

In [19]:
# Train model
for epoch in range(epochs):
    # Train for one epoch, then validate
    train(train_loader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(val_loader, model, criterion, save_images, epoch)
    # Save checkpoint and replace old best model if current model is better
    if losses[0] < best_losses[0]:
        best_losses[0] = losses[0]
        best_epoch = epoch
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}.pth')
    if losses[1] < best_losses[1]:
        best_losses[1] = losses[1]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-PSNRLoss-{losses[1]:.8f}.pth')
    if losses[2] < best_losses[2]:
        best_losses[2] = losses[2]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-SSIMLoss-{losses[2]:.8f}.pth')
    
    if epoch - best_epoch >= patience and epoch >= 100:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}-early_stop.pth')
        break
    
    if epoch == epochs - 1:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')


Starting training epoch 0
Epoch: 0, MSE 0.00260683 (0.02539646), PSNR 25.83887672 (19.67075554), SSIM 0.78443354 (0.56356671)
Finished training epoch 0
Validate: MSE 0.00323098 (0.00276707), PSNR 24.90665627 (25.65438153), SSIM 0.70478636 (0.77540276)
Finished validation.
Starting training epoch 1
Epoch: 1, MSE 0.00246911 (0.00276762), PSNR 26.07458878 (25.60189381), SSIM 0.78921789 (0.76919917)
Finished training epoch 1
Validate: MSE 0.00318588 (0.00277101), PSNR 24.96770096 (25.65076536), SSIM 0.70452988 (0.77523164)
Finished validation.
Starting training epoch 2
Epoch: 2, MSE 0.00322645 (0.00276198), PSNR 24.91275024 (25.61069731), SSIM 0.72757071 (0.76927202)
Finished training epoch 2
Validate: MSE 0.00318999 (0.00276560), PSNR 24.96210098 (25.65894791), SSIM 0.70469582 (0.77527850)
Finished validation.
Starting training epoch 3
Epoch: 3, MSE 0.00249500 (0.00276814), PSNR 26.02929306 (25.60197807), SSIM 0.77563041 (0.76946240)
Finished training epoch 3
Validate: MSE 0.00322156 (0.0

Epoch: 30, MSE 0.00291339 (0.00235073), PSNR 25.35600471 (26.31438294), SSIM 0.76715982 (0.77743220)
Finished training epoch 30
Validate: MSE 0.00265491 (0.00231735), PSNR 25.75950050 (26.41041068), SSIM 0.71889114 (0.78483069)
Finished validation.
Starting training epoch 31
Epoch: 31, MSE 0.00269248 (0.00233820), PSNR 25.69847870 (26.33595785), SSIM 0.76613688 (0.77770094)
Finished training epoch 31
Validate: MSE 0.00281381 (0.00238107), PSNR 25.50705528 (26.29763209), SSIM 0.72004682 (0.78410021)
Finished validation.
Starting training epoch 32
Epoch: 32, MSE 0.00251219 (0.00233814), PSNR 25.99946594 (26.33598994), SSIM 0.76145953 (0.77750429)
Finished training epoch 32
Validate: MSE 0.00280579 (0.00285976), PSNR 25.51945114 (25.51620258), SSIM 0.71812296 (0.77358646)
Finished validation.
Starting training epoch 33
Epoch: 33, MSE 0.00199560 (0.00232942), PSNR 26.99925232 (26.35170484), SSIM 0.79133844 (0.77796140)
Finished training epoch 33
Validate: MSE 0.00285687 (0.00242639), PSNR 

Epoch: 60, MSE 0.00227186 (0.00226223), PSNR 26.43619156 (26.47839053), SSIM 0.77077407 (0.77906271)
Finished training epoch 60
Validate: MSE 0.00263477 (0.00223266), PSNR 25.79256439 (26.57757120), SSIM 0.72594863 (0.78770677)
Finished validation.
Starting training epoch 61
Epoch: 61, MSE 0.00215007 (0.00226652), PSNR 26.67547226 (26.47266222), SSIM 0.78251797 (0.77885184)
Finished training epoch 61
Validate: MSE 0.00251097 (0.00261924), PSNR 26.00157928 (25.86266407), SSIM 0.71897894 (0.77231281)
Finished validation.
Starting training epoch 62
Epoch: 62, MSE 0.00247680 (0.00226528), PSNR 26.06108284 (26.47257541), SSIM 0.74821430 (0.77882617)
Finished training epoch 62
Validate: MSE 0.00268737 (0.00226434), PSNR 25.70672798 (26.51079151), SSIM 0.72894073 (0.78669783)
Finished validation.
Starting training epoch 63
Epoch: 63, MSE 0.00217201 (0.00226138), PSNR 26.63138962 (26.48122088), SSIM 0.78294593 (0.77905771)
Finished training epoch 63
Validate: MSE 0.00274635 (0.00226255), PSNR 

Epoch: 90, MSE 0.00236379 (0.00222946), PSNR 26.26390839 (26.54121924), SSIM 0.77005738 (0.77904119)
Finished training epoch 90
Validate: MSE 0.00251160 (0.00225987), PSNR 26.00049019 (26.51596912), SSIM 0.72208494 (0.78030426)
Finished validation.
Starting training epoch 91
Epoch: 91, MSE 0.00231877 (0.00222728), PSNR 26.34742928 (26.54578114), SSIM 0.77724272 (0.77912835)
Finished training epoch 91
Validate: MSE 0.00275025 (0.00237574), PSNR 25.60627937 (26.30776241), SSIM 0.72151434 (0.77918742)
Finished validation.
Starting training epoch 92
Epoch: 92, MSE 0.00190106 (0.00222543), PSNR 27.21003723 (26.54906991), SSIM 0.79190999 (0.77928280)
Finished training epoch 92
Validate: MSE 0.00225696 (0.00230527), PSNR 26.46475410 (26.43200174), SSIM 0.73808908 (0.77862561)
Finished validation.
Starting training epoch 93
Epoch: 93, MSE 0.00271729 (0.00221786), PSNR 25.65864563 (26.56225859), SSIM 0.77278721 (0.77896081)
Finished training epoch 93
Validate: MSE 0.00287160 (0.00241320), PSNR 

Epoch: 120, MSE 0.00266992 (0.00219696), PSNR 25.73502159 (26.60226550), SSIM 0.74893963 (0.77876291)
Finished training epoch 120
Validate: MSE 0.00255529 (0.00222316), PSNR 25.92560196 (26.59212418), SSIM 0.72610056 (0.78559034)
Finished validation.
Starting training epoch 121
Epoch: 121, MSE 0.00209582 (0.00219016), PSNR 26.78645134 (26.61720786), SSIM 0.77067530 (0.77929994)
Finished training epoch 121
Validate: MSE 0.00264993 (0.00227790), PSNR 25.76766205 (26.48370021), SSIM 0.71984506 (0.77940638)
Finished validation.
Starting training epoch 122
Epoch: 122, MSE 0.00233113 (0.00219981), PSNR 26.32432556 (26.60267062), SSIM 0.77415669 (0.77859267)
Finished training epoch 122
Validate: MSE 0.00270606 (0.00244583), PSNR 25.67662621 (26.18677269), SSIM 0.72205788 (0.78363660)
Finished validation.
Starting training epoch 123
Epoch: 123, MSE 0.00204272 (0.00219628), PSNR 26.89790344 (26.60370996), SSIM 0.78705156 (0.77868940)
Finished training epoch 123
Validate: MSE 0.00250812 (0.00219

Validate: MSE 0.00265179 (0.00223887), PSNR 25.76461601 (26.56657846), SSIM 0.71579581 (0.77997964)
Finished validation.
Starting training epoch 150
Epoch: 150, MSE 0.00277798 (0.00217459), PSNR 25.56270027 (26.65338241), SSIM 0.75293970 (0.77850740)
Finished training epoch 150
Validate: MSE 0.00269075 (0.00230807), PSNR 25.70125961 (26.43260402), SSIM 0.72476077 (0.78353999)
Finished validation.
Starting training epoch 151
Epoch: 151, MSE 0.00256003 (0.00217234), PSNR 25.91755295 (26.65235800), SSIM 0.77262890 (0.77875551)
Finished training epoch 151
Validate: MSE 0.00244085 (0.00224250), PSNR 26.12459373 (26.55196451), SSIM 0.72691911 (0.77952711)
Finished validation.
Starting training epoch 152
Epoch: 152, MSE 0.00190379 (0.00217208), PSNR 27.20381737 (26.65291182), SSIM 0.78639758 (0.77861806)
Finished training epoch 152
Validate: MSE 0.00234894 (0.00229746), PSNR 26.29127693 (26.43908638), SSIM 0.72912186 (0.77479057)
Finished validation.
Starting training epoch 153
Epoch: 153, MS

Epoch: 179, MSE 0.00209473 (0.00215882), PSNR 26.78871346 (26.67817154), SSIM 0.77292883 (0.77801895)
Finished training epoch 179
Validate: MSE 0.00254484 (0.00221347), PSNR 25.94339371 (26.60962403), SSIM 0.72465229 (0.78226190)
Finished validation.
Starting training epoch 180
Epoch: 180, MSE 0.00202767 (0.00215408), PSNR 26.93002701 (26.68909526), SSIM 0.77327418 (0.77817723)
Finished training epoch 180
Validate: MSE 0.00235974 (0.00229493), PSNR 26.27134895 (26.45097215), SSIM 0.72629136 (0.77119864)
Finished validation.
Starting training epoch 181
Epoch: 181, MSE 0.00171306 (0.00215732), PSNR 27.66226578 (26.68177991), SSIM 0.79762632 (0.77812564)
Finished training epoch 181
Validate: MSE 0.00254902 (0.00218053), PSNR 25.93627357 (26.67577904), SSIM 0.72467703 (0.78397535)
Finished validation.
Starting training epoch 182
Epoch: 182, MSE 0.00224073 (0.00214659), PSNR 26.49609756 (26.70384279), SSIM 0.77862322 (0.77818127)
Finished training epoch 182
Validate: MSE 0.00238826 (0.00222

Validate: MSE 0.00243268 (0.00220878), PSNR 26.13914108 (26.61880024), SSIM 0.72469938 (0.78127278)
Finished validation.
Starting training epoch 209
Epoch: 209, MSE 0.00236325 (0.00213373), PSNR 26.26490211 (26.73023478), SSIM 0.78451204 (0.77799770)
Finished training epoch 209
Validate: MSE 0.00253965 (0.00236091), PSNR 25.95225525 (26.33330903), SSIM 0.72469521 (0.78007340)
Finished validation.
Starting training epoch 210
Epoch: 210, MSE 0.00199151 (0.00214013), PSNR 27.00816727 (26.71643857), SSIM 0.78745353 (0.77789236)
Finished training epoch 210
Validate: MSE 0.00240553 (0.00217292), PSNR 26.18789482 (26.68936621), SSIM 0.72866869 (0.78453570)
Finished validation.
Starting training epoch 211
Epoch: 211, MSE 0.00221291 (0.00213487), PSNR 26.55036926 (26.72891137), SSIM 0.79303777 (0.77790903)
Finished training epoch 211
Validate: MSE 0.00280823 (0.00220772), PSNR 25.51566505 (26.62194533), SSIM 0.72187638 (0.78635950)
Finished validation.
Starting training epoch 212
Epoch: 212, MS

Epoch: 238, MSE 0.00228814 (0.00212139), PSNR 26.40518188 (26.75475387), SSIM 0.75408828 (0.77766165)
Finished training epoch 238
Validate: MSE 0.00230937 (0.00225570), PSNR 26.36507034 (26.52202293), SSIM 0.72260541 (0.76903306)
Finished validation.
Starting training epoch 239
Epoch: 239, MSE 0.00243103 (0.00213254), PSNR 26.14208794 (26.73341871), SSIM 0.76160526 (0.77729756)
Finished training epoch 239
Validate: MSE 0.00241024 (0.00221624), PSNR 26.17938995 (26.60358842), SSIM 0.72956002 (0.77780735)
Finished validation.
Starting training epoch 240
Epoch: 240, MSE 0.00188545 (0.00212043), PSNR 27.24584961 (26.75573596), SSIM 0.79377878 (0.77745786)
Finished training epoch 240
Validate: MSE 0.00268649 (0.00250782), PSNR 25.70815277 (26.07310499), SSIM 0.71979856 (0.77719217)
Finished validation.
Starting training epoch 241
Epoch: 241, MSE 0.00202096 (0.00211304), PSNR 26.94441605 (26.77146665), SSIM 0.77335709 (0.77764927)
Finished training epoch 241
Validate: MSE 0.00254027 (0.00220

Validate: MSE 0.00255931 (0.00218538), PSNR 25.91876602 (26.66841800), SSIM 0.72487253 (0.78353888)
Finished validation.
Starting training epoch 268
Epoch: 268, MSE 0.00186136 (0.00210879), PSNR 27.30170631 (26.78304815), SSIM 0.77797461 (0.77708134)
Finished training epoch 268
Validate: MSE 0.00259397 (0.00219928), PSNR 25.86034393 (26.63882216), SSIM 0.71231198 (0.77763079)
Finished validation.
Starting training epoch 269
Epoch: 269, MSE 0.00172643 (0.00211276), PSNR 27.62851715 (26.77569776), SSIM 0.78803045 (0.77683141)
Finished training epoch 269
Validate: MSE 0.00247587 (0.00219005), PSNR 26.06271744 (26.65697925), SSIM 0.72803897 (0.78375342)
Finished validation.
Starting training epoch 270
Epoch: 270, MSE 0.00199054 (0.00210899), PSNR 27.01029015 (26.77857440), SSIM 0.80169928 (0.77714838)
Finished training epoch 270
Validate: MSE 0.00251557 (0.00219505), PSNR 25.99362946 (26.64769247), SSIM 0.72438687 (0.78235997)
Finished validation.
Starting training epoch 271
Epoch: 271, MS

Epoch: 297, MSE 0.00218292 (0.00210352), PSNR 26.60961723 (26.79186693), SSIM 0.77275425 (0.77666949)
Finished training epoch 297
Validate: MSE 0.00239932 (0.00219881), PSNR 26.19911194 (26.63474717), SSIM 0.71276999 (0.77329984)
Finished validation.
Starting training epoch 298
Epoch: 298, MSE 0.00211611 (0.00209938), PSNR 26.74461555 (26.80270651), SSIM 0.77366745 (0.77694511)
Finished training epoch 298
Validate: MSE 0.00237789 (0.00220027), PSNR 26.23807335 (26.63078490), SSIM 0.72039121 (0.77355696)
Finished validation.
Starting training epoch 299
Epoch: 299, MSE 0.00184186 (0.00209117), PSNR 27.34742165 (26.82013418), SSIM 0.77214319 (0.77663967)
Finished training epoch 299
Validate: MSE 0.00232623 (0.00226132), PSNR 26.33346367 (26.51147663), SSIM 0.72266424 (0.77251262)
Finished validation.
Starting training epoch 300
Epoch: 300, MSE 0.00216203 (0.00209328), PSNR 26.65137863 (26.81540815), SSIM 0.76259506 (0.77675773)
Finished training epoch 300
Validate: MSE 0.00239420 (0.00218

<Figure size 432x288 with 0 Axes>

In [20]:
torch.save(model.state_dict(), f'{checkpoints}/last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')

In [21]:
# Validate
save_images = True
with torch.no_grad():
    validate(val_loader, model, criterion, save_images, -1)

Validate: MSE 0.00240837 (0.00222493), PSNR 26.18276405 (26.58702870), SSIM 0.72831553 (0.78031262)
Finished validation.


<Figure size 432x288 with 0 Axes>

In [22]:
# # Show images 
# image_pairs = []

# for i in range(10):
#     image_pairs.append((f'{color_imgs}img-{i}-epoch-{best_epoch}.jpg', f'{gray_imgs}img-{i}-epoch-{best_epoch}.jpg'))
    
# for c, g in image_pairs:
#   color = mpimg.imread(c)
#   gray  = mpimg.imread(g)
#   f, axarr = plt.subplots(1, 2)
#   f.set_size_inches(15, 15)
#   axarr[0].imshow(gray, cmap='gray')
#   axarr[1].imshow(color)
#   axarr[0].axis('off'), axarr[1].axis('off')
#   plt.show()