In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# For conversion
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchmetrics import MeanSquaredError, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from PIL import Image
# For utilities
import os, shutil, time

In [2]:
%load_ext tensorboard
%tensorboard --logdir=runs

In [3]:
# colab i kaggle jeszcze nie testowane
colab = False
kaggle = False
test_number = '12_2'

In [4]:
color_imgs = 'outputs/color/'
gray_imgs = 'outputs/gray/'
checkpoints = 'checkpoints'
if colab:
    from google.colab import drive
    drive.mount('/content/drive')
    dataset = '/content/drive/MyDrive/MGU/cifar10/'
    
    color_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{color_imgs}'
    gray_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{gray_imgs}'
    checkpoints = f'/content/drive/MyDrive/MGU/{test_number}/{checkpoints}'
elif kaggle:
    dataset = '/kaggle/input/cifar10/'
    
    color_imgs = f'{test_number}/{color_imgs}'
    gray_imgs = f'{test_number}/{gray_imgs}'
    checkpoints = f'{test_number}/{checkpoints}'
else:
    dataset = '../../datasets/cifar10/'

In [5]:
# Make folders and set parameters
os.makedirs(color_imgs, exist_ok=True)
os.makedirs(gray_imgs, exist_ok=True)
os.makedirs(checkpoints, exist_ok=True)
save_images = True
best_losses = [1e10, 1e10, 1e10]
best_epoch = -1
patience = 50
epochs = 500
batch_size = 128
SIZE = 32

In [6]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [7]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()
print(use_gpu)

True


In [8]:
class LabImageFolder(torch.utils.data.Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC),
                transforms.RandomCrop(SIZE),
                transforms.RandomHorizontalFlip(),  
            ])
        elif split == 'val':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC), 
                transforms.RandomCrop(SIZE), 
            ])
            
        self.split = split
        self.size = SIZE
        self.paths = [os.path.join(paths, file) for file in os.listdir(paths) if os.path.isfile(
            os.path.join(paths, file))]
        
        
    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("RGB")
        img_original = self.transforms(img)
        img_original = np.asarray(img_original)
        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        img_gray = rgb2gray(img_original)
        img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
        return img_gray, img_ab
    
    def __len__(self):
        return len(self.paths)

In [9]:
# Training
train_imagefolder = LabImageFolder(dataset + 'train')
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=batch_size, shuffle=True)
# Validation 
val_imagefolder = LabImageFolder(dataset + 'val' , 'val')
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=batch_size, shuffle=False)

In [10]:
kernel_size=3
stride_en=1
stride_de=1
padding=1
scale_factor=2
padding_mode='zeros'
channels_base = 128

