In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# For conversion
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchmetrics import MeanSquaredError, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from PIL import Image
# For utilities
import os, shutil, time

In [2]:
%load_ext tensorboard
%tensorboard --logdir=runs

In [3]:
# colab i kaggle jeszcze nie testowane
colab = False
kaggle = False
test_number = '12_2'

In [4]:
color_imgs = 'outputs/color/'
gray_imgs = 'outputs/gray/'
checkpoints = 'checkpoints'
if colab:
    from google.colab import drive
    drive.mount('/content/drive')
    dataset = '/content/drive/MyDrive/MGU/cifar10/'
    
    color_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{color_imgs}'
    gray_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{gray_imgs}'
    checkpoints = f'/content/drive/MyDrive/MGU/{test_number}/{checkpoints}'
elif kaggle:
    dataset = '/kaggle/input/cifar10/'
    
    color_imgs = f'{test_number}/{color_imgs}'
    gray_imgs = f'{test_number}/{gray_imgs}'
    checkpoints = f'{test_number}/{checkpoints}'
else:
    dataset = '../../datasets/cifar10/'

In [5]:
# Make folders and set parameters
os.makedirs(color_imgs, exist_ok=True)
os.makedirs(gray_imgs, exist_ok=True)
os.makedirs(checkpoints, exist_ok=True)
save_images = True
best_losses = [1e10, 1e10, 1e10]
best_epoch = -1
patience = 50
epochs = 500
batch_size = 128
SIZE = 32

In [6]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [7]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()
print(use_gpu)

True


In [8]:
class LabImageFolder(torch.utils.data.Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC),
                transforms.RandomCrop(SIZE),
                transforms.RandomHorizontalFlip(),  
                transforms.RandomVerticalFlip(),  
                transforms.RandomRotation(10),
                transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05)
            ])
        elif split == 'val':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC), 
                transforms.RandomCrop(SIZE), 
            ])
            
        self.split = split
        self.size = SIZE
        self.paths = [os.path.join(paths, file) for file in os.listdir(paths) if os.path.isfile(
            os.path.join(paths, file))]
        
        
    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("RGB")
        img_original = self.transforms(img)
        img_original = np.asarray(img_original)
        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        img_gray = rgb2gray(img_original)
        img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
        return img_gray, img_ab
    
    def __len__(self):
        return len(self.paths)

In [9]:
# Training
train_imagefolder = LabImageFolder(dataset + 'train')
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=batch_size, shuffle=True)
# Validation 
val_imagefolder = LabImageFolder(dataset + 'val' , 'val')
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=batch_size, shuffle=False)

