In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# For conversion
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchmetrics import MeanSquaredError, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from PIL import Image
# For utilities
import os, shutil, time

In [2]:
%load_ext tensorboard
%tensorboard --logdir=runs

In [3]:
# colab i kaggle jeszcze nie testowane
colab = False
kaggle = False
test_number = '12_2'

In [4]:
color_imgs = 'outputs/color/'
gray_imgs = 'outputs/gray/'
checkpoints = 'checkpoints'
if colab:
    from google.colab import drive
    drive.mount('/content/drive')
    dataset = '/content/drive/MyDrive/MGU/cifar10/'
    
    color_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{color_imgs}'
    gray_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{gray_imgs}'
    checkpoints = f'/content/drive/MyDrive/MGU/{test_number}/{checkpoints}'
elif kaggle:
    dataset = '/kaggle/input/cifar10/'
    
    color_imgs = f'{test_number}/{color_imgs}'
    gray_imgs = f'{test_number}/{gray_imgs}'
    checkpoints = f'{test_number}/{checkpoints}'
else:
    dataset = '../../datasets/cifar10/'

In [5]:
# Make folders and set parameters
os.makedirs(color_imgs, exist_ok=True)
os.makedirs(gray_imgs, exist_ok=True)
os.makedirs(checkpoints, exist_ok=True)
save_images = True
best_losses = [1e10, 1e10, 1e10]
best_epoch = -1
patience = 50
epochs = 500
batch_size = 128
SIZE = 32

In [6]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [7]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()
print(use_gpu)

True


In [8]:
class LabImageFolder(torch.utils.data.Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC),
                transforms.RandomCrop(SIZE),
                transforms.RandomHorizontalFlip(),  
            ])
        elif split == 'val':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC), 
                transforms.RandomCrop(SIZE), 
            ])
            
        self.split = split
        self.size = SIZE
        self.paths = [os.path.join(paths, file) for file in os.listdir(paths) if os.path.isfile(
            os.path.join(paths, file))]
        
        
    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("RGB")
        img_original = self.transforms(img)
        img_original = np.asarray(img_original)
        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        img_gray = rgb2gray(img_original)
        img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
        return img_gray, img_ab
    
    def __len__(self):
        return len(self.paths)

In [9]:
# Training
train_imagefolder = LabImageFolder(dataset + 'train')
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=batch_size, shuffle=True)
# Validation 
val_imagefolder = LabImageFolder(dataset + 'val' , 'val')
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=batch_size, shuffle=False)

In [10]:
kernel_size=3
stride_en=1
stride_de=1
padding=1
scale_factor=2
padding_mode='zeros'
channels_base = 64