class Autoencoder(nn.Module):
    def __init__(self, input_size=128):
        super(Autoencoder, self).__init__()

        self.conv1 = nn.Conv2d(1, channels_base, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        self.conv2 = nn.Conv2d(channels_base, channels_base * 2, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)

        self.convtrans1 = nn.ConvTranspose2d(channels_base * 2, channels_base, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans2 = nn.ConvTranspose2d(channels_base, channels_base // 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans3 = nn.ConvTranspose2d(channels_base // 2, 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)

        self.batchnorm1 = nn.BatchNorm2d(channels_base)
        self.batchnorm2 = nn.BatchNorm2d(channels_base * 2)
        self.batchnorm3 = nn.BatchNorm2d(channels_base)
        self.batchnorm4 = nn.BatchNorm2d(channels_base // 2)
        
    def forward(self, input):
        # encoder
        x = F.leaky_relu(self.batchnorm1(self.conv1(input)), negative_slope=0.1)
        x = F.avg_pool2d(x, kernel_size=3, stride=2, padding=1)
        x = F.leaky_relu(self.batchnorm2(self.conv2(x)), negative_slope=0.1)
        x = F.avg_pool2d(x, kernel_size=3, stride=2, padding=1)
        
        # decoder
        x = F.leaky_relu(self.batchnorm3(self.convtrans1(x)), negative_slope=0.1)
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.leaky_relu(self.batchnorm4(self.convtrans2(x)), negative_slope=0.1)
        x = F.interpolate(self.convtrans3(x), scale_factor=scale_factor)

        return x

In [11]:
model = Autoencoder()

In [12]:
criterion = [MeanSquaredError(), PeakSignalNoiseRatio(data_range=1.0), StructuralSimilarityIndexMeasure(data_range=1.0)]

In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [14]:
# # Move model and loss function to GPU
if use_gpu: 
    criterion = [criterion[0].to("cuda"), criterion[1].to("cuda"), criterion[2].to("cuda")]
    model = model.cuda()

In [15]:
if use_gpu: 
    from torchsummary import summary
    summary(model, (1, SIZE, SIZE))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 128, 32, 32]           1,280
       BatchNorm2d-2          [-1, 128, 32, 32]             256
            Conv2d-3          [-1, 256, 16, 16]         295,168
       BatchNorm2d-4          [-1, 256, 16, 16]             512
   ConvTranspose2d-5            [-1, 128, 8, 8]         295,040
       BatchNorm2d-6            [-1, 128, 8, 8]             256
   ConvTranspose2d-7           [-1, 64, 16, 16]          73,792
       BatchNorm2d-8           [-1, 64, 16, 16]             128
   ConvTranspose2d-9            [-1, 2, 16, 16]           1,154
Total params: 667,586
Trainable params: 667,586
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 3.38
Params size (MB): 2.55
Estimated Total Size (MB): 5.93
-------------------------------------------

In [16]:
class AverageMeter(object):
    '''A handy class from the PyTorch ImageNet tutorial''' 
    def __init__(self):
        self.reset()
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
       Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
    color_image = lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    if save_path is not None and save_name is not None: 
        plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
        plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

In [17]:
def validate(val_loader, model, criterion, save_images, epoch):
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]

    model.eval()
    already_saved_images = False
    for gray, ab in val_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()

        # Run model and record loss
        output_ab = model(gray) # throw away class predictions
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))

        # Save images to file
        if save_images and not already_saved_images:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)): # save at most 5 images
                save_path = {'grayscale': gray_imgs, 'colorized': color_imgs}
                save_name = 'img-{}-epoch-{}.jpg'.format(j, epoch)
                to_rgb(gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

    print(f'Validate: MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print('Finished validation.')
    if epoch >= 0:
        writer.add_scalar("MSE/test", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/test", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/test", _loss[2].avg, epoch)
    return _loss[0].avg, _loss[1].avg, _loss[2].avg

In [18]:
def train(train_loader, model, criterion, optimizer, epoch):
    print(f'Starting training epoch {epoch}')
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]
    
    model.train()

    for gray, ab in train_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()
            
        optimizer.zero_grad()

        output_ab = model(gray) 
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        loss[0].backward()
        optimizer.step()
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))
        
    print(f'Epoch: {epoch}, MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print(f'Finished training epoch {epoch}')
    if epoch >= 0:
        writer.add_scalar("MSE/train", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/train", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/train", _loss[2].avg, epoch)

In [19]:
# Train model
for epoch in range(epochs):
    # Train for one epoch, then validate
    train(train_loader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(val_loader, model, criterion, save_images, epoch)
    # Save checkpoint and replace old best model if current model is better
    if losses[0] < best_losses[0]:
        best_losses[0] = losses[0]
        best_epoch = epoch
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}.pth')
    if losses[1] < best_losses[1]:
        best_losses[1] = losses[1]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-PSNRLoss-{losses[1]:.8f}.pth')
    if losses[2] < best_losses[2]:
        best_losses[2] = losses[2]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-SSIMLoss-{losses[2]:.8f}.pth')
    
    if epoch - best_epoch >= patience and epoch >= 100:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}-early_stop.pth')
        break
    
    if epoch == epochs - 1:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')


