In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# For conversion
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchmetrics import MeanSquaredError, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from PIL import Image
# For utilities
import os, shutil, time

In [2]:
%load_ext tensorboard
%tensorboard --logdir=runs

In [3]:
# colab i kaggle jeszcze nie testowane
colab = False
kaggle = False
test_number = '12_2'

In [4]:
color_imgs = 'outputs/color/'
gray_imgs = 'outputs/gray/'
checkpoints = 'checkpoints'
if colab:
    from google.colab import drive
    drive.mount('/content/drive')
    dataset = '/content/drive/MyDrive/MGU/cifar10/'
    
    color_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{color_imgs}'
    gray_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{gray_imgs}'
    checkpoints = f'/content/drive/MyDrive/MGU/{test_number}/{checkpoints}'
elif kaggle:
    dataset = '/kaggle/input/cifar10/'
    
    color_imgs = f'{test_number}/{color_imgs}'
    gray_imgs = f'{test_number}/{gray_imgs}'
    checkpoints = f'{test_number}/{checkpoints}'
else:
    dataset = '../../datasets/cifar10/'

In [5]:
# Make folders and set parameters
os.makedirs(color_imgs, exist_ok=True)
os.makedirs(gray_imgs, exist_ok=True)
os.makedirs(checkpoints, exist_ok=True)
save_images = True
best_losses = [1e10, 1e10, 1e10]
best_epoch = -1
patience = 50
epochs = 500
batch_size = 128
SIZE = 32

In [6]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [7]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()
print(use_gpu)

True


In [8]:
class LabImageFolder(torch.utils.data.Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC),
                transforms.RandomCrop(SIZE),
                transforms.RandomHorizontalFlip(),  
            ])
        elif split == 'val':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC), 
                transforms.RandomCrop(SIZE), 
            ])
            
        self.split = split
        self.size = SIZE
        self.paths = [os.path.join(paths, file) for file in os.listdir(paths) if os.path.isfile(
            os.path.join(paths, file))]
        
        
    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("RGB")
        img_original = self.transforms(img)
        img_original = np.asarray(img_original)
        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        img_gray = rgb2gray(img_original)
        img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
        return img_gray, img_ab
    
    def __len__(self):
        return len(self.paths)

In [9]:
# Training
train_imagefolder = LabImageFolder(dataset + 'train')
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=batch_size, shuffle=True)
# Validation 
val_imagefolder = LabImageFolder(dataset + 'val' , 'val')
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=batch_size, shuffle=False)