In [10]:
kernel_size=3
stride_en=1
stride_de=1
padding=1
scale_factor=2
padding_mode='zeros'
channels_base = 256
p1 = .25

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        self.conv1 = nn.Conv2d(1, channels_base, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        self.conv2 = nn.Conv2d(channels_base, channels_base * 2, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)

        self.convtrans1 = nn.ConvTranspose2d(channels_base * 2, channels_base, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans2 = nn.ConvTranspose2d(channels_base, channels_base // 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans3 = nn.ConvTranspose2d(channels_base // 2, 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)

        self.batchnorm1 = nn.BatchNorm2d(channels_base)
        self.batchnorm2 = nn.BatchNorm2d(channels_base * 2)
        self.batchnorm3 = nn.BatchNorm2d(channels_base)
        self.batchnorm4 = nn.BatchNorm2d(channels_base // 2)
        self.batchnorm5 = nn.BatchNorm2d(2)
        
    def forward(self, input):
        # encoder
        x = F.leaky_relu(self.batchnorm1(self.conv1(input)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = y = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        x = F.leaky_relu(self.batchnorm2(self.conv2(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        
        # decoder
        x = F.leaky_relu(self.batchnorm3(self.convtrans1(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.leaky_relu(self.batchnorm4(self.convtrans2(x + y)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.leaky_relu(self.batchnorm5(self.convtrans3(x + input)), negative_slope=0.1)

        return x

In [11]:
model = Autoencoder()

In [12]:
criterion = [MeanSquaredError(), PeakSignalNoiseRatio(data_range=1.0), StructuralSimilarityIndexMeasure(data_range=1.0)]

In [13]:
optimizer = torch.optim.SGD(model.parameters(), lr=1e-1, momentum=0.9, weight_decay=1e-5)

In [14]:
# # Move model and loss function to GPU
if use_gpu: 
    criterion = [criterion[0].to("cuda"), criterion[1].to("cuda"), criterion[2].to("cuda")]
    model = model.cuda()

In [15]:
if use_gpu: 
    from torchsummary import summary
    summary(model, (1, SIZE, SIZE))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 256, 32, 32]           2,560
       BatchNorm2d-2          [-1, 256, 32, 32]             512
            Conv2d-3          [-1, 512, 16, 16]       1,180,160
       BatchNorm2d-4          [-1, 512, 16, 16]           1,024
   ConvTranspose2d-5            [-1, 256, 8, 8]       1,179,904
       BatchNorm2d-6            [-1, 256, 8, 8]             512
   ConvTranspose2d-7          [-1, 128, 16, 16]         295,040
       BatchNorm2d-8          [-1, 128, 16, 16]             256
   ConvTranspose2d-9            [-1, 2, 32, 32]           2,306
      BatchNorm2d-10            [-1, 2, 32, 32]               4
Total params: 2,662,278
Trainable params: 2,662,278
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 6.78
Params size (MB): 10.16
Estima

In [16]:
class AverageMeter(object):
    '''A handy class from the PyTorch ImageNet tutorial''' 
    def __init__(self):
        self.reset()
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
       Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
    color_image = lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    if save_path is not None and save_name is not None: 
        plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
        plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

In [17]:
def validate(val_loader, model, criterion, save_images, epoch):
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]

    model.eval()
    already_saved_images = False
    for gray, ab in val_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()

        # Run model and record loss
        output_ab = model(gray) # throw away class predictions
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))

        # Save images to file
        if save_images and not already_saved_images:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)): # save at most 5 images
                save_path = {'grayscale': gray_imgs, 'colorized': color_imgs}
                save_name = 'img-{}-epoch-{}.jpg'.format(j, epoch)
                to_rgb(gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

    print(f'Validate: MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print('Finished validation.')
    if epoch >= 0:
        writer.add_scalar("MSE/test", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/test", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/test", _loss[2].avg, epoch)
    return _loss[0].avg, _loss[1].avg, _loss[2].avg

In [18]:
def train(train_loader, model, criterion, optimizer, epoch):
    print(f'Starting training epoch {epoch}')
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]
    
    model.train()

    for gray, ab in train_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()
            
        optimizer.zero_grad()

        output_ab = model(gray) 
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        loss[0].backward()
        optimizer.step()
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))
        
    print(f'Epoch: {epoch}, MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print(f'Finished training epoch {epoch}')
    if epoch >= 0:
        writer.add_scalar("MSE/train", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/train", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/train", _loss[2].avg, epoch)

In [19]:
# Train model
for epoch in range(epochs):
    # Train for one epoch, then validate
    train(train_loader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(val_loader, model, criterion, save_images, epoch)
    # Save checkpoint and replace old best model if current model is better
    if losses[0] < best_losses[0]:
        best_losses[0] = losses[0]
        best_epoch = epoch
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}.pth')
    if losses[1] < best_losses[1]:
        best_losses[1] = losses[1]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-PSNRLoss-{losses[1]:.8f}.pth')
    if losses[2] < best_losses[2]:
        best_losses[2] = losses[2]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-SSIMLoss-{losses[2]:.8f}.pth')
    
    if epoch - best_epoch >= patience and epoch >= 100:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}-early_stop.pth')
        break
    
    if epoch == epochs - 1:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')


Starting training epoch 0
Epoch: 0, MSE 0.00275440 (0.01035134), PSNR 25.59972191 (24.33031112), SSIM 0.74054354 (0.71809805)
Finished training epoch 0
Validate: MSE 0.00312409 (0.00276836), PSNR 25.05276871 (25.65877745), SSIM 0.70640135 (0.77491635)
Finished validation.
Starting training epoch 1
Epoch: 1, MSE 0.00262224 (0.00275835), PSNR 25.81326675 (25.61635816), SSIM 0.78478986 (0.76902870)
Finished training epoch 1
Validate: MSE 0.00336505 (0.00280434), PSNR 24.73008919 (25.59472744), SSIM 0.70571125 (0.77598963)
Finished validation.
Starting training epoch 2
Epoch: 2, MSE 0.00345118 (0.00275448), PSNR 24.62032700 (25.62154764), SSIM 0.73684347 (0.76923741)
Finished training epoch 2
Validate: MSE 0.00317125 (0.00275612), PSNR 24.98769760 (25.67471057), SSIM 0.70474565 (0.77375406)
Finished validation.
Starting training epoch 3
Epoch: 3, MSE 0.00291546 (0.00276316), PSNR 25.35293579 (25.61190337), SSIM 0.76530474 (0.76895246)
Finished training epoch 3
Validate: MSE 0.00319668 (0.0

Epoch: 30, MSE 0.00251411 (0.00259489), PSNR 25.99615288 (25.88563211), SSIM 0.77301586 (0.76105055)
Finished training epoch 30
Validate: MSE 0.00296504 (0.00257298), PSNR 25.27969933 (25.96163705), SSIM 0.69369513 (0.76271602)
Finished validation.
Starting training epoch 31
Epoch: 31, MSE 0.00264458 (0.00258991), PSNR 25.77643776 (25.88849471), SSIM 0.73080635 (0.76073460)
Finished training epoch 31
Validate: MSE 0.00299515 (0.00256120), PSNR 25.23580742 (25.98319536), SSIM 0.70114613 (0.76908816)
Finished validation.
Starting training epoch 32
Epoch: 32, MSE 0.00243794 (0.00258028), PSNR 26.12976646 (25.90443763), SSIM 0.77313149 (0.76082935)
Finished training epoch 32
Validate: MSE 0.00293109 (0.00255268), PSNR 25.32970047 (25.99655147), SSIM 0.69934225 (0.76620150)
Finished validation.
Starting training epoch 33
Epoch: 33, MSE 0.00252816 (0.00257238), PSNR 25.97195816 (25.91900755), SSIM 0.76952839 (0.76106461)
Finished training epoch 33
Validate: MSE 0.00296172 (0.00256323), PSNR 

Epoch: 60, MSE 0.00247877 (0.00250058), PSNR 26.05764198 (26.04142398), SSIM 0.75156569 (0.76649267)
Finished training epoch 60
Validate: MSE 0.00297794 (0.00248840), PSNR 25.26083565 (26.10671004), SSIM 0.70794088 (0.77222864)
Finished validation.
Starting training epoch 61
Epoch: 61, MSE 0.00257135 (0.00249839), PSNR 25.89838600 (26.04962130), SSIM 0.77389413 (0.76678761)
Finished training epoch 61
Validate: MSE 0.00293927 (0.00246244), PSNR 25.31759834 (26.14923147), SSIM 0.70795870 (0.77400538)
Finished validation.
Starting training epoch 62
Epoch: 62, MSE 0.00275591 (0.00249341), PSNR 25.59735489 (26.05542146), SSIM 0.74002081 (0.76670928)
Finished training epoch 62
Validate: MSE 0.00292141 (0.00246730), PSNR 25.34406662 (26.14020285), SSIM 0.70753974 (0.77213911)
Finished validation.
Starting training epoch 63
Epoch: 63, MSE 0.00216505 (0.00248946), PSNR 26.64532852 (26.06447421), SSIM 0.74936843 (0.76703904)
Finished training epoch 63
Validate: MSE 0.00288601 (0.00245102), PSNR 

Epoch: 90, MSE 0.00281569 (0.00245994), PSNR 25.50415802 (26.11239053), SSIM 0.77916187 (0.76941890)
Finished training epoch 90
Validate: MSE 0.00290468 (0.00243963), PSNR 25.36901665 (26.19075820), SSIM 0.71115947 (0.77648558)
Finished validation.
Starting training epoch 91
Epoch: 91, MSE 0.00270411 (0.00245590), PSNR 25.67976189 (26.11709267), SSIM 0.75850546 (0.76954720)
Finished training epoch 91
Validate: MSE 0.00284246 (0.00242745), PSNR 25.46305847 (26.20960449), SSIM 0.71127319 (0.77458825)
Finished validation.
Starting training epoch 92
Epoch: 92, MSE 0.00303261 (0.00245523), PSNR 25.18183327 (26.12172284), SSIM 0.75962603 (0.76943968)
Finished training epoch 92
Validate: MSE 0.00297725 (0.00245891), PSNR 25.26185226 (26.15804457), SSIM 0.71192408 (0.77474099)
Finished validation.
Starting training epoch 93
Epoch: 93, MSE 0.00273277 (0.00245810), PSNR 25.63396835 (26.11831312), SSIM 0.77956378 (0.76960699)
Finished training epoch 93
Validate: MSE 0.00298599 (0.00245613), PSNR 

Epoch: 120, MSE 0.00260208 (0.00242921), PSNR 25.84679031 (26.16959542), SSIM 0.77088511 (0.77040175)
Finished training epoch 120
Validate: MSE 0.00295202 (0.00241572), PSNR 25.29880524 (26.23338646), SSIM 0.71251589 (0.77845382)
Finished validation.
Starting training epoch 121
Epoch: 121, MSE 0.00237271 (0.00243396), PSNR 26.24754715 (26.16027393), SSIM 0.80134982 (0.77002338)
Finished training epoch 121
Validate: MSE 0.00286982 (0.00240221), PSNR 25.42145920 (26.25604960), SSIM 0.71142673 (0.77781731)
Finished validation.
Starting training epoch 122
Epoch: 122, MSE 0.00314168 (0.00242910), PSNR 25.02837944 (26.16906267), SSIM 0.75275654 (0.77029811)
Finished training epoch 122
Validate: MSE 0.00289909 (0.00239615), PSNR 25.37738609 (26.26904369), SSIM 0.71178544 (0.77708331)
Finished validation.
Starting training epoch 123
Epoch: 123, MSE 0.00260053 (0.00242776), PSNR 25.84937668 (26.17122047), SSIM 0.76228726 (0.77037025)
Finished training epoch 123
Validate: MSE 0.00284258 (0.00238

Validate: MSE 0.00280102 (0.00236476), PSNR 25.52683640 (26.32185410), SSIM 0.71056688 (0.77699758)
Finished validation.
Starting training epoch 150
Epoch: 150, MSE 0.00203706 (0.00239579), PSNR 26.90997124 (26.23212643), SSIM 0.79697168 (0.77107326)
Finished training epoch 150
Validate: MSE 0.00279885 (0.00236293), PSNR 25.53020859 (26.32818906), SSIM 0.71572793 (0.77919368)
Finished validation.
Starting training epoch 151
Epoch: 151, MSE 0.00231057 (0.00239898), PSNR 26.36281586 (26.22264719), SSIM 0.76535094 (0.77086303)
Finished training epoch 151
Validate: MSE 0.00279761 (0.00237219), PSNR 25.53212547 (26.30887533), SSIM 0.70963526 (0.77756897)
Finished validation.
Starting training epoch 152
Epoch: 152, MSE 0.00219778 (0.00239840), PSNR 26.58015251 (26.22674203), SSIM 0.76975423 (0.77107627)
Finished training epoch 152
Validate: MSE 0.00287661 (0.00236993), PSNR 25.41118813 (26.31706415), SSIM 0.71469045 (0.77928361)
Finished validation.
Starting training epoch 153
Epoch: 153, MS

Epoch: 179, MSE 0.00196185 (0.00238878), PSNR 27.07333565 (26.24245798), SSIM 0.79648703 (0.77177401)
Finished training epoch 179
Validate: MSE 0.00295317 (0.00246287), PSNR 25.29712105 (26.15498714), SSIM 0.71365499 (0.77959303)
Finished validation.
Starting training epoch 180
Epoch: 180, MSE 0.00193023 (0.00238798), PSNR 27.14390755 (26.24397744), SSIM 0.79425269 (0.77174931)
Finished training epoch 180
Validate: MSE 0.00283744 (0.00236323), PSNR 25.47072983 (26.32694167), SSIM 0.71474379 (0.77769881)
Finished validation.
Starting training epoch 181
Epoch: 181, MSE 0.00265009 (0.00238135), PSNR 25.76740074 (26.25476388), SSIM 0.76958740 (0.77212279)
Finished training epoch 181
Validate: MSE 0.00276389 (0.00246102), PSNR 25.58478737 (26.14465104), SSIM 0.70570904 (0.77064220)
Finished validation.
Starting training epoch 182
Epoch: 182, MSE 0.00208570 (0.00238421), PSNR 26.80747795 (26.25151053), SSIM 0.77552450 (0.77189403)
Finished training epoch 182
Validate: MSE 0.00277079 (0.00234

Validate: MSE 0.00275814 (0.00233232), PSNR 25.59382820 (26.38103843), SSIM 0.71451813 (0.77684628)
Finished validation.
Starting training epoch 209
Epoch: 209, MSE 0.00247532 (0.00237263), PSNR 26.06368828 (26.27124740), SSIM 0.78531569 (0.77262102)
Finished training epoch 209
Validate: MSE 0.00304765 (0.00255412), PSNR 25.16034698 (26.00203663), SSIM 0.71293998 (0.77955727)
Finished validation.
Starting training epoch 210
Epoch: 210, MSE 0.00177914 (0.00237026), PSNR 27.49789047 (26.27845255), SSIM 0.79878694 (0.77262296)
Finished training epoch 210
Validate: MSE 0.00274480 (0.00234405), PSNR 25.61488533 (26.36181182), SSIM 0.71630645 (0.78026278)
Finished validation.
Starting training epoch 211
Epoch: 211, MSE 0.00248536 (0.00236720), PSNR 26.04611206 (26.27988220), SSIM 0.75684142 (0.77296877)
Finished training epoch 211
Validate: MSE 0.00273845 (0.00235707), PSNR 25.62495232 (26.33548388), SSIM 0.71252412 (0.77538892)
Finished validation.
Starting training epoch 212
Epoch: 212, MS

Epoch: 238, MSE 0.00259989 (0.00235731), PSNR 25.85044479 (26.29936547), SSIM 0.77979434 (0.77351538)
Finished training epoch 238
Validate: MSE 0.00269220 (0.00232321), PSNR 25.69892502 (26.39672751), SSIM 0.71511006 (0.77942504)
Finished validation.
Starting training epoch 239
Epoch: 239, MSE 0.00214281 (0.00235902), PSNR 26.69016266 (26.29504933), SSIM 0.78461570 (0.77351872)
Finished training epoch 239
Validate: MSE 0.00274893 (0.00231255), PSNR 25.60836220 (26.41856383), SSIM 0.71424353 (0.77884263)
Finished validation.
Starting training epoch 240
Epoch: 240, MSE 0.00211457 (0.00236225), PSNR 26.74777031 (26.28990931), SSIM 0.78389460 (0.77369060)
Finished training epoch 240
Validate: MSE 0.00268573 (0.00231833), PSNR 25.70937347 (26.40608436), SSIM 0.71547031 (0.77945035)
Finished validation.
Starting training epoch 241
Epoch: 241, MSE 0.00233837 (0.00234629), PSNR 26.31086922 (26.31725577), SSIM 0.76769942 (0.77372869)
Finished training epoch 241
Validate: MSE 0.00266021 (0.00268

Validate: MSE 0.00273480 (0.00231959), PSNR 25.63074303 (26.40527149), SSIM 0.71537292 (0.78208769)
Finished validation.
Starting training epoch 268
Epoch: 268, MSE 0.00246082 (0.00235234), PSNR 26.08920670 (26.30651785), SSIM 0.77056026 (0.77400883)
Finished training epoch 268
Validate: MSE 0.00267538 (0.00232934), PSNR 25.72614670 (26.38603144), SSIM 0.71404684 (0.77970110)
Finished validation.
Starting training epoch 269
Epoch: 269, MSE 0.00213768 (0.00234800), PSNR 26.70057106 (26.31480417), SSIM 0.75847405 (0.77390045)
Finished training epoch 269
Validate: MSE 0.00277123 (0.00230978), PSNR 25.57327080 (26.42449094), SSIM 0.71702266 (0.78193962)
Finished validation.
Starting training epoch 270
Epoch: 270, MSE 0.00204822 (0.00235143), PSNR 26.88624001 (26.31088354), SSIM 0.78773743 (0.77424774)
Finished training epoch 270
Validate: MSE 0.00279550 (0.00231348), PSNR 25.53540993 (26.41672582), SSIM 0.71664464 (0.78111953)
Finished validation.
Starting training epoch 271
Epoch: 271, MS

Epoch: 297, MSE 0.00257841 (0.00233908), PSNR 25.88648415 (26.33230906), SSIM 0.76083809 (0.77461868)
Finished training epoch 297
Validate: MSE 0.00267436 (0.00232337), PSNR 25.72779846 (26.39731492), SSIM 0.71672571 (0.78153083)
Finished validation.
Starting training epoch 298
Epoch: 298, MSE 0.00252362 (0.00233883), PSNR 25.97975731 (26.33471727), SSIM 0.77717721 (0.77448332)
Finished training epoch 298
Validate: MSE 0.00270590 (0.00234555), PSNR 25.67688751 (26.35680411), SSIM 0.71614701 (0.78164316)
Finished validation.
Starting training epoch 299
Epoch: 299, MSE 0.00245101 (0.00233666), PSNR 26.10655212 (26.33446944), SSIM 0.77350843 (0.77475529)
Finished training epoch 299
Validate: MSE 0.00265675 (0.00232741), PSNR 25.75649261 (26.38729459), SSIM 0.71668816 (0.78215797)
Finished validation.
Starting training epoch 300
Epoch: 300, MSE 0.00275839 (0.00233877), PSNR 25.59344292 (26.33456286), SSIM 0.77094650 (0.77458261)
Finished training epoch 300
Validate: MSE 0.00275409 (0.00231

Validate: MSE 0.00287480 (0.00243196), PSNR 25.41392326 (26.21091240), SSIM 0.71644235 (0.78183874)
Finished validation.
Starting training epoch 327
Epoch: 327, MSE 0.00222768 (0.00232974), PSNR 26.52147865 (26.35095915), SSIM 0.79833084 (0.77495092)
Finished training epoch 327
Validate: MSE 0.00284799 (0.00235991), PSNR 25.45462036 (26.33591992), SSIM 0.71727782 (0.78244120)
Finished validation.
Starting training epoch 328
Epoch: 328, MSE 0.00217568 (0.00232618), PSNR 26.62405014 (26.35519278), SSIM 0.76128972 (0.77528629)
Finished training epoch 328
Validate: MSE 0.00261737 (0.00228736), PSNR 25.82135010 (26.46213240), SSIM 0.71777612 (0.77942050)
Finished validation.
Starting training epoch 329
Epoch: 329, MSE 0.00226856 (0.00232403), PSNR 26.44248581 (26.35707235), SSIM 0.78470123 (0.77499501)
Finished training epoch 329
Validate: MSE 0.00279224 (0.00231382), PSNR 25.54047394 (26.41792106), SSIM 0.71480584 (0.78079860)
Finished validation.
Starting training epoch 330
Epoch: 330, MS

Epoch: 356, MSE 0.00186344 (0.00231846), PSNR 27.29683876 (26.36797190), SSIM 0.78949374 (0.77539962)
Finished training epoch 356
Validate: MSE 0.00270948 (0.00231556), PSNR 25.67114449 (26.41516702), SSIM 0.71337539 (0.78134290)
Finished validation.
Starting training epoch 357
Epoch: 357, MSE 0.00261788 (0.00232541), PSNR 25.82050514 (26.35918398), SSIM 0.76435834 (0.77525523)
Finished training epoch 357
Validate: MSE 0.00287246 (0.00239536), PSNR 25.41746140 (26.27566996), SSIM 0.71743667 (0.78288187)
Finished validation.
Starting training epoch 358
Epoch: 358, MSE 0.00293969 (0.00231635), PSNR 25.31698799 (26.37358156), SSIM 0.77808815 (0.77560664)
Finished training epoch 358
Validate: MSE 0.00266250 (0.00234527), PSNR 25.74710083 (26.35662808), SSIM 0.71702874 (0.78254452)
Finished validation.
Starting training epoch 359
Epoch: 359, MSE 0.00213228 (0.00232254), PSNR 26.71155548 (26.36307474), SSIM 0.78423470 (0.77539572)
Finished training epoch 359
Validate: MSE 0.00276519 (0.00236

Validate: MSE 0.00272961 (0.00228576), PSNR 25.63899422 (26.47395301), SSIM 0.71827322 (0.78310487)
Finished validation.
Starting training epoch 386
Epoch: 386, MSE 0.00269133 (0.00230920), PSNR 25.70032883 (26.38682605), SSIM 0.76719654 (0.77580731)
Finished training epoch 386
Validate: MSE 0.00280734 (0.00230097), PSNR 25.51705360 (26.44581022), SSIM 0.71731383 (0.78558876)
Finished validation.
Starting training epoch 387
Epoch: 387, MSE 0.00219828 (0.00231446), PSNR 26.57916832 (26.37687164), SSIM 0.78606236 (0.77579230)
Finished training epoch 387
Validate: MSE 0.00292675 (0.00251430), PSNR 25.33614540 (26.06999545), SSIM 0.71698374 (0.78241991)
Finished validation.
Starting training epoch 388
Epoch: 388, MSE 0.00227233 (0.00230956), PSNR 26.43528748 (26.38944634), SSIM 0.78317130 (0.77577991)
Finished training epoch 388
Validate: MSE 0.00271856 (0.00227761), PSNR 25.65661621 (26.48156709), SSIM 0.71435946 (0.78146040)
Finished validation.
Starting training epoch 389
Epoch: 389, MS

<Figure size 432x288 with 0 Axes>

In [20]:
torch.save(model.state_dict(), f'{checkpoints}/last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')

In [21]:
# Validate
save_images = True
with torch.no_grad():
    validate(val_loader, model, criterion, save_images, -1)

Validate: MSE 0.00272881 (0.00232502), PSNR 25.64026451 (26.39654459), SSIM 0.72002816 (0.78196699)
Finished validation.


<Figure size 432x288 with 0 Axes>

In [22]:
# # Show images 
# image_pairs = []

# for i in range(10):
#     image_pairs.append((f'{color_imgs}img-{i}-epoch-{best_epoch}.jpg', f'{gray_imgs}img-{i}-epoch-{best_epoch}.jpg'))
    
# for c, g in image_pairs:
#   color = mpimg.imread(c)
#   gray  = mpimg.imread(g)
#   f, axarr = plt.subplots(1, 2)
#   f.set_size_inches(15, 15)
#   axarr[0].imshow(gray, cmap='gray')
#   axarr[1].imshow(color)
#   axarr[0].axis('off'), axarr[1].axis('off')
#   plt.show()