Starting training epoch 0
Epoch: 0, MSE 0.00629300 (0.32895766), PSNR 22.01142120 (17.13292887), SSIM 0.35981512 (0.22733096)
Finished training epoch 0


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00711792 (0.00668482), PSNR 21.47646713 (21.76924555), SSIM 0.28251266 (0.34594226)
Finished validation.
Starting training epoch 1
Epoch: 1, MSE 0.00422540 (0.00553903), PSNR 23.74131775 (22.60150604), SSIM 0.43954182 (0.40080481)
Finished training epoch 1


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00514265 (0.00455466), PSNR 22.88812828 (23.43734610), SSIM 0.37517017 (0.45108672)
Finished validation.
Starting training epoch 2
Epoch: 2, MSE 0.00423759 (0.00427830), PSNR 23.72881317 (23.70519824), SSIM 0.52953392 (0.48261855)
Finished training epoch 2
Validate: MSE 0.00439893 (0.00411593), PSNR 23.56653214 (23.89295430), SSIM 0.43119660 (0.50852199)
Finished validation.
Starting training epoch 3
Epoch: 3, MSE 0.00378140 (0.00376994), PSNR 24.22346878 (24.25412929), SSIM 0.53469861 (0.53794718)
Finished training epoch 3


  return func(*args, **kwargs)