class Autoencoder(nn.Module):
    def __init__(self, input_size=128):
        super(Autoencoder, self).__init__()

        self.conv1 = nn.Conv2d(1, channels_base, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        self.conv2 = nn.Conv2d(channels_base, channels_base * 2, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)

        self.convtrans1 = nn.ConvTranspose2d(channels_base * 2, channels_base, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans2 = nn.ConvTranspose2d(channels_base, channels_base // 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans3 = nn.ConvTranspose2d(channels_base // 2, 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)

        self.batchnorm1 = nn.BatchNorm2d(channels_base)
        self.batchnorm2 = nn.BatchNorm2d(channels_base * 2)
        self.batchnorm3 = nn.BatchNorm2d(channels_base)
        self.batchnorm4 = nn.BatchNorm2d(channels_base // 2)
        
    def forward(self, input):
        # encoder
        x = F.leaky_relu(self.batchnorm1(self.conv1(input)), negative_slope=0.1)
        x = F.avg_pool2d(x, kernel_size=3, stride=2, padding=1)
        x = F.leaky_relu(self.batchnorm2(self.conv2(x)), negative_slope=0.1)
        x = F.avg_pool2d(x, kernel_size=3, stride=2, padding=1)
        
        # decoder
        x = F.leaky_relu(self.batchnorm3(self.convtrans1(x)), negative_slope=0.1)
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.leaky_relu(self.batchnorm4(self.convtrans2(x)), negative_slope=0.1)
        x = F.interpolate(self.convtrans3(x), scale_factor=scale_factor)

        return x

In [11]:
model = Autoencoder()

In [12]:
criterion = [MeanSquaredError(), PeakSignalNoiseRatio(data_range=1.0), StructuralSimilarityIndexMeasure(data_range=1.0)]

In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [14]:
# # Move model and loss function to GPU
if use_gpu: 
    criterion = [criterion[0].to("cuda"), criterion[1].to("cuda"), criterion[2].to("cuda")]
    model = model.cuda()

In [15]:
if use_gpu: 
    from torchsummary import summary
    summary(model, (1, SIZE, SIZE))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]             640
       BatchNorm2d-2           [-1, 64, 32, 32]             128
            Conv2d-3          [-1, 128, 16, 16]          73,856
       BatchNorm2d-4          [-1, 128, 16, 16]             256
   ConvTranspose2d-5             [-1, 64, 8, 8]          73,792
       BatchNorm2d-6             [-1, 64, 8, 8]             128
   ConvTranspose2d-7           [-1, 32, 16, 16]          18,464
       BatchNorm2d-8           [-1, 32, 16, 16]              64
   ConvTranspose2d-9            [-1, 2, 16, 16]             578
Total params: 167,906
Trainable params: 167,906
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 1.69
Params size (MB): 0.64
Estimated Total Size (MB): 2.34
-------------------------------------------

In [16]:
class AverageMeter(object):
    '''A handy class from the PyTorch ImageNet tutorial''' 
    def __init__(self):
        self.reset()
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
       Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
    color_image = lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    if save_path is not None and save_name is not None: 
        plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
        plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

In [17]:
def validate(val_loader, model, criterion, save_images, epoch):
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]

    model.eval()
    already_saved_images = False
    for gray, ab in val_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()

        # Run model and record loss
        output_ab = model(gray) # throw away class predictions
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))

        # Save images to file
        if save_images and not already_saved_images:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)): # save at most 5 images
                save_path = {'grayscale': gray_imgs, 'colorized': color_imgs}
                save_name = 'img-{}-epoch-{}.jpg'.format(j, epoch)
                to_rgb(gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

    print(f'Validate: MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print('Finished validation.')
    if epoch >= 0:
        writer.add_scalar("MSE/test", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/test", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/test", _loss[2].avg, epoch)
    return _loss[0].avg, _loss[1].avg, _loss[2].avg

In [18]:
def train(train_loader, model, criterion, optimizer, epoch):
    print(f'Starting training epoch {epoch}')
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]
    
    model.train()

    for gray, ab in train_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()
            
        optimizer.zero_grad()

        output_ab = model(gray) 
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        loss[0].backward()
        optimizer.step()
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))
        
    print(f'Epoch: {epoch}, MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print(f'Finished training epoch {epoch}')
    if epoch >= 0:
        writer.add_scalar("MSE/train", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/train", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/train", _loss[2].avg, epoch)

In [19]:
# Train model
for epoch in range(epochs):
    # Train for one epoch, then validate
    train(train_loader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(val_loader, model, criterion, save_images, epoch)
    # Save checkpoint and replace old best model if current model is better
    if losses[0] < best_losses[0]:
        best_losses[0] = losses[0]
        best_epoch = epoch
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}.pth')
    if losses[1] < best_losses[1]:
        best_losses[1] = losses[1]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-PSNRLoss-{losses[1]:.8f}.pth')
    if losses[2] < best_losses[2]:
        best_losses[2] = losses[2]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-SSIMLoss-{losses[2]:.8f}.pth')
    
    if epoch - best_epoch >= patience and epoch >= 100:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}-early_stop.pth')
        break
    
    if epoch == epochs - 1:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')


