In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# For conversion
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchmetrics import MeanSquaredError, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from PIL import Image
# For utilities
import os, shutil, time

In [2]:
%load_ext tensorboard
%tensorboard --logdir=runs

In [3]:
# colab i kaggle jeszcze nie testowane
colab = False
kaggle = False
test_number = '12_2'

In [4]:
color_imgs = 'outputs/color/'
gray_imgs = 'outputs/gray/'
checkpoints = 'checkpoints'
if colab:
    from google.colab import drive
    drive.mount('/content/drive')
    dataset = '/content/drive/MyDrive/MGU/cifar10/'
    
    color_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{color_imgs}'
    gray_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{gray_imgs}'
    checkpoints = f'/content/drive/MyDrive/MGU/{test_number}/{checkpoints}'
elif kaggle:
    dataset = '/kaggle/input/cifar10/'
    
    color_imgs = f'{test_number}/{color_imgs}'
    gray_imgs = f'{test_number}/{gray_imgs}'
    checkpoints = f'{test_number}/{checkpoints}'
else:
    dataset = '../../datasets/cifar10/'

In [5]:
# Make folders and set parameters
os.makedirs(color_imgs, exist_ok=True)
os.makedirs(gray_imgs, exist_ok=True)
os.makedirs(checkpoints, exist_ok=True)
save_images = True
best_losses = [1e10, 1e10, 1e10]
best_epoch = -1
patience = 50
epochs = 500
batch_size = 128
SIZE = 32

In [6]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [7]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()
print(use_gpu)

True


In [8]:
class LabImageFolder(torch.utils.data.Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC),
                transforms.RandomCrop(SIZE),
                transforms.RandomHorizontalFlip(),  
            ])
        elif split == 'val':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC), 
                transforms.RandomCrop(SIZE), 
            ])
            
        self.split = split
        self.size = SIZE
        self.paths = [os.path.join(paths, file) for file in os.listdir(paths) if os.path.isfile(
            os.path.join(paths, file))]
        
        
    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("RGB")
        img_original = self.transforms(img)
        img_original = np.asarray(img_original)
        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        img_gray = rgb2gray(img_original)
        img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
        return img_gray, img_ab
    
    def __len__(self):
        return len(self.paths)

In [9]:
# Training
train_imagefolder = LabImageFolder(dataset + 'train')
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=batch_size, shuffle=True)
# Validation 
val_imagefolder = LabImageFolder(dataset + 'val' , 'val')
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=batch_size, shuffle=False)

