In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# For conversion
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchmetrics import MeanSquaredError, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from PIL import Image
# For utilities
import os, shutil, time

In [2]:
%load_ext tensorboard
%tensorboard --logdir=runs

In [3]:
# colab i kaggle jeszcze nie testowane
colab = False
kaggle = False
test_number = '12_2'

In [4]:
color_imgs = 'outputs/color/'
gray_imgs = 'outputs/gray/'
checkpoints = 'checkpoints'
if colab:
    from google.colab import drive
    drive.mount('/content/drive')
    dataset = '/content/drive/MyDrive/MGU/cifar10/'
    
    color_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{color_imgs}'
    gray_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{gray_imgs}'
    checkpoints = f'/content/drive/MyDrive/MGU/{test_number}/{checkpoints}'
elif kaggle:
    dataset = '/kaggle/input/cifar10/'
    
    color_imgs = f'{test_number}/{color_imgs}'
    gray_imgs = f'{test_number}/{gray_imgs}'
    checkpoints = f'{test_number}/{checkpoints}'
else:
    dataset = '../../datasets/cifar10/'

In [5]:
# Make folders and set parameters
os.makedirs(color_imgs, exist_ok=True)
os.makedirs(gray_imgs, exist_ok=True)
os.makedirs(checkpoints, exist_ok=True)
save_images = True
best_losses = [1e10, 1e10, 1e10]
best_epoch = -1
patience = 50
epochs = 500
batch_size = 128
SIZE = 32

In [6]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [7]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()
print(use_gpu)

True


In [8]:
class LabImageFolder(torch.utils.data.Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC),
                transforms.RandomCrop(SIZE),
                transforms.RandomHorizontalFlip(),  
                transforms.RandomVerticalFlip(),  
                transforms.RandomRotation(10),
                transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05)
            ])
        elif split == 'val':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC), 
                transforms.RandomCrop(SIZE), 
            ])
            
        self.split = split
        self.size = SIZE
        self.paths = [os.path.join(paths, file) for file in os.listdir(paths) if os.path.isfile(
            os.path.join(paths, file))]
        
        
    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("RGB")
        img_original = self.transforms(img)
        img_original = np.asarray(img_original)
        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        img_gray = rgb2gray(img_original)
        img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
        return img_gray, img_ab
    
    def __len__(self):
        return len(self.paths)

In [9]:
# Training
train_imagefolder = LabImageFolder(dataset + 'train')
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=batch_size, shuffle=True)
# Validation 
val_imagefolder = LabImageFolder(dataset + 'val' , 'val')
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=batch_size, shuffle=False)