In [10]:
kernel_size=3
stride_en=1
stride_de=1
padding=1
scale_factor=2
padding_mode='zeros'
channels_base = 256
p1 = .75

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        self.conv1 = nn.Conv2d(1, channels_base, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        self.conv2 = nn.Conv2d(channels_base, channels_base * 2, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)

        self.convtrans1 = nn.ConvTranspose2d(channels_base * 2, channels_base, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans2 = nn.ConvTranspose2d(channels_base, channels_base // 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans3 = nn.ConvTranspose2d(channels_base // 2, 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)

        self.batchnorm1 = nn.BatchNorm2d(channels_base)
        self.batchnorm2 = nn.BatchNorm2d(channels_base * 2)
        self.batchnorm3 = nn.BatchNorm2d(channels_base)
        self.batchnorm4 = nn.BatchNorm2d(channels_base // 2)
        
    def forward(self, input):
        # encoder
        x = F.leaky_relu(self.batchnorm1(self.conv1(input)), negative_slope=0.1)
        x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        x = F.dropout(x, p=p1)
        x = F.leaky_relu(self.batchnorm2(self.conv2(x)), negative_slope=0.1)
        x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        x = F.dropout(x, p=p1)
        
        # decoder
        x = F.leaky_relu(self.batchnorm3(self.convtrans1(x)), negative_slope=0.1)
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.leaky_relu(self.batchnorm4(self.convtrans2(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.interpolate(self.convtrans3(x), scale_factor=scale_factor)

        return x

In [11]:
model = Autoencoder()

In [12]:
criterion = [MeanSquaredError(), PeakSignalNoiseRatio(data_range=1.0), StructuralSimilarityIndexMeasure(data_range=1.0)]

In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [14]:
# # Move model and loss function to GPU
if use_gpu: 
    criterion = [criterion[0].to("cuda"), criterion[1].to("cuda"), criterion[2].to("cuda")]
    model = model.cuda()

In [15]:
if use_gpu: 
    from torchsummary import summary
    summary(model, (1, SIZE, SIZE))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 256, 32, 32]           2,560
       BatchNorm2d-2          [-1, 256, 32, 32]             512
            Conv2d-3          [-1, 512, 16, 16]       1,180,160
       BatchNorm2d-4          [-1, 512, 16, 16]           1,024
   ConvTranspose2d-5            [-1, 256, 8, 8]       1,179,904
       BatchNorm2d-6            [-1, 256, 8, 8]             512
   ConvTranspose2d-7          [-1, 128, 16, 16]         295,040
       BatchNorm2d-8          [-1, 128, 16, 16]             256
   ConvTranspose2d-9            [-1, 2, 16, 16]           2,306
Total params: 2,662,274
Trainable params: 2,662,274
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 6.75
Params size (MB): 10.16
Estimated Total Size (MB): 16.91
-------------------------------------

In [16]:
class AverageMeter(object):
    '''A handy class from the PyTorch ImageNet tutorial''' 
    def __init__(self):
        self.reset()
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
       Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
    color_image = lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    if save_path is not None and save_name is not None: 
        plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
        plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

In [17]:
def validate(val_loader, model, criterion, save_images, epoch):
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]

    model.eval()
    already_saved_images = False
    for gray, ab in val_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()

        # Run model and record loss
        output_ab = model(gray) # throw away class predictions
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))

        # Save images to file
        if save_images and not already_saved_images:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)): # save at most 5 images
                save_path = {'grayscale': gray_imgs, 'colorized': color_imgs}
                save_name = 'img-{}-epoch-{}.jpg'.format(j, epoch)
                to_rgb(gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

    print(f'Validate: MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print('Finished validation.')
    if epoch >= 0:
        writer.add_scalar("MSE/test", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/test", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/test", _loss[2].avg, epoch)
    return _loss[0].avg, _loss[1].avg, _loss[2].avg

In [18]:
def train(train_loader, model, criterion, optimizer, epoch):
    print(f'Starting training epoch {epoch}')
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]
    
    model.train()

    for gray, ab in train_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()
            
        optimizer.zero_grad()

        output_ab = model(gray) 
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        loss[0].backward()
        optimizer.step()
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))
        
    print(f'Epoch: {epoch}, MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print(f'Finished training epoch {epoch}')
    if epoch >= 0:
        writer.add_scalar("MSE/train", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/train", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/train", _loss[2].avg, epoch)

In [19]:
# Train model
for epoch in range(epochs):
    # Train for one epoch, then validate
    train(train_loader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(val_loader, model, criterion, save_images, epoch)
    # Save checkpoint and replace old best model if current model is better
    if losses[0] < best_losses[0]:
        best_losses[0] = losses[0]
        best_epoch = epoch
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}.pth')
    if losses[1] < best_losses[1]:
        best_losses[1] = losses[1]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-PSNRLoss-{losses[1]:.8f}.pth')
    if losses[2] < best_losses[2]:
        best_losses[2] = losses[2]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-SSIMLoss-{losses[2]:.8f}.pth')
    
    if epoch - best_epoch >= patience and epoch >= 100:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}-early_stop.pth')
        break
    
    if epoch == epochs - 1:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')