In [10]:
kernel_size=3
stride_en=1
stride_de=1
padding=1
scale_factor=2
padding_mode='zeros'
channels_base = 16
p1 = .5

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        self.conv1 = nn.Conv2d(1, channels_base, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        self.conv2 = nn.Conv2d(channels_base, channels_base * 2, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)

        self.convtrans1 = nn.ConvTranspose2d(channels_base * 2, channels_base, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans2 = nn.ConvTranspose2d(channels_base, channels_base // 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans3 = nn.ConvTranspose2d(channels_base // 2, 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)

        self.batchnorm1 = nn.BatchNorm2d(channels_base)
        self.batchnorm2 = nn.BatchNorm2d(channels_base * 2)
        self.batchnorm3 = nn.BatchNorm2d(channels_base)
        self.batchnorm4 = nn.BatchNorm2d(channels_base // 2)
        
    def forward(self, input):
        # encoder
        x = F.leaky_relu(self.batchnorm1(self.conv1(input)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.avg_pool2d(x, kernel_size=3, stride=2, padding=1)
        x = F.leaky_relu(self.batchnorm2(self.conv2(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.avg_pool2d(x, kernel_size=3, stride=2, padding=1)
        
        # decoder
        x = F.leaky_relu(self.batchnorm3(self.convtrans1(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.leaky_relu(self.batchnorm4(self.convtrans2(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.interpolate(self.convtrans3(x), scale_factor=scale_factor)

        return x

In [11]:
model = Autoencoder()

In [12]:
criterion = [MeanSquaredError(), PeakSignalNoiseRatio(data_range=1.0), StructuralSimilarityIndexMeasure(data_range=1.0)]

In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [14]:
# # Move model and loss function to GPU
if use_gpu: 
    criterion = [criterion[0].to("cuda"), criterion[1].to("cuda"), criterion[2].to("cuda")]
    model = model.cuda()

In [15]:
if use_gpu: 
    from torchsummary import summary
    summary(model, (1, SIZE, SIZE))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 32, 32]             160
       BatchNorm2d-2           [-1, 16, 32, 32]              32
            Conv2d-3           [-1, 32, 16, 16]           4,640
       BatchNorm2d-4           [-1, 32, 16, 16]              64
   ConvTranspose2d-5             [-1, 16, 8, 8]           4,624
       BatchNorm2d-6             [-1, 16, 8, 8]              32
   ConvTranspose2d-7            [-1, 8, 16, 16]           1,160
       BatchNorm2d-8            [-1, 8, 16, 16]              16
   ConvTranspose2d-9            [-1, 2, 16, 16]             146
Total params: 10,874
Trainable params: 10,874
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.43
Params size (MB): 0.04
Estimated Total Size (MB): 0.47
---------------------------------------------

In [16]:
class AverageMeter(object):
    '''A handy class from the PyTorch ImageNet tutorial''' 
    def __init__(self):
        self.reset()
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
       Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
    color_image = lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    if save_path is not None and save_name is not None: 
        plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
        plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

In [17]:
def validate(val_loader, model, criterion, save_images, epoch):
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]

    model.eval()
    already_saved_images = False
    for gray, ab in val_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()

        # Run model and record loss
        output_ab = model(gray) # throw away class predictions
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))

        # Save images to file
        if save_images and not already_saved_images:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)): # save at most 5 images
                save_path = {'grayscale': gray_imgs, 'colorized': color_imgs}
                save_name = 'img-{}-epoch-{}.jpg'.format(j, epoch)
                to_rgb(gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

    print(f'Validate: MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print('Finished validation.')
    if epoch >= 0:
        writer.add_scalar("MSE/test", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/test", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/test", _loss[2].avg, epoch)
    return _loss[0].avg, _loss[1].avg, _loss[2].avg

In [18]:
def train(train_loader, model, criterion, optimizer, epoch):
    print(f'Starting training epoch {epoch}')
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]
    
    model.train()

    for gray, ab in train_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()
            
        optimizer.zero_grad()

        output_ab = model(gray) 
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        loss[0].backward()
        optimizer.step()
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))
        
    print(f'Epoch: {epoch}, MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print(f'Finished training epoch {epoch}')
    if epoch >= 0:
        writer.add_scalar("MSE/train", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/train", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/train", _loss[2].avg, epoch)

In [19]:
# Train model
for epoch in range(epochs):
    # Train for one epoch, then validate
    train(train_loader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(val_loader, model, criterion, save_images, epoch)
    # Save checkpoint and replace old best model if current model is better
    if losses[0] < best_losses[0]:
        best_losses[0] = losses[0]
        best_epoch = epoch
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}.pth')
    if losses[1] < best_losses[1]:
        best_losses[1] = losses[1]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-PSNRLoss-{losses[1]:.8f}.pth')
    if losses[2] < best_losses[2]:
        best_losses[2] = losses[2]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-SSIMLoss-{losses[2]:.8f}.pth')
    
    if epoch - best_epoch >= patience and epoch >= 100:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}-early_stop.pth')
        break
    
    if epoch == epochs - 1:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')


Starting training epoch 0
Epoch: 0, MSE 0.00324499 (0.02247040), PSNR 24.88785934 (22.00917702), SSIM 0.76118404 (0.48987466)
Finished training epoch 0
Validate: MSE 0.00315523 (0.00281371), PSNR 25.00968361 (25.59363546), SSIM 0.68125319 (0.75447644)
Finished validation.
Starting training epoch 1
Epoch: 1, MSE 0.00286631 (0.00275769), PSNR 25.42676544 (25.61855511), SSIM 0.75725430 (0.76679320)
Finished training epoch 1
Validate: MSE 0.00325461 (0.00274212), PSNR 24.87501335 (25.69207182), SSIM 0.70379198 (0.77233774)
Finished validation.
Starting training epoch 2
Epoch: 2, MSE 0.00283083 (0.00275308), PSNR 25.48086166 (25.62235686), SSIM 0.75429595 (0.77038410)
Finished training epoch 2
Validate: MSE 0.00296922 (0.00275512), PSNR 25.27357101 (25.69056369), SSIM 0.69919163 (0.76793322)
Finished validation.
Starting training epoch 3
Epoch: 3, MSE 0.00328439 (0.00274455), PSNR 24.83545303 (25.64275133), SSIM 0.74957931 (0.76993772)
Finished training epoch 3
Validate: MSE 0.00308958 (0.0

Epoch: 30, MSE 0.00294697 (0.00248508), PSNR 25.30624390 (26.07054819), SSIM 0.73658335 (0.76945938)
Finished training epoch 30
Validate: MSE 0.00281049 (0.00259950), PSNR 25.51217079 (25.93610497), SSIM 0.70365548 (0.76659014)
Finished validation.
Starting training epoch 31
Epoch: 31, MSE 0.00251723 (0.00248915), PSNR 25.99076271 (26.06050119), SSIM 0.75417423 (0.76909924)
Finished training epoch 31
Validate: MSE 0.00295342 (0.00266221), PSNR 25.29673958 (25.84074959), SSIM 0.70152110 (0.76592611)
Finished validation.
Starting training epoch 32
Epoch: 32, MSE 0.00200542 (0.00247742), PSNR 26.97794151 (26.08186983), SSIM 0.79134166 (0.76899952)
Finished training epoch 32
Validate: MSE 0.00282077 (0.00255173), PSNR 25.49631500 (26.00336532), SSIM 0.70364809 (0.76932644)
Finished validation.
Starting training epoch 33
Epoch: 33, MSE 0.00214818 (0.00247885), PSNR 26.67929077 (26.07904555), SSIM 0.78088796 (0.76933719)
Finished training epoch 33
Validate: MSE 0.00313887 (0.00267479), PSNR 

Epoch: 60, MSE 0.00246093 (0.00245765), PSNR 26.08901024 (26.11895586), SSIM 0.76849401 (0.76881814)
Finished training epoch 60
Validate: MSE 0.00291449 (0.00264189), PSNR 25.35436821 (25.85310144), SSIM 0.70038819 (0.76319774)
Finished validation.
Starting training epoch 61
Epoch: 61, MSE 0.00268208 (0.00246221), PSNR 25.71527481 (26.10850824), SSIM 0.76608425 (0.76874647)
Finished training epoch 61
Validate: MSE 0.00288888 (0.00252016), PSNR 25.39271164 (26.06325070), SSIM 0.70755333 (0.77127566)
Finished validation.
Starting training epoch 62
Epoch: 62, MSE 0.00234440 (0.00246169), PSNR 26.29968834 (26.10998519), SSIM 0.78040391 (0.76897149)
Finished training epoch 62
Validate: MSE 0.00280666 (0.00248771), PSNR 25.51810074 (26.10070135), SSIM 0.70790541 (0.76717481)
Finished validation.
Starting training epoch 63
Epoch: 63, MSE 0.00301629 (0.00246043), PSNR 25.20526314 (26.11102209), SSIM 0.73968869 (0.76913993)
Finished training epoch 63
Validate: MSE 0.00300014 (0.00250918), PSNR 

Epoch: 90, MSE 0.00218453 (0.00245310), PSNR 26.60640907 (26.12506835), SSIM 0.77657664 (0.76872848)
Finished training epoch 90
Validate: MSE 0.00282903 (0.00241733), PSNR 25.48362923 (26.22676790), SSIM 0.70401442 (0.76712889)
Finished validation.
Starting training epoch 91
Epoch: 91, MSE 0.00197978 (0.00244775), PSNR 27.03383636 (26.13396134), SSIM 0.77837789 (0.76863425)
Finished training epoch 91
Validate: MSE 0.00275537 (0.00246953), PSNR 25.59819794 (26.14224845), SSIM 0.70585197 (0.76817628)
Finished validation.
Starting training epoch 92
Epoch: 92, MSE 0.00220914 (0.00245201), PSNR 26.55777168 (26.12683108), SSIM 0.77758282 (0.76878739)
Finished training epoch 92
Validate: MSE 0.00274028 (0.00245954), PSNR 25.62205505 (26.15487900), SSIM 0.70295894 (0.76468492)
Finished validation.
Starting training epoch 93
Epoch: 93, MSE 0.00203704 (0.00244794), PSNR 26.90999794 (26.13437261), SSIM 0.77927911 (0.76859682)
Finished training epoch 93
Validate: MSE 0.00291213 (0.00248313), PSNR 

Epoch: 120, MSE 0.00255856 (0.00245102), PSNR 25.92004395 (26.12869942), SSIM 0.77776784 (0.76878599)
Finished training epoch 120
Validate: MSE 0.00293536 (0.00260310), PSNR 25.32339096 (25.93602870), SSIM 0.70439088 (0.76756528)
Finished validation.
Starting training epoch 121
Epoch: 121, MSE 0.00253300 (0.00244774), PSNR 25.96365166 (26.13455743), SSIM 0.78635013 (0.76838002)
Finished training epoch 121
Validate: MSE 0.00289182 (0.00243359), PSNR 25.38828468 (26.20300126), SSIM 0.70578587 (0.77266114)
Finished validation.
Starting training epoch 122
Epoch: 122, MSE 0.00260665 (0.00244695), PSNR 25.83916855 (26.13369521), SSIM 0.78541934 (0.76877755)
Finished training epoch 122
Validate: MSE 0.00302040 (0.00244943), PSNR 25.19935036 (26.17399901), SSIM 0.70456052 (0.77193059)
Finished validation.
Starting training epoch 123
Epoch: 123, MSE 0.00239985 (0.00244813), PSNR 26.19816399 (26.13311751), SSIM 0.77352053 (0.76867654)
Finished training epoch 123
Validate: MSE 0.00276181 (0.00250

Validate: MSE 0.00286043 (0.00245960), PSNR 25.43569183 (26.16065558), SSIM 0.70900834 (0.77164004)
Finished validation.
Starting training epoch 150
Epoch: 150, MSE 0.00276591 (0.00243194), PSNR 25.58161354 (26.16060867), SSIM 0.75677860 (0.76870017)
Finished training epoch 150
Validate: MSE 0.00277903 (0.00241138), PSNR 25.56105804 (26.24061654), SSIM 0.70715606 (0.76916878)
Finished validation.
Starting training epoch 151
Epoch: 151, MSE 0.00250124 (0.00243511), PSNR 26.01845169 (26.15516075), SSIM 0.74552947 (0.76884046)
Finished training epoch 151
Validate: MSE 0.00294347 (0.00277314), PSNR 25.31140709 (25.64384225), SSIM 0.69928336 (0.75666065)
Finished validation.
Starting training epoch 152
Epoch: 152, MSE 0.00205741 (0.00245173), PSNR 26.86679077 (26.12751512), SSIM 0.79393715 (0.76851652)
Finished training epoch 152
Validate: MSE 0.00283367 (0.00244590), PSNR 25.47650146 (26.18518845), SSIM 0.70716244 (0.77239047)
Finished validation.
Starting training epoch 153
Epoch: 153, MS

Epoch: 179, MSE 0.00214661 (0.00244137), PSNR 26.68246651 (26.14780516), SSIM 0.78614956 (0.76860192)
Finished training epoch 179
Validate: MSE 0.00281273 (0.00240682), PSNR 25.50871277 (26.24440872), SSIM 0.69880605 (0.76592589)
Finished validation.
Starting training epoch 180
Epoch: 180, MSE 0.00197260 (0.00243560), PSNR 27.04960823 (26.15791311), SSIM 0.78510296 (0.76886583)
Finished training epoch 180
Validate: MSE 0.00298564 (0.00241404), PSNR 25.24962997 (26.23801430), SSIM 0.70315331 (0.77000274)
Finished validation.
Starting training epoch 181
Epoch: 181, MSE 0.00236076 (0.00243496), PSNR 26.26948738 (26.15746910), SSIM 0.76577055 (0.76851812)
Finished training epoch 181
Validate: MSE 0.00265186 (0.00240377), PSNR 25.76449394 (26.25220737), SSIM 0.70794767 (0.76857372)
Finished validation.
Starting training epoch 182
Epoch: 182, MSE 0.00262434 (0.00244375), PSNR 25.80979156 (26.13965181), SSIM 0.79084694 (0.76883239)
Finished training epoch 182
Validate: MSE 0.00291256 (0.00241

Validate: MSE 0.00296321 (0.00243563), PSNR 25.28237534 (26.20222072), SSIM 0.70945030 (0.77392168)
Finished validation.
Starting training epoch 209
Epoch: 209, MSE 0.00263918 (0.00243354), PSNR 25.78531075 (26.16106018), SSIM 0.76128471 (0.76874445)
Finished training epoch 209
Validate: MSE 0.00284288 (0.00244142), PSNR 25.46241760 (26.17956221), SSIM 0.70158184 (0.76427788)
Finished validation.
Starting training epoch 210
Epoch: 210, MSE 0.00232965 (0.00244066), PSNR 26.32708359 (26.14395309), SSIM 0.77864599 (0.76837025)
Finished training epoch 210
Validate: MSE 0.00279833 (0.00241489), PSNR 25.53101349 (26.22975392), SSIM 0.70753157 (0.76760273)
Finished validation.
Starting training epoch 211
Epoch: 211, MSE 0.00203605 (0.00243190), PSNR 26.91210556 (26.16042420), SSIM 0.77711332 (0.76880962)
Finished training epoch 211
Validate: MSE 0.00279067 (0.00266888), PSNR 25.54291725 (25.83976249), SSIM 0.70452905 (0.76744447)
Finished validation.
Starting training epoch 212
Epoch: 212, MS

Epoch: 238, MSE 0.00300420 (0.00244201), PSNR 25.22270203 (26.14523236), SSIM 0.75081968 (0.76874772)
Finished training epoch 238
Validate: MSE 0.00287863 (0.00240800), PSNR 25.40813637 (26.24801569), SSIM 0.70432037 (0.76982036)
Finished validation.
Starting training epoch 239
Epoch: 239, MSE 0.00213573 (0.00243328), PSNR 26.70453453 (26.15810385), SSIM 0.79308927 (0.76886306)
Finished training epoch 239
Validate: MSE 0.00281921 (0.00242039), PSNR 25.49873161 (26.22120561), SSIM 0.70605117 (0.76790226)
Finished validation.
Starting training epoch 240
Epoch: 240, MSE 0.00283637 (0.00243616), PSNR 25.47237015 (26.15379358), SSIM 0.75195009 (0.76847712)
Finished training epoch 240
Validate: MSE 0.00296092 (0.00256271), PSNR 25.28573799 (25.98633313), SSIM 0.71023345 (0.77017344)
Finished validation.
Starting training epoch 241
Epoch: 241, MSE 0.00201565 (0.00243766), PSNR 26.95585060 (26.15121816), SSIM 0.78930008 (0.76878849)
Finished training epoch 241
Validate: MSE 0.00287826 (0.00241

Validate: MSE 0.00301592 (0.00257921), PSNR 25.20579720 (25.96093135), SSIM 0.70349485 (0.76855830)
Finished validation.
Starting training epoch 268
Epoch: 268, MSE 0.00265591 (0.00243668), PSNR 25.75786400 (26.15549118), SSIM 0.76224339 (0.76855256)
Finished training epoch 268
Validate: MSE 0.00282994 (0.00240117), PSNR 25.48222351 (26.25976042), SSIM 0.70657468 (0.77059285)
Finished validation.
Starting training epoch 269
Epoch: 269, MSE 0.00244787 (0.00243362), PSNR 26.11211586 (26.15851236), SSIM 0.75780243 (0.76880588)
Finished training epoch 269
Validate: MSE 0.00285192 (0.00241844), PSNR 25.44862938 (26.22616546), SSIM 0.70921946 (0.77025976)
Finished validation.
Starting training epoch 270
Epoch: 270, MSE 0.00240466 (0.00243580), PSNR 26.18946648 (26.15655807), SSIM 0.78336340 (0.76863637)
Finished training epoch 270
Validate: MSE 0.00281780 (0.00240444), PSNR 25.50089645 (26.25049680), SSIM 0.70388758 (0.76734242)
Finished validation.
Starting training epoch 271
Epoch: 271, MS

<Figure size 432x288 with 0 Axes>

In [20]:
torch.save(model.state_dict(), f'{checkpoints}/last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')

In [21]:
# Validate
save_images = True
with torch.no_grad():
    validate(val_loader, model, criterion, save_images, -1)

Validate: MSE 0.00276967 (0.00241368), PSNR 25.57571602 (26.23340635), SSIM 0.70624006 (0.76591180)
Finished validation.


<Figure size 432x288 with 0 Axes>

In [22]:
# # Show images 
# image_pairs = []

# for i in range(10):
#     image_pairs.append((f'{color_imgs}img-{i}-epoch-{best_epoch}.jpg', f'{gray_imgs}img-{i}-epoch-{best_epoch}.jpg'))
    
# for c, g in image_pairs:
#   color = mpimg.imread(c)
#   gray  = mpimg.imread(g)
#   f, axarr = plt.subplots(1, 2)
#   f.set_size_inches(15, 15)
#   axarr[0].imshow(gray, cmap='gray')
#   axarr[1].imshow(color)
#   axarr[0].axis('off'), axarr[1].axis('off')
#   plt.show()