Starting training epoch 0
Epoch: 0, MSE 0.00395997 (0.07210925), PSNR 24.02308083 (20.20020934), SSIM 0.49807110 (0.35565215)
Finished training epoch 0


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00501600 (0.00447828), PSNR 22.99642754 (23.52445583), SSIM 0.39760792 (0.47833114)
Finished validation.
Starting training epoch 1
Epoch: 1, MSE 0.00309138 (0.00405904), PSNR 25.09847260 (23.94357850), SSIM 0.59131271 (0.53415837)
Finished training epoch 1


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00385019 (0.00347840), PSNR 24.14517403 (24.62630303), SSIM 0.51230961 (0.58389975)
Finished validation.
Starting training epoch 2
Epoch: 2, MSE 0.00289960 (0.00347437), PSNR 25.37662315 (24.61602449), SSIM 0.61944377 (0.60804741)
Finished training epoch 2
Validate: MSE 0.00337290 (0.00312337), PSNR 24.71997070 (25.11795999), SSIM 0.56866556 (0.63829048)
Finished validation.
Starting training epoch 3
Epoch: 3, MSE 0.00269104 (0.00320652), PSNR 25.70080376 (24.96199064), SSIM 0.66482544 (0.65170853)
Finished training epoch 3
Validate: MSE 0.00388769 (0.00326970), PSNR 24.10308075 (24.91045326), SSIM 0.59631449 (0.66599934)
Finished validation.
Starting training epoch 4
Epoch: 4, MSE 0.00267871 (0.00314213), PSNR 25.72073555 (25.05128213), SSIM 0.66098118 (0.67827899)
Finished training epoch 4
Validate: MSE 0.00330104 (0.00290342), PSNR 24.81348801 (25.43324514), SSIM 0.62230325 (0.69341623)
Finished validation.
Starting training epoch 5
Epoch: 5, MSE 0.00292001 (0.003029

  return func(*args, **kwargs)


Validate: MSE 0.00327998 (0.00297949), PSNR 24.84129143 (25.31228501), SSIM 0.66141278 (0.72780231)
Finished validation.
Starting training epoch 9
Epoch: 9, MSE 0.00269933 (0.00276349), PSNR 25.68743896 (25.60534546), SSIM 0.76567888 (0.73822524)
Finished training epoch 9
Validate: MSE 0.00285841 (0.00280631), PSNR 25.43875885 (25.58213898), SSIM 0.67584896 (0.74032918)
Finished validation.
Starting training epoch 10
Epoch: 10, MSE 0.00273486 (0.00275805), PSNR 25.63064575 (25.62151401), SSIM 0.75195754 (0.74422943)
Finished training epoch 10
Validate: MSE 0.00292311 (0.00271189), PSNR 25.34154892 (25.75492724), SSIM 0.68662393 (0.75041824)
Finished validation.
Starting training epoch 11
Epoch: 11, MSE 0.00246637 (0.00266259), PSNR 26.07941437 (25.76788906), SSIM 0.74519849 (0.75055592)
Finished training epoch 11
Validate: MSE 0.00290970 (0.00258494), PSNR 25.36151314 (25.94286631), SSIM 0.68905759 (0.75525674)
Finished validation.
Starting training epoch 12
Epoch: 12, MSE 0.00237463 (

Validate: MSE 0.00262169 (0.00275778), PSNR 25.81418610 (25.72322855), SSIM 0.72042394 (0.77622112)
Finished validation.
Starting training epoch 39
Epoch: 39, MSE 0.00234243 (0.00233108), PSNR 26.30333328 (26.34457754), SSIM 0.78154355 (0.77923449)
Finished training epoch 39
Validate: MSE 0.00284047 (0.00236012), PSNR 25.46609306 (26.33928867), SSIM 0.71370220 (0.77914752)
Finished validation.
Starting training epoch 40
Epoch: 40, MSE 0.00251949 (0.00233241), PSNR 25.98686409 (26.34500534), SSIM 0.78894722 (0.77936020)
Finished training epoch 40
Validate: MSE 0.00315662 (0.00247656), PSNR 25.00777245 (26.13408976), SSIM 0.71260405 (0.77869310)
Finished validation.
Starting training epoch 41
Epoch: 41, MSE 0.00197817 (0.00234598), PSNR 27.03735924 (26.31982561), SSIM 0.78623569 (0.77917809)
Finished training epoch 41
Validate: MSE 0.00267472 (0.00287245), PSNR 25.72722435 (25.54289083), SSIM 0.71871513 (0.77504333)
Finished validation.
Starting training epoch 42
Epoch: 42, MSE 0.0021695

Validate: MSE 0.00276450 (0.00287301), PSNR 25.58383942 (25.53063292), SSIM 0.70999098 (0.76835879)
Finished validation.
Starting training epoch 69
Epoch: 69, MSE 0.00253793 (0.00223625), PSNR 25.95519638 (26.52721134), SSIM 0.78609091 (0.78100926)
Finished training epoch 69
Validate: MSE 0.00323067 (0.00269503), PSNR 24.90707779 (25.77745921), SSIM 0.71068740 (0.77343868)
Finished validation.
Starting training epoch 70
Epoch: 70, MSE 0.00241158 (0.00223247), PSNR 26.17697716 (26.53363396), SSIM 0.76937932 (0.78086412)
Finished training epoch 70
Validate: MSE 0.00262688 (0.00242124), PSNR 25.80560112 (26.22941648), SSIM 0.72366631 (0.77626724)
Finished validation.
Starting training epoch 71
Epoch: 71, MSE 0.00274641 (0.00222755), PSNR 25.61234856 (26.54483004), SSIM 0.76646650 (0.78115500)
Finished training epoch 71
Validate: MSE 0.00257116 (0.00253017), PSNR 25.89870644 (26.04597506), SSIM 0.71646130 (0.77277455)
Finished validation.
Starting training epoch 72
Epoch: 72, MSE 0.0022190

Validate: MSE 0.00303160 (0.00296512), PSNR 25.18328476 (25.36263571), SSIM 0.71395183 (0.77011189)
Finished validation.
Starting training epoch 99
Epoch: 99, MSE 0.00223055 (0.00217917), PSNR 26.51586914 (26.63991569), SSIM 0.78549105 (0.78094985)
Finished training epoch 99
Validate: MSE 0.00260496 (0.00224190), PSNR 25.84198761 (26.55057738), SSIM 0.71832609 (0.77693113)
Finished validation.
Starting training epoch 100
Epoch: 100, MSE 0.00206239 (0.00218036), PSNR 26.85628319 (26.64003666), SSIM 0.77704108 (0.78098069)
Finished training epoch 100
Validate: MSE 0.00312289 (0.00259669), PSNR 25.05443573 (25.94465460), SSIM 0.71191597 (0.77384095)
Finished validation.
Starting training epoch 101
Epoch: 101, MSE 0.00196780 (0.00218664), PSNR 27.06018448 (26.62398222), SSIM 0.78089535 (0.78086407)
Finished training epoch 101
Validate: MSE 0.00259736 (0.00227442), PSNR 25.85467339 (26.49650586), SSIM 0.71926886 (0.77945364)
Finished validation.
Starting training epoch 102
Epoch: 102, MSE 0

Epoch: 128, MSE 0.00213914 (0.00214430), PSNR 26.69760704 (26.71160600), SSIM 0.76612371 (0.78057959)
Finished training epoch 128
Validate: MSE 0.00276589 (0.00274453), PSNR 25.58165359 (25.69355621), SSIM 0.70968771 (0.75955276)
Finished validation.
Starting training epoch 129
Epoch: 129, MSE 0.00211948 (0.00214802), PSNR 26.73769569 (26.70264181), SSIM 0.77415633 (0.78048064)
Finished training epoch 129
Validate: MSE 0.00277450 (0.00248435), PSNR 25.56815720 (26.11696945), SSIM 0.71690750 (0.77380745)
Finished validation.
Starting training epoch 130
Epoch: 130, MSE 0.00172662 (0.00214740), PSNR 27.62801743 (26.70332184), SSIM 0.79682982 (0.78057451)
Finished training epoch 130
Validate: MSE 0.00289158 (0.00228778), PSNR 25.38864136 (26.47853296), SSIM 0.70773768 (0.77909665)
Finished validation.
Starting training epoch 131
Epoch: 131, MSE 0.00217908 (0.00214811), PSNR 26.61726761 (26.70257119), SSIM 0.77415907 (0.78056183)
Finished training epoch 131
Validate: MSE 0.00289257 (0.00266

Validate: MSE 0.00261498 (0.00220897), PSNR 25.82531357 (26.61793859), SSIM 0.71153951 (0.77493642)
Finished validation.
Starting training epoch 158
Epoch: 158, MSE 0.00208110 (0.00212597), PSNR 26.81706047 (26.74809620), SSIM 0.77537620 (0.78010127)
Finished training epoch 158
Validate: MSE 0.00276874 (0.00229993), PSNR 25.57717705 (26.44749345), SSIM 0.71439123 (0.77438674)
Finished validation.
Starting training epoch 159
Epoch: 159, MSE 0.00247653 (0.00211702), PSNR 26.06156540 (26.76563849), SSIM 0.79103971 (0.78006452)
Finished training epoch 159
Validate: MSE 0.00241404 (0.00232848), PSNR 26.17255211 (26.38116976), SSIM 0.71495628 (0.77052195)
Finished validation.
Starting training epoch 160
Epoch: 160, MSE 0.00211515 (0.00212182), PSNR 26.74659157 (26.75555908), SSIM 0.78704357 (0.78012126)
Finished training epoch 160
Validate: MSE 0.00262444 (0.00224453), PSNR 25.80963326 (26.54644471), SSIM 0.71112007 (0.77378457)
Finished validation.
Starting training epoch 161
Epoch: 161, MS

Epoch: 187, MSE 0.00217208 (0.00211021), PSNR 26.63123322 (26.77823200), SSIM 0.77131355 (0.77972474)
Finished training epoch 187
Validate: MSE 0.00280929 (0.00244100), PSNR 25.51403427 (26.19836664), SSIM 0.71029592 (0.77298358)
Finished validation.
Starting training epoch 188
Epoch: 188, MSE 0.00182863 (0.00209721), PSNR 27.37873077 (26.80433343), SSIM 0.79653537 (0.78002954)
Finished training epoch 188
Validate: MSE 0.00277699 (0.00223602), PSNR 25.56426048 (26.56500318), SSIM 0.70921886 (0.77448267)
Finished validation.
Starting training epoch 189
Epoch: 189, MSE 0.00200568 (0.00209878), PSNR 26.97739029 (26.80230600), SSIM 0.76627284 (0.77974143)
Finished training epoch 189
Validate: MSE 0.00279298 (0.00230590), PSNR 25.53932762 (26.43758451), SSIM 0.70920247 (0.77371221)
Finished validation.
Starting training epoch 190
Epoch: 190, MSE 0.00165112 (0.00210567), PSNR 27.82221031 (26.78845883), SSIM 0.80311143 (0.77977097)
Finished training epoch 190
Validate: MSE 0.00251534 (0.00221

<Figure size 432x288 with 0 Axes>

In [20]:
torch.save(model.state_dict(), f'{checkpoints}/last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')

In [21]:
# Validate
save_images = True
with torch.no_grad():
    validate(val_loader, model, criterion, save_images, -1)

Validate: MSE 0.00257919 (0.00231392), PSNR 25.88517189 (26.41912558), SSIM 0.71349913 (0.77266914)
Finished validation.


<Figure size 432x288 with 0 Axes>

In [22]:
# # Show images 
# image_pairs = []

# for i in range(10):
#     image_pairs.append((f'{color_imgs}img-{i}-epoch-{best_epoch}.jpg', f'{gray_imgs}img-{i}-epoch-{best_epoch}.jpg'))
    
# for c, g in image_pairs:
#   color = mpimg.imread(c)
#   gray  = mpimg.imread(g)
#   f, axarr = plt.subplots(1, 2)
#   f.set_size_inches(15, 15)
#   axarr[0].imshow(gray, cmap='gray')
#   axarr[1].imshow(color)
#   axarr[0].axis('off'), axarr[1].axis('off')
#   plt.show()