In [10]:
kernel_size=3
stride_en=1
stride_de=1
padding=1
scale_factor=2
padding_mode='zeros'
channels_base = 256
p1 = .25

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        self.conv1 = nn.Conv2d(1, channels_base, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        self.conv2 = nn.Conv2d(channels_base, channels_base * 2, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)

        self.convtrans1 = nn.ConvTranspose2d(channels_base * 2, channels_base, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans2 = nn.ConvTranspose2d(channels_base, channels_base // 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans3 = nn.ConvTranspose2d(channels_base // 2, 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)

        self.batchnorm1 = nn.BatchNorm2d(channels_base)
        self.batchnorm2 = nn.BatchNorm2d(channels_base * 2)
        self.batchnorm3 = nn.BatchNorm2d(channels_base)
        self.batchnorm4 = nn.BatchNorm2d(channels_base // 2)
        self.batchnorm5 = nn.BatchNorm2d(2)
        
    def forward(self, input):
        # encoder
        x = F.leaky_relu(self.batchnorm1(self.conv1(input)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = y = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        x = F.leaky_relu(self.batchnorm2(self.conv2(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        
        # decoder
        x = F.leaky_relu(self.batchnorm3(self.convtrans1(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.leaky_relu(self.batchnorm4(self.convtrans2(x + y)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.leaky_relu(self.batchnorm5(self.convtrans3(x + input)), negative_slope=0.1)

        return x

In [11]:
model = Autoencoder()

In [12]:
criterion = [MeanSquaredError(), PeakSignalNoiseRatio(data_range=1.0), StructuralSimilarityIndexMeasure(data_range=1.0)]

In [13]:
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)

In [14]:
# # Move model and loss function to GPU
if use_gpu: 
    criterion = [criterion[0].to("cuda"), criterion[1].to("cuda"), criterion[2].to("cuda")]
    model = model.cuda()

In [15]:
if use_gpu: 
    from torchsummary import summary
    summary(model, (1, SIZE, SIZE))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 256, 32, 32]           2,560
       BatchNorm2d-2          [-1, 256, 32, 32]             512
            Conv2d-3          [-1, 512, 16, 16]       1,180,160
       BatchNorm2d-4          [-1, 512, 16, 16]           1,024
   ConvTranspose2d-5            [-1, 256, 8, 8]       1,179,904
       BatchNorm2d-6            [-1, 256, 8, 8]             512
   ConvTranspose2d-7          [-1, 128, 16, 16]         295,040
       BatchNorm2d-8          [-1, 128, 16, 16]             256
   ConvTranspose2d-9            [-1, 2, 32, 32]           2,306
      BatchNorm2d-10            [-1, 2, 32, 32]               4
Total params: 2,662,278
Trainable params: 2,662,278
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 6.78
Params size (MB): 10.16
Estima

In [16]:
class AverageMeter(object):
    '''A handy class from the PyTorch ImageNet tutorial''' 
    def __init__(self):
        self.reset()
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
       Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
    color_image = lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    if save_path is not None and save_name is not None: 
        plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
        plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

In [17]:
def validate(val_loader, model, criterion, save_images, epoch):
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]

    model.eval()
    already_saved_images = False
    for gray, ab in val_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()

        # Run model and record loss
        output_ab = model(gray) # throw away class predictions
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))

        # Save images to file
        if save_images and not already_saved_images:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)): # save at most 5 images
                save_path = {'grayscale': gray_imgs, 'colorized': color_imgs}
                save_name = 'img-{}-epoch-{}.jpg'.format(j, epoch)
                to_rgb(gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

    print(f'Validate: MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print('Finished validation.')
    if epoch >= 0:
        writer.add_scalar("MSE/test", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/test", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/test", _loss[2].avg, epoch)
    return _loss[0].avg, _loss[1].avg, _loss[2].avg

In [18]:
def train(train_loader, model, criterion, optimizer, epoch):
    print(f'Starting training epoch {epoch}')
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]
    
    model.train()

    for gray, ab in train_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()
            
        optimizer.zero_grad()

        output_ab = model(gray) 
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        loss[0].backward()
        optimizer.step()
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))
        
    print(f'Epoch: {epoch}, MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print(f'Finished training epoch {epoch}')
    if epoch >= 0:
        writer.add_scalar("MSE/train", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/train", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/train", _loss[2].avg, epoch)

In [None]:
# Train model
for epoch in range(epochs):
    # Train for one epoch, then validate
    train(train_loader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(val_loader, model, criterion, save_images, epoch)
    # Save checkpoint and replace old best model if current model is better
    if losses[0] < best_losses[0]:
        best_losses[0] = losses[0]
        best_epoch = epoch
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}.pth')
    if losses[1] < best_losses[1]:
        best_losses[1] = losses[1]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-PSNRLoss-{losses[1]:.8f}.pth')
    if losses[2] < best_losses[2]:
        best_losses[2] = losses[2]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-SSIMLoss-{losses[2]:.8f}.pth')
    
    if epoch - best_epoch >= patience and epoch >= 100:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}-early_stop.pth')
        break
    
    if epoch == epochs - 1:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')


Starting training epoch 0
Epoch: 0, MSE 0.00240498 (0.04317634), PSNR 26.18887901 (19.05293106), SSIM 0.77045929 (0.46797675)
Finished training epoch 0
Validate: MSE 0.00322293 (0.00276196), PSNR 24.91748619 (25.66230228), SSIM 0.70535600 (0.77504690)
Finished validation.
Starting training epoch 1
Epoch: 1, MSE 0.00292737 (0.00276169), PSNR 25.33522797 (25.60989477), SSIM 0.78277564 (0.76857419)
Finished training epoch 1
Validate: MSE 0.00321934 (0.00276010), PSNR 24.92233276 (25.66538044), SSIM 0.70484686 (0.77504291)
Finished validation.
Starting training epoch 2
Epoch: 2, MSE 0.00301752 (0.00274890), PSNR 25.20349503 (25.63245347), SSIM 0.75473601 (0.76871359)
Finished training epoch 2
Validate: MSE 0.00318627 (0.00275944), PSNR 24.96717834 (25.66791514), SSIM 0.70567769 (0.77488110)
Finished validation.
Starting training epoch 3
Epoch: 3, MSE 0.00264109 (0.00275351), PSNR 25.78217506 (25.62361488), SSIM 0.77186835 (0.76867788)
Finished training epoch 3
Validate: MSE 0.00321380 (0.0

Epoch: 30, MSE 0.00281007 (0.00273936), PSNR 25.51282692 (25.64555071), SSIM 0.77044851 (0.76856964)
Finished training epoch 30
Validate: MSE 0.00313762 (0.00273484), PSNR 25.03398895 (25.70867198), SSIM 0.70481086 (0.77458122)
Finished validation.
Starting training epoch 31
Epoch: 31, MSE 0.00295271 (0.00273730), PSNR 25.29778671 (25.65305305), SSIM 0.77853453 (0.76853222)
Finished training epoch 31
Validate: MSE 0.00315191 (0.00273315), PSNR 25.01426315 (25.71068351), SSIM 0.70510292 (0.77507565)
Finished validation.
Starting training epoch 32
Epoch: 32, MSE 0.00270842 (0.00274213), PSNR 25.67283440 (25.64423948), SSIM 0.74247348 (0.76854827)
Finished training epoch 32
Validate: MSE 0.00314925 (0.00273291), PSNR 25.01792908 (25.71068530), SSIM 0.70531005 (0.77508565)
Finished validation.
Starting training epoch 33
Epoch: 33, MSE 0.00232526 (0.00274355), PSNR 26.33528519 (25.63837784), SSIM 0.78053081 (0.76841232)
Finished training epoch 33
Validate: MSE 0.00317317 (0.00273289), PSNR 

Epoch: 60, MSE 0.00297168 (0.00268989), PSNR 25.26997566 (25.72702324), SSIM 0.75224763 (0.76674578)
Finished training epoch 60
Validate: MSE 0.00309299 (0.00268892), PSNR 25.09621239 (25.78072227), SSIM 0.70219994 (0.77459968)
Finished validation.
Starting training epoch 61
Epoch: 61, MSE 0.00262585 (0.00270111), PSNR 25.80729294 (25.71079910), SSIM 0.78074300 (0.76644462)
Finished training epoch 61
Validate: MSE 0.00308796 (0.00268344), PSNR 25.10328293 (25.78981423), SSIM 0.70282710 (0.77387372)
Finished validation.
Starting training epoch 62
Epoch: 62, MSE 0.00283914 (0.00269144), PSNR 25.46812820 (25.72466160), SSIM 0.76002228 (0.76647216)
Finished training epoch 62
Validate: MSE 0.00307387 (0.00268018), PSNR 25.12314987 (25.79576835), SSIM 0.70047778 (0.77336714)
Finished validation.
Starting training epoch 63
Epoch: 63, MSE 0.00263440 (0.00269314), PSNR 25.79317856 (25.72098754), SSIM 0.76939487 (0.76602792)
Finished training epoch 63
Validate: MSE 0.00305296 (0.00267742), PSNR 

Epoch: 90, MSE 0.00212069 (0.00262792), PSNR 26.73522758 (25.83027939), SSIM 0.75244129 (0.76204711)
Finished training epoch 90
Validate: MSE 0.00297231 (0.00261546), PSNR 25.26905441 (25.89762883), SSIM 0.69860983 (0.76969474)
Finished validation.
Starting training epoch 91
Epoch: 91, MSE 0.00245268 (0.00263030), PSNR 26.10359001 (25.82180445), SSIM 0.78178614 (0.76185573)
Finished training epoch 91
Validate: MSE 0.00300592 (0.00261670), PSNR 25.22022057 (25.89440099), SSIM 0.69902039 (0.77049925)
Finished validation.
Starting training epoch 92
Epoch: 92, MSE 0.00248397 (0.00262669), PSNR 26.04852676 (25.82725176), SSIM 0.76913649 (0.76186753)
Finished training epoch 92
Validate: MSE 0.00297270 (0.00261377), PSNR 25.26848412 (25.90038548), SSIM 0.69863158 (0.77004850)
Finished validation.
Starting training epoch 93
Epoch: 93, MSE 0.00309690 (0.00262158), PSNR 25.09072113 (25.84054596), SSIM 0.75299066 (0.76163960)
Finished training epoch 93
Validate: MSE 0.00298286 (0.00260854), PSNR 

Epoch: 120, MSE 0.00250773 (0.00259029), PSNR 26.00718307 (25.89199034), SSIM 0.76238126 (0.76030551)
Finished training epoch 120
Validate: MSE 0.00295459 (0.00256696), PSNR 25.29502296 (25.97411193), SSIM 0.69744074 (0.76754290)
Finished validation.
Starting training epoch 121
Epoch: 121, MSE 0.00233587 (0.00258613), PSNR 26.31550407 (25.89785608), SSIM 0.76833725 (0.76034294)
Finished training epoch 121
Validate: MSE 0.00294459 (0.00256314), PSNR 25.30974579 (25.97997660), SSIM 0.69513065 (0.76655879)
Finished validation.
Starting training epoch 122
Epoch: 122, MSE 0.00214270 (0.00258154), PSNR 26.69038582 (25.90437009), SSIM 0.78162318 (0.76023808)
Finished training epoch 122
Validate: MSE 0.00295699 (0.00256539), PSNR 25.29149818 (25.97622902), SSIM 0.69907522 (0.76748198)
Finished validation.
Starting training epoch 123
Epoch: 123, MSE 0.00252631 (0.00258134), PSNR 25.97512817 (25.90523213), SSIM 0.76124084 (0.76043420)
Finished training epoch 123
Validate: MSE 0.00297448 (0.00256

Validate: MSE 0.00292602 (0.00253606), PSNR 25.33722687 (26.02461432), SSIM 0.70032322 (0.76812091)
Finished validation.
Starting training epoch 150
Epoch: 150, MSE 0.00231904 (0.00255509), PSNR 26.34691620 (25.94630918), SSIM 0.77941632 (0.76064570)
Finished training epoch 150
Validate: MSE 0.00292794 (0.00253538), PSNR 25.33438301 (26.02534868), SSIM 0.69933927 (0.76733126)
Finished validation.
Starting training epoch 151
Epoch: 151, MSE 0.00235693 (0.00255252), PSNR 26.27652931 (25.95168586), SSIM 0.77906072 (0.76057795)
Finished training epoch 151
Validate: MSE 0.00294881 (0.00253732), PSNR 25.30353737 (26.02237196), SSIM 0.70091027 (0.76885811)
Finished validation.
Starting training epoch 152
Epoch: 152, MSE 0.00314699 (0.00255791), PSNR 25.02103996 (25.94683340), SSIM 0.76118171 (0.76053213)
Finished training epoch 152
Validate: MSE 0.00291669 (0.00253198), PSNR 25.35110092 (26.03136782), SSIM 0.69987667 (0.76804896)
Finished validation.
Starting training epoch 153
Epoch: 153, MS

Epoch: 179, MSE 0.00275069 (0.00253501), PSNR 25.60558891 (25.98275382), SSIM 0.74419439 (0.76123346)
Finished training epoch 179
Validate: MSE 0.00291854 (0.00251712), PSNR 25.34834862 (26.05548582), SSIM 0.70066679 (0.76816227)
Finished validation.
Starting training epoch 180
Epoch: 180, MSE 0.00276355 (0.00253415), PSNR 25.58531952 (25.98477450), SSIM 0.75393879 (0.76100672)
Finished training epoch 180
Validate: MSE 0.00293984 (0.00251597), PSNR 25.31676674 (26.05822194), SSIM 0.70171523 (0.76940753)
Finished validation.
Starting training epoch 181
Epoch: 181, MSE 0.00246316 (0.00253456), PSNR 26.08506584 (25.98354209), SSIM 0.74892461 (0.76104233)
Finished training epoch 181
Validate: MSE 0.00292955 (0.00251726), PSNR 25.33199310 (26.05548942), SSIM 0.70134878 (0.76905311)
Finished validation.
Starting training epoch 182
Epoch: 182, MSE 0.00262537 (0.00253005), PSNR 25.80808830 (25.99012590), SSIM 0.76227748 (0.76130486)
Finished training epoch 182
Validate: MSE 0.00291570 (0.00251

Validate: MSE 0.00292304 (0.00250418), PSNR 25.34165764 (26.07756530), SSIM 0.70362359 (0.76971148)
Finished validation.
Starting training epoch 209
Epoch: 209, MSE 0.00320559 (0.00251906), PSNR 24.94091988 (26.01150963), SSIM 0.75999719 (0.76176162)
Finished training epoch 209
Validate: MSE 0.00286728 (0.00249550), PSNR 25.42529678 (26.09144730), SSIM 0.70211601 (0.76895701)
Finished validation.
Starting training epoch 210
Epoch: 210, MSE 0.00295801 (0.00251634), PSNR 25.28999901 (26.01670397), SSIM 0.75164592 (0.76157753)
Finished training epoch 210
Validate: MSE 0.00289754 (0.00249827), PSNR 25.37970352 (26.08593340), SSIM 0.70169526 (0.76651380)
Finished validation.
Starting training epoch 211
Epoch: 211, MSE 0.00275052 (0.00251731), PSNR 25.60585594 (26.01478927), SSIM 0.75312626 (0.76140676)
Finished training epoch 211
Validate: MSE 0.00292123 (0.00250144), PSNR 25.34434319 (26.08190175), SSIM 0.70387822 (0.76921340)
Finished validation.
Starting training epoch 212
Epoch: 212, MS

Epoch: 238, MSE 0.00270902 (0.00250579), PSNR 25.67187119 (26.03032203), SSIM 0.75776339 (0.76199164)
Finished training epoch 238
Validate: MSE 0.00292011 (0.00248647), PSNR 25.34600830 (26.10664021), SSIM 0.70313716 (0.76852668)
Finished validation.
Starting training epoch 239
Epoch: 239, MSE 0.00198527 (0.00250768), PSNR 27.02181053 (26.02736018), SSIM 0.77484500 (0.76216226)
Finished training epoch 239
Validate: MSE 0.00291355 (0.00248583), PSNR 25.35576820 (26.10874071), SSIM 0.70482910 (0.77004800)
Finished validation.
Starting training epoch 240
Epoch: 240, MSE 0.00287369 (0.00250504), PSNR 25.41559982 (26.03276373), SSIM 0.74296224 (0.76204631)
Finished training epoch 240
Validate: MSE 0.00290022 (0.00248234), PSNR 25.37568474 (26.11439150), SSIM 0.70436811 (0.76952444)
Finished validation.
Starting training epoch 241
Epoch: 241, MSE 0.00224495 (0.00250897), PSNR 26.48793221 (26.02803908), SSIM 0.77415407 (0.76209442)
Finished training epoch 241
Validate: MSE 0.00290184 (0.00248

Validate: MSE 0.00291922 (0.00247732), PSNR 25.34732819 (26.12349247), SSIM 0.70384884 (0.77080181)
Finished validation.
Starting training epoch 268
Epoch: 268, MSE 0.00302474 (0.00250239), PSNR 25.19312286 (26.04051958), SSIM 0.73524094 (0.76238058)
Finished training epoch 268
Validate: MSE 0.00289098 (0.00247420), PSNR 25.38954353 (26.12816875), SSIM 0.70453304 (0.76938108)
Finished validation.
Starting training epoch 269
Epoch: 269, MSE 0.00239211 (0.00249824), PSNR 26.21217918 (26.04937489), SSIM 0.76539284 (0.76228066)
Finished training epoch 269
Validate: MSE 0.00290985 (0.00247498), PSNR 25.36129379 (26.12682411), SSIM 0.70693886 (0.77035485)
Finished validation.
Starting training epoch 270
Epoch: 270, MSE 0.00211491 (0.00249769), PSNR 26.74707985 (26.04991848), SSIM 0.77789748 (0.76252443)
Finished training epoch 270
Validate: MSE 0.00289192 (0.00247190), PSNR 25.38813972 (26.13241175), SSIM 0.70531493 (0.76962984)
Finished validation.
Starting training epoch 271
Epoch: 271, MS

Epoch: 297, MSE 0.00215233 (0.00249466), PSNR 26.67090797 (26.05448774), SSIM 0.77206212 (0.76263633)
Finished training epoch 297
Validate: MSE 0.00291774 (0.00246920), PSNR 25.34953117 (26.13740659), SSIM 0.70669645 (0.77091902)
Finished validation.
Starting training epoch 298
Epoch: 298, MSE 0.00221712 (0.00248922), PSNR 26.54209709 (26.06280499), SSIM 0.75614941 (0.76274835)
Finished training epoch 298
Validate: MSE 0.00287822 (0.00246380), PSNR 25.40876389 (26.14596481), SSIM 0.70773011 (0.77111002)
Finished validation.
Starting training epoch 299
Epoch: 299, MSE 0.00290461 (0.00248571), PSNR 25.36912155 (26.06991932), SSIM 0.76809132 (0.76278669)
Finished training epoch 299
Validate: MSE 0.00286932 (0.00246110), PSNR 25.42221260 (26.15041839), SSIM 0.70320088 (0.76985959)
Finished validation.
Starting training epoch 300
Epoch: 300, MSE 0.00350054 (0.00248938), PSNR 24.55864525 (26.06313621), SSIM 0.73774517 (0.76276386)
Finished training epoch 300
Validate: MSE 0.00289678 (0.00246

Validate: MSE 0.00286857 (0.00245690), PSNR 25.42334366 (26.15900911), SSIM 0.70710856 (0.77213707)
Finished validation.
Starting training epoch 327
Epoch: 327, MSE 0.00250591 (0.00247896), PSNR 26.01033783 (26.08040262), SSIM 0.75334972 (0.76326205)
Finished training epoch 327
Validate: MSE 0.00289987 (0.00245664), PSNR 25.37620544 (26.15857266), SSIM 0.70750594 (0.77084670)
Finished validation.
Starting training epoch 328
Epoch: 328, MSE 0.00275046 (0.00248974), PSNR 25.60594177 (26.06482798), SSIM 0.75605041 (0.76269185)
Finished training epoch 328
Validate: MSE 0.00294717 (0.00245959), PSNR 25.30594635 (26.15396971), SSIM 0.70564079 (0.77162486)
Finished validation.
Starting training epoch 329
Epoch: 329, MSE 0.00259999 (0.00248269), PSNR 25.85027885 (26.07542829), SSIM 0.75359190 (0.76305329)
Finished training epoch 329
Validate: MSE 0.00287916 (0.00245348), PSNR 25.40733528 (26.16369751), SSIM 0.70512652 (0.77033366)
Finished validation.
Starting training epoch 330
Epoch: 330, MS

In [None]:
torch.save(model.state_dict(), f'{checkpoints}/last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')

In [None]:
# Validate
save_images = True
with torch.no_grad():
    validate(val_loader, model, criterion, save_images, -1)

In [None]:
# # Show images 
# image_pairs = []

# for i in range(10):
#     image_pairs.append((f'{color_imgs}img-{i}-epoch-{best_epoch}.jpg', f'{gray_imgs}img-{i}-epoch-{best_epoch}.jpg'))
    
# for c, g in image_pairs:
#   color = mpimg.imread(c)
#   gray  = mpimg.imread(g)
#   f, axarr = plt.subplots(1, 2)
#   f.set_size_inches(15, 15)
#   axarr[0].imshow(gray, cmap='gray')
#   axarr[1].imshow(color)
#   axarr[0].axis('off'), axarr[1].axis('off')
#   plt.show()