Validate: MSE 0.00441257 (0.00371417), PSNR 23.55307961 (24.33607802), SSIM 0.48082352 (0.56092571)
Finished validation.
Starting training epoch 4
Epoch: 4, MSE 0.00324305 (0.00353890), PSNR 24.89046097 (24.53135319), SSIM 0.58112079 (0.57726259)
Finished training epoch 4
Validate: MSE 0.00381673 (0.00329149), PSNR 24.18308830 (24.86797097), SSIM 0.51795316 (0.59644723)
Finished validation.
Starting training epoch 5
Epoch: 5, MSE 0.00348696 (0.00336277), PSNR 24.57552719 (24.75436299), SSIM 0.62215775 (0.60815034)
Finished training epoch 5
Validate: MSE 0.00397684 (0.00355392), PSNR 24.00461578 (24.53082845), SSIM 0.53698170 (0.61344580)
Finished validation.
Starting training epoch 6
Epoch: 6, MSE 0.00281569 (0.00327284), PSNR 25.50414467 (24.87112835), SSIM 0.65401661 (0.63147016)
Finished training epoch 6
Validate: MSE 0.00335941 (0.00300089), PSNR 24.73737335 (25.28238046), SSIM 0.56657594 (0.64753498)
Finished validation.
Starting training epoch 7
Epoch: 7, MSE 0.00345117 (0.003304

Validate: MSE 0.00308024 (0.00425120), PSNR 25.11415863 (23.86843370), SSIM 0.70933759 (0.75770408)
Finished validation.
Starting training epoch 34
Epoch: 34, MSE 0.00279835 (0.00240524), PSNR 25.53098297 (26.21173927), SSIM 0.76764309 (0.77695859)
Finished training epoch 34
Validate: MSE 0.00285197 (0.00254037), PSNR 25.44854736 (26.03145391), SSIM 0.70996749 (0.77648998)
Finished validation.
Starting training epoch 35
Epoch: 35, MSE 0.00215803 (0.00238229), PSNR 26.65943336 (26.25191722), SSIM 0.78550977 (0.77776620)
Finished training epoch 35
Validate: MSE 0.00291275 (0.00347177), PSNR 25.35697174 (24.75941996), SSIM 0.71057498 (0.76726979)
Finished validation.
Starting training epoch 36
Epoch: 36, MSE 0.00262476 (0.00237115), PSNR 25.80910683 (26.27435653), SSIM 0.76548910 (0.77810791)
Finished training epoch 36
Validate: MSE 0.00285430 (0.00268885), PSNR 25.44500923 (25.78474737), SSIM 0.70903307 (0.76698012)
Finished validation.
Starting training epoch 37
Epoch: 37, MSE 0.0022149

Validate: MSE 0.00268408 (0.00232315), PSNR 25.71203804 (26.39965843), SSIM 0.72065258 (0.77797019)
Finished validation.
Starting training epoch 64
Epoch: 64, MSE 0.00206012 (0.00226053), PSNR 26.86106682 (26.48173115), SSIM 0.79828799 (0.78153459)
Finished training epoch 64
Validate: MSE 0.00249069 (0.00236953), PSNR 26.03680611 (26.30799915), SSIM 0.72232056 (0.78051738)
Finished validation.
Starting training epoch 65
Epoch: 65, MSE 0.00189285 (0.00224376), PSNR 27.22883606 (26.51347824), SSIM 0.78520399 (0.78184517)
Finished training epoch 65
Validate: MSE 0.00301126 (0.00244878), PSNR 25.21252060 (26.18703732), SSIM 0.71587491 (0.77876337)
Finished validation.
Starting training epoch 66
Epoch: 66, MSE 0.00186748 (0.00224098), PSNR 27.28742981 (26.51670774), SSIM 0.80128956 (0.78191633)
Finished training epoch 66
Validate: MSE 0.00282784 (0.00238616), PSNR 25.48545456 (26.28813764), SSIM 0.71686226 (0.77643570)
Finished validation.
Starting training epoch 67
Epoch: 67, MSE 0.0020989

Validate: MSE 0.00277331 (0.00273711), PSNR 25.57001305 (25.74706112), SSIM 0.71356571 (0.77074434)
Finished validation.
Starting training epoch 94
Epoch: 94, MSE 0.00181117 (0.00215518), PSNR 27.42041779 (26.68853027), SSIM 0.78958219 (0.78252324)
Finished training epoch 94
Validate: MSE 0.00257913 (0.00234318), PSNR 25.88527489 (26.36929081), SSIM 0.72193450 (0.77797312)
Finished validation.
Starting training epoch 95
Epoch: 95, MSE 0.00247453 (0.00215693), PSNR 26.06507683 (26.68431836), SSIM 0.77188873 (0.78236626)
Finished training epoch 95
Validate: MSE 0.00261869 (0.00245953), PSNR 25.81915283 (26.18338792), SSIM 0.71861362 (0.77603534)
Finished validation.
Starting training epoch 96
Epoch: 96, MSE 0.00201656 (0.00215825), PSNR 26.95388985 (26.68143346), SSIM 0.77983236 (0.78238489)
Finished training epoch 96
Validate: MSE 0.00276551 (0.00230981), PSNR 25.58224487 (26.42786558), SSIM 0.71678454 (0.77769775)
Finished validation.
Starting training epoch 97
Epoch: 97, MSE 0.0018982

Epoch: 123, MSE 0.00201248 (0.00209358), PSNR 26.96268463 (26.81116706), SSIM 0.78632355 (0.78193982)
Finished training epoch 123
Validate: MSE 0.00264689 (0.00226452), PSNR 25.77264023 (26.50958222), SSIM 0.72124231 (0.77842296)
Finished validation.
Starting training epoch 124
Epoch: 124, MSE 0.00175021 (0.00208927), PSNR 27.56910515 (26.82239276), SSIM 0.80813754 (0.78190461)
Finished training epoch 124
Validate: MSE 0.00293705 (0.00247868), PSNR 25.32088852 (26.13014233), SSIM 0.70982301 (0.77282008)
Finished validation.
Starting training epoch 125
Epoch: 125, MSE 0.00222835 (0.00208278), PSNR 26.52016830 (26.83830782), SSIM 0.77378559 (0.78194072)
Finished training epoch 125
Validate: MSE 0.00257100 (0.00243026), PSNR 25.89897156 (26.21748831), SSIM 0.71580780 (0.77411732)
Finished validation.
Starting training epoch 126
Epoch: 126, MSE 0.00222288 (0.00208455), PSNR 26.53084755 (26.83086251), SSIM 0.79426193 (0.78191160)
Finished training epoch 126
Validate: MSE 0.00264407 (0.00264

Validate: MSE 0.00258082 (0.00232145), PSNR 25.88242149 (26.40356624), SSIM 0.71486157 (0.76874223)
Finished validation.
Starting training epoch 153
Epoch: 153, MSE 0.00171274 (0.00203815), PSNR 27.66308022 (26.92868858), SSIM 0.77420032 (0.78125374)
Finished training epoch 153
Validate: MSE 0.00253108 (0.00247696), PSNR 25.96694183 (26.11827848), SSIM 0.71029317 (0.77164082)
Finished validation.
Starting training epoch 154
Epoch: 154, MSE 0.00193178 (0.00201829), PSNR 27.14041901 (26.97341914), SSIM 0.78534186 (0.78149168)
Finished training epoch 154
Validate: MSE 0.00254914 (0.00220019), PSNR 25.93606567 (26.63294536), SSIM 0.71512806 (0.77897084)
Finished validation.
Starting training epoch 155
Epoch: 155, MSE 0.00179403 (0.00201934), PSNR 27.46171188 (26.96789029), SSIM 0.80981874 (0.78164645)
Finished training epoch 155
Validate: MSE 0.00270620 (0.00239811), PSNR 25.67639923 (26.27214017), SSIM 0.71081299 (0.77530670)
Finished validation.
Starting training epoch 156
Epoch: 156, MS

Epoch: 182, MSE 0.00178428 (0.00197998), PSNR 27.48537254 (27.05471907), SSIM 0.79232037 (0.78077842)
Finished training epoch 182
Validate: MSE 0.00296667 (0.00248296), PSNR 25.27730560 (26.11653107), SSIM 0.70818734 (0.76781318)
Finished validation.
Starting training epoch 183
Epoch: 183, MSE 0.00198959 (0.00197613), PSNR 27.01236725 (27.06305259), SSIM 0.78366637 (0.78095038)
Finished training epoch 183
Validate: MSE 0.00248164 (0.00247569), PSNR 26.05261421 (26.12590862), SSIM 0.71636677 (0.76140716)
Finished validation.
Starting training epoch 184
Epoch: 184, MSE 0.00222789 (0.00197794), PSNR 26.52105331 (27.05943124), SSIM 0.76788074 (0.78092214)
Finished training epoch 184
Validate: MSE 0.00253825 (0.00228934), PSNR 25.95465851 (26.46022106), SSIM 0.70813304 (0.76502305)
Finished validation.
Starting training epoch 185
Epoch: 185, MSE 0.00195663 (0.00197502), PSNR 27.08491516 (27.06383222), SSIM 0.77925897 (0.78077506)
Finished training epoch 185
Validate: MSE 0.00259563 (0.00221

<Figure size 432x288 with 0 Axes>

In [20]:
torch.save(model.state_dict(), f'{checkpoints}/last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')

In [21]:
# Validate
save_images = True
with torch.no_grad():
    validate(val_loader, model, criterion, save_images, -1)

Validate: MSE 0.00268372 (0.00230091), PSNR 25.71261787 (26.44286734), SSIM 0.70963109 (0.77121133)
Finished validation.


<Figure size 432x288 with 0 Axes>

In [22]:
# # Show images 
# image_pairs = []

# for i in range(10):
#     image_pairs.append((f'{color_imgs}img-{i}-epoch-{best_epoch}.jpg', f'{gray_imgs}img-{i}-epoch-{best_epoch}.jpg'))
    
# for c, g in image_pairs:
#   color = mpimg.imread(c)
#   gray  = mpimg.imread(g)
#   f, axarr = plt.subplots(1, 2)
#   f.set_size_inches(15, 15)
#   axarr[0].imshow(gray, cmap='gray')
#   axarr[1].imshow(color)
#   axarr[0].axis('off'), axarr[1].axis('off')
#   plt.show()