Starting training epoch 0
Epoch: 0, MSE 0.03388810 (0.93333788), PSNR 14.69952679 (8.53339034), SSIM 0.25305301 (0.08514557)
Finished training epoch 0


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.03650992 (0.03360692), PSNR 14.37588978 (15.08183875), SSIM 0.21280637 (0.23963901)
Finished validation.
Starting training epoch 1
Epoch: 1, MSE 0.00709825 (0.01570231), PSNR 21.48848343 (18.54908478), SSIM 0.53417128 (0.39284576)
Finished training epoch 1


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00720866 (0.00652036), PSNR 21.42145157 (22.01351016), SSIM 0.45609725 (0.52920665)
Finished validation.
Starting training epoch 2
Epoch: 2, MSE 0.00326057 (0.00417200), PSNR 24.86706161 (23.90178636), SSIM 0.67344320 (0.61018161)
Finished training epoch 2
Validate: MSE 0.00353590 (0.00309852), PSNR 24.51499367 (25.15258752), SSIM 0.59476095 (0.67398117)
Finished validation.
Starting training epoch 3
Epoch: 3, MSE 0.00262501 (0.00288258), PSNR 25.80868530 (25.42565603), SSIM 0.76429427 (0.71448121)
Finished training epoch 3
Validate: MSE 0.00318152 (0.00275678), PSNR 24.97365570 (25.66723143), SSIM 0.67360258 (0.74554738)
Finished validation.
Starting training epoch 4
Epoch: 4, MSE 0.00207683 (0.00275217), PSNR 26.82598305 (25.62253094), SSIM 0.79283869 (0.75709660)
Finished training epoch 4
Validate: MSE 0.00305720 (0.00273137), PSNR 25.14675903 (25.71452241), SSIM 0.69638532 (0.76404811)
Finished validation.
Starting training epoch 5
Epoch: 5, MSE 0.00312848 (0.002739

Validate: MSE 0.00316934 (0.00272543), PSNR 24.99031448 (25.72342664), SSIM 0.69513196 (0.75996570)
Finished validation.
Starting training epoch 32
Epoch: 32, MSE 0.00245844 (0.00246387), PSNR 26.09340858 (26.10775116), SSIM 0.79304922 (0.76969094)
Finished training epoch 32
Validate: MSE 0.00288801 (0.00258393), PSNR 25.39401245 (25.95718291), SSIM 0.70519888 (0.77379118)
Finished validation.
Starting training epoch 33
Epoch: 33, MSE 0.00245213 (0.00246685), PSNR 26.10457039 (26.10031175), SSIM 0.78593928 (0.77004434)
Finished training epoch 33
Validate: MSE 0.00282798 (0.00242419), PSNR 25.48523521 (26.21557606), SSIM 0.70143533 (0.77073457)
Finished validation.
Starting training epoch 34
Epoch: 34, MSE 0.00271499 (0.00245824), PSNR 25.66231728 (26.11555544), SSIM 0.78743827 (0.77000911)
Finished training epoch 34
Validate: MSE 0.00297100 (0.00249345), PSNR 25.27096558 (26.10166116), SSIM 0.70618582 (0.77379351)
Finished validation.
Starting training epoch 35
Epoch: 35, MSE 0.0023257

Validate: MSE 0.00276311 (0.00236539), PSNR 25.58600807 (26.32372177), SSIM 0.71140313 (0.77498749)
Finished validation.
Starting training epoch 62
Epoch: 62, MSE 0.00220364 (0.00239462), PSNR 26.56859207 (26.22994659), SSIM 0.78424919 (0.76956548)
Finished training epoch 62
Validate: MSE 0.00294680 (0.00252757), PSNR 25.30648804 (26.04531189), SSIM 0.70130837 (0.76537809)
Finished validation.
Starting training epoch 63
Epoch: 63, MSE 0.00234563 (0.00238729), PSNR 26.29741096 (26.24375493), SSIM 0.78331870 (0.77008149)
Finished training epoch 63
Validate: MSE 0.00269222 (0.00257215), PSNR 25.69889259 (25.96860827), SSIM 0.70807749 (0.75941638)
Finished validation.
Starting training epoch 64
Epoch: 64, MSE 0.00244396 (0.00238629), PSNR 26.11906052 (26.24375596), SSIM 0.77736330 (0.77009760)
Finished training epoch 64
Validate: MSE 0.00260702 (0.00241136), PSNR 25.83856010 (26.24034975), SSIM 0.71055877 (0.77228735)
Finished validation.
Starting training epoch 65
Epoch: 65, MSE 0.0029819

Validate: MSE 0.00273414 (0.00241470), PSNR 25.63179779 (26.24281407), SSIM 0.70777953 (0.77221652)
Finished validation.
Starting training epoch 92
Epoch: 92, MSE 0.00267876 (0.00235040), PSNR 25.72065544 (26.30939517), SSIM 0.76592779 (0.77076823)
Finished training epoch 92
Validate: MSE 0.00273876 (0.00242906), PSNR 25.62445450 (26.21598028), SSIM 0.71330684 (0.76888692)
Finished validation.
Starting training epoch 93
Epoch: 93, MSE 0.00189850 (0.00235467), PSNR 27.21588898 (26.30326355), SSIM 0.80768168 (0.77001901)
Finished training epoch 93
Validate: MSE 0.00287796 (0.00236797), PSNR 25.40914536 (26.32572819), SSIM 0.71184200 (0.77309854)
Finished validation.
Starting training epoch 94
Epoch: 94, MSE 0.00212631 (0.00233967), PSNR 26.72372627 (26.33290585), SSIM 0.78545791 (0.77116605)
Finished training epoch 94
Validate: MSE 0.00309292 (0.00283060), PSNR 25.09631348 (25.56518235), SSIM 0.70620120 (0.76502200)
Finished validation.
Starting training epoch 95
Epoch: 95, MSE 0.0026809

Epoch: 121, MSE 0.00218987 (0.00232246), PSNR 26.59581375 (26.36280286), SSIM 0.76066780 (0.77138335)
Finished training epoch 121
Validate: MSE 0.00252467 (0.00240055), PSNR 25.97795296 (26.24593606), SSIM 0.71339041 (0.76472734)
Finished validation.
Starting training epoch 122
Epoch: 122, MSE 0.00206372 (0.00233662), PSNR 26.85349083 (26.33713269), SSIM 0.76561594 (0.77018680)
Finished training epoch 122
Validate: MSE 0.00305283 (0.00248239), PSNR 25.15297127 (26.11473941), SSIM 0.70794034 (0.76461683)
Finished validation.
Starting training epoch 123
Epoch: 123, MSE 0.00254855 (0.00235072), PSNR 25.93707275 (26.31248685), SSIM 0.74865848 (0.76930563)
Finished training epoch 123
Validate: MSE 0.00268977 (0.00244162), PSNR 25.70284081 (26.16753114), SSIM 0.69399244 (0.75534593)
Finished validation.
Starting training epoch 124
Epoch: 124, MSE 0.00194008 (0.00234473), PSNR 27.12179375 (26.32300870), SSIM 0.78364438 (0.76966066)
Finished training epoch 124
Validate: MSE 0.00266922 (0.00230

Validate: MSE 0.00261357 (0.00230134), PSNR 25.82765961 (26.43701046), SSIM 0.71176445 (0.76905190)
Finished validation.
Starting training epoch 151
Epoch: 151, MSE 0.00240416 (0.00231717), PSNR 26.19037247 (26.37359680), SSIM 0.75427282 (0.77011840)
Finished training epoch 151
Validate: MSE 0.00267210 (0.00231959), PSNR 25.73146820 (26.40534491), SSIM 0.71109962 (0.77472239)
Finished validation.
Starting training epoch 152
Epoch: 152, MSE 0.00236534 (0.00231546), PSNR 26.26105881 (26.37510389), SSIM 0.76008409 (0.76989860)
Finished training epoch 152
Validate: MSE 0.00278040 (0.00244216), PSNR 25.55893135 (26.17840305), SSIM 0.71346205 (0.76484740)
Finished validation.
Starting training epoch 153
Epoch: 153, MSE 0.00245854 (0.00232256), PSNR 26.09322548 (26.36406573), SSIM 0.77007568 (0.76976250)
Finished training epoch 153
Validate: MSE 0.00262891 (0.00235265), PSNR 25.80224800 (26.34862105), SSIM 0.71325397 (0.77104038)
Finished validation.
Starting training epoch 154
Epoch: 154, MS

Epoch: 180, MSE 0.00155077 (0.00229616), PSNR 28.09452057 (26.41350774), SSIM 0.79046839 (0.77074254)
Finished training epoch 180
Validate: MSE 0.00273082 (0.00240546), PSNR 25.63706589 (26.25712001), SSIM 0.70860034 (0.76947624)
Finished validation.
Starting training epoch 181
Epoch: 181, MSE 0.00221008 (0.00230411), PSNR 26.55591202 (26.39513535), SSIM 0.76969397 (0.77039911)
Finished training epoch 181
Validate: MSE 0.00288978 (0.00234289), PSNR 25.39135170 (26.35913493), SSIM 0.70352399 (0.76852851)
Finished validation.
Starting training epoch 182
Epoch: 182, MSE 0.00192289 (0.00231508), PSNR 27.16045952 (26.37504158), SSIM 0.77719289 (0.77009599)
Finished training epoch 182
Validate: MSE 0.00259780 (0.00229958), PSNR 25.85394287 (26.43655916), SSIM 0.70938426 (0.76481335)
Finished validation.
Starting training epoch 183
Epoch: 183, MSE 0.00240342 (0.00229174), PSNR 26.19170570 (26.42152433), SSIM 0.79107440 (0.77099550)
Finished training epoch 183
Validate: MSE 0.00272244 (0.00233

Validate: MSE 0.00286728 (0.00229630), PSNR 25.42530441 (26.45640153), SSIM 0.71190053 (0.77317270)
Finished validation.
Starting training epoch 210
Epoch: 210, MSE 0.00216165 (0.00229392), PSNR 26.65214920 (26.41739692), SSIM 0.77645123 (0.77054530)
Finished training epoch 210
Validate: MSE 0.00288547 (0.00247584), PSNR 25.39783287 (26.13687154), SSIM 0.71401626 (0.77137521)
Finished validation.
Starting training epoch 211
Epoch: 211, MSE 0.00224614 (0.00229847), PSNR 26.48562241 (26.40836252), SSIM 0.75859678 (0.77044988)
Finished training epoch 211
Validate: MSE 0.00275354 (0.00233602), PSNR 25.60107994 (26.37403663), SSIM 0.69604880 (0.76395401)
Finished validation.
Starting training epoch 212
Epoch: 212, MSE 0.00224133 (0.00230001), PSNR 26.49493408 (26.40700044), SSIM 0.78799438 (0.77045664)
Finished training epoch 212
Validate: MSE 0.00259884 (0.00228900), PSNR 25.85220909 (26.45782220), SSIM 0.70942211 (0.76853769)
Finished validation.
Starting training epoch 213
Epoch: 213, MS

  return func(*args, **kwargs)


Validate: MSE 0.00279272 (0.00229199), PSNR 25.53972626 (26.45894022), SSIM 0.71499741 (0.77053429)
Finished validation.
Starting training epoch 219
Epoch: 219, MSE 0.00188037 (0.00230203), PSNR 27.25757217 (26.39985139), SSIM 0.79891205 (0.77012164)
Finished training epoch 219
Validate: MSE 0.00310966 (0.00249185), PSNR 25.07286453 (26.10326673), SSIM 0.71334803 (0.76874077)
Finished validation.
Starting training epoch 220
Epoch: 220, MSE 0.00193522 (0.00228412), PSNR 27.13270569 (26.43387031), SSIM 0.78259051 (0.77069475)
Finished training epoch 220
Validate: MSE 0.00268858 (0.00228130), PSNR 25.70476532 (26.47615839), SSIM 0.70923972 (0.77398599)
Finished validation.
Starting training epoch 221
Epoch: 221, MSE 0.00239730 (0.00230266), PSNR 26.20277596 (26.40178096), SSIM 0.77502006 (0.76938998)
Finished training epoch 221
Validate: MSE 0.00276957 (0.00233898), PSNR 25.57587242 (26.37475802), SSIM 0.70689201 (0.77024667)
Finished validation.
Starting training epoch 222
Epoch: 222, MS

<Figure size 432x288 with 0 Axes>

In [20]:
torch.save(model.state_dict(), f'{checkpoints}/last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')

In [21]:
# Validate
save_images = True
with torch.no_grad():
    validate(val_loader, model, criterion, save_images, -1)

Validate: MSE 0.00253776 (0.00228655), PSNR 25.95550156 (26.46911406), SSIM 0.71470141 (0.76996229)
Finished validation.


<Figure size 432x288 with 0 Axes>

In [22]:
# # Show images 
# image_pairs = []

# for i in range(10):
#     image_pairs.append((f'{color_imgs}img-{i}-epoch-{best_epoch}.jpg', f'{gray_imgs}img-{i}-epoch-{best_epoch}.jpg'))
    
# for c, g in image_pairs:
#   color = mpimg.imread(c)
#   gray  = mpimg.imread(g)
#   f, axarr = plt.subplots(1, 2)
#   f.set_size_inches(15, 15)
#   axarr[0].imshow(gray, cmap='gray')
#   axarr[1].imshow(color)
#   axarr[0].axis('off'), axarr[1].axis('off')
#   plt.show()