In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# For conversion
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchmetrics import MeanSquaredError, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from PIL import Image
# For utilities
import os, shutil, time

In [2]:
%load_ext tensorboard
%tensorboard --logdir=runs

In [3]:
# colab i kaggle jeszcze nie testowane
colab = False
kaggle = False
test_number = '12_2'

In [4]:
color_imgs = 'outputs/color/'
gray_imgs = 'outputs/gray/'
checkpoints = 'checkpoints'
if colab:
    from google.colab import drive
    drive.mount('/content/drive')
    dataset = '/content/drive/MyDrive/MGU/cifar10/'
    
    color_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{color_imgs}'
    gray_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{gray_imgs}'
    checkpoints = f'/content/drive/MyDrive/MGU/{test_number}/{checkpoints}'
elif kaggle:
    dataset = '/kaggle/input/cifar10/'
    
    color_imgs = f'{test_number}/{color_imgs}'
    gray_imgs = f'{test_number}/{gray_imgs}'
    checkpoints = f'{test_number}/{checkpoints}'
else:
    dataset = '../../datasets/cifar10/'

In [5]:
# Make folders and set parameters
os.makedirs(color_imgs, exist_ok=True)
os.makedirs(gray_imgs, exist_ok=True)
os.makedirs(checkpoints, exist_ok=True)
save_images = True
best_losses = [1e10, 1e10, 1e10]
best_epoch = -1
patience = 50
epochs = 500
batch_size = 128
SIZE = 32

In [6]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [7]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()
print(use_gpu)

True


In [8]:
class LabImageFolder(torch.utils.data.Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC),
                transforms.RandomCrop(SIZE),
                transforms.RandomHorizontalFlip(),  
                transforms.RandomVerticalFlip(),  
                transforms.RandomRotation(10),
                transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05)
            ])
        elif split == 'val':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC), 
                transforms.RandomCrop(SIZE), 
            ])
            
        self.split = split
        self.size = SIZE
        self.paths = [os.path.join(paths, file) for file in os.listdir(paths) if os.path.isfile(
            os.path.join(paths, file))]
        
        
    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("RGB")
        img_original = self.transforms(img)
        img_original = np.asarray(img_original)
        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        img_gray = rgb2gray(img_original)
        img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
        return img_gray, img_ab
    
    def __len__(self):
        return len(self.paths)

In [9]:
# Training
train_imagefolder = LabImageFolder(dataset + 'train')
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=batch_size, shuffle=True)
# Validation 
val_imagefolder = LabImageFolder(dataset + 'val' , 'val')
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=batch_size, shuffle=False)

In [10]:
kernel_size=3
stride_en=1
stride_de=1
padding=1
scale_factor=2
padding_mode='zeros'
channels_base = 256
p1 = .25

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        self.conv1 = nn.Conv2d(1, channels_base, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        self.conv2 = nn.Conv2d(channels_base, channels_base * 2, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)

        self.convtrans1 = nn.ConvTranspose2d(channels_base * 2, channels_base, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans2 = nn.ConvTranspose2d(channels_base, channels_base // 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans3 = nn.ConvTranspose2d(channels_base // 2, 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)

        self.batchnorm1 = nn.BatchNorm2d(channels_base)
        self.batchnorm2 = nn.BatchNorm2d(channels_base * 2)
        self.batchnorm3 = nn.BatchNorm2d(channels_base)
        self.batchnorm4 = nn.BatchNorm2d(channels_base // 2)
        self.batchnorm5 = nn.BatchNorm2d(2)
        
    def forward(self, input):
        # encoder
        x = F.leaky_relu(self.batchnorm1(self.conv1(input)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = y = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        x = F.leaky_relu(self.batchnorm2(self.conv2(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        
        # decoder
        x = F.leaky_relu(self.batchnorm3(self.convtrans1(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.leaky_relu(self.batchnorm4(self.convtrans2(x + y)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.leaky_relu(self.batchnorm5(self.convtrans3(x + input)), negative_slope=0.1)

        return x

In [11]:
model = Autoencoder()

In [12]:
criterion = [MeanSquaredError(), PeakSignalNoiseRatio(data_range=1.0), StructuralSimilarityIndexMeasure(data_range=1.0)]

In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [14]:
# # Move model and loss function to GPU
if use_gpu: 
    criterion = [criterion[0].to("cuda"), criterion[1].to("cuda"), criterion[2].to("cuda")]
    model = model.cuda()

In [15]:
if use_gpu: 
    from torchsummary import summary
    summary(model, (1, SIZE, SIZE))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 256, 32, 32]           2,560
       BatchNorm2d-2          [-1, 256, 32, 32]             512
            Conv2d-3          [-1, 512, 16, 16]       1,180,160
       BatchNorm2d-4          [-1, 512, 16, 16]           1,024
   ConvTranspose2d-5            [-1, 256, 8, 8]       1,179,904
       BatchNorm2d-6            [-1, 256, 8, 8]             512
   ConvTranspose2d-7          [-1, 128, 16, 16]         295,040
       BatchNorm2d-8          [-1, 128, 16, 16]             256
   ConvTranspose2d-9            [-1, 2, 32, 32]           2,306
      BatchNorm2d-10            [-1, 2, 32, 32]               4
Total params: 2,662,278
Trainable params: 2,662,278
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 6.78
Params size (MB): 10.16
Estima

In [16]:
class AverageMeter(object):
    '''A handy class from the PyTorch ImageNet tutorial''' 
    def __init__(self):
        self.reset()
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
       Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
    color_image = lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    if save_path is not None and save_name is not None: 
        plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
        plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

In [17]:
def validate(val_loader, model, criterion, save_images, epoch):
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]

    model.eval()
    already_saved_images = False
    for gray, ab in val_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()

        # Run model and record loss
        output_ab = model(gray) # throw away class predictions
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))

        # Save images to file
        if save_images and not already_saved_images:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)): # save at most 5 images
                save_path = {'grayscale': gray_imgs, 'colorized': color_imgs}
                save_name = 'img-{}-epoch-{}.jpg'.format(j, epoch)
                to_rgb(gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

    print(f'Validate: MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print('Finished validation.')
    if epoch >= 0:
        writer.add_scalar("MSE/test", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/test", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/test", _loss[2].avg, epoch)
    return _loss[0].avg, _loss[1].avg, _loss[2].avg

In [18]:
def train(train_loader, model, criterion, optimizer, epoch):
    print(f'Starting training epoch {epoch}')
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]
    
    model.train()

    for gray, ab in train_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()
            
        optimizer.zero_grad()

        output_ab = model(gray) 
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        loss[1].backward()
        optimizer.step()
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))
        
    print(f'Epoch: {epoch}, MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print(f'Finished training epoch {epoch}')
    if epoch >= 0:
        writer.add_scalar("MSE/train", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/train", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/train", _loss[2].avg, epoch)

In [19]:
# Train model
for epoch in range(epochs):
    # Train for one epoch, then validate
    train(train_loader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(val_loader, model, criterion, save_images, epoch)
    # Save checkpoint and replace old best model if current model is better
    if losses[0] < best_losses[0]:
        best_losses[0] = losses[0]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}.pth')
    if losses[1] < best_losses[1]:
        best_losses[1] = losses[1]
        best_epoch = epoch
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-PSNRLoss-{losses[1]:.8f}.pth')
    if losses[2] < best_losses[2]:
        best_losses[2] = losses[2]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-SSIMLoss-{losses[2]:.8f}.pth')
    
    if epoch - best_epoch >= patience and epoch >= 100:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-PSNRLoss-{losses[1]:.8f}-early_stop.pth')
        break
    
    if epoch == epochs - 1:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')


Starting training epoch 0
Epoch: 0, MSE 23.02385712 (8.89927898), PSNR -13.62177944 (-8.12911477), SSIM 0.00098736 (0.03303412)
Finished training epoch 0


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 32.01262283 (20.82948008), PSNR -15.05321217 (-12.85713765), SSIM -0.00284308 (0.00369839)
Finished validation.
Starting training epoch 1
Epoch: 1, MSE 73.78147888 (47.21137739), PSNR -18.67947388 (-16.51441091), SSIM -0.00047635 (0.00007291)
Finished training epoch 1


  return func(*args, **kwargs)


Validate: MSE 92.65051270 (72.53884810), PSNR -19.66847610 (-18.49876414), SSIM -0.00104874 (-0.00045369)
Finished validation.
Starting training epoch 2
Epoch: 2, MSE 137.39656067 (104.61431704), PSNR -21.37975883 (-20.12775042), SSIM -0.00019352 (-0.00038564)
Finished training epoch 2
Validate: MSE 156.00096130 (155.56190544), PSNR -21.93127251 (-21.87692229), SSIM -0.00038169 (-0.00024906)
Finished validation.
Starting training epoch 3
Epoch: 3, MSE 212.82283020 (174.14957977), PSNR -23.28018188 (-22.37492743), SSIM -0.00029718 (-0.00026355)
Finished training epoch 3


  return func(*args, **kwargs)


Validate: MSE 178.52282715 (213.59309883), PSNR -22.51693726 (-23.23048958), SSIM -0.00045007 (-0.00026167)
Finished validation.
Starting training epoch 4
Epoch: 4, MSE 300.72299194 (255.83690397), PSNR -24.78166580 (-24.05815869), SSIM -0.00019486 (-0.00021512)
Finished training epoch 4
Validate: MSE 231.16839600 (271.76889014), PSNR -23.63928223 (-24.27796458), SSIM -0.00031805 (-0.00021835)
Finished validation.
Starting training epoch 5
Epoch: 5, MSE 402.25018311 (350.38611107), PSNR -26.04496002 (-25.43031958), SSIM -0.00013904 (-0.00017777)
Finished training epoch 5
Validate: MSE 333.26501465 (413.32687051), PSNR -25.22789574 (-26.09697203), SSIM -0.00028607 (-0.00016432)
Finished validation.
Starting training epoch 6
Epoch: 6, MSE 517.80297852 (458.82123953), PSNR -27.14164352 (-26.60490621), SSIM -0.00015086 (-0.00014959)
Finished training epoch 6
Validate: MSE 524.05371094 (666.96765820), PSNR -27.19375801 (-28.17364707), SSIM -0.00017650 (-0.00010591)
Finished validation.
Star

  return func(*args, **kwargs)


Validate: MSE 534.92657471 (632.71387705), PSNR -27.28293991 (-27.98316064), SSIM -0.00017128 (-0.00011420)
Finished validation.
Starting training epoch 9
Epoch: 9, MSE 965.84375000 (880.54380844), PSNR -29.84906769 (-29.44094969), SSIM -0.00008200 (-0.00009230)
Finished training epoch 9
Validate: MSE 717.33825684 (844.76950918), PSNR -28.55723953 (-29.23741585), SSIM -0.00016288 (-0.00009324)
Finished validation.
Starting training epoch 10
Epoch: 10, MSE 1154.35083008 (1058.37622070), PSNR -30.62337685 (-30.24067707), SSIM -0.00007716 (-0.00008064)
Finished training epoch 10
Validate: MSE 1206.03833008 (1490.30111602), PSNR -30.81361008 (-31.68037580), SSIM -0.00010148 (-0.00005940)
Finished validation.
Starting training epoch 11
Epoch: 11, MSE 1364.62719727 (1257.76139789), PSNR -31.35013771 (-30.99090380), SSIM -0.00008247 (-0.00007008)
Finished training epoch 11
Validate: MSE 1071.83410645 (1400.41595645), PSNR -30.30127525 (-31.39222345), SSIM -0.00013030 (-0.00006888)
Finished va

  return func(*args, **kwargs)


Validate: MSE 3444.14404297 (4504.80956797), PSNR -35.37081146 (-36.51287010), SSIM -0.00006079 (-0.00002723)
Finished validation.
Starting training epoch 19
Epoch: 19, MSE 4160.53613281 (3914.69595156), PSNR -36.19149017 (-35.92420878), SSIM -0.00001884 (-0.00003015)
Finished training epoch 19


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 4342.49218750 (5909.21376875), PSNR -36.37738800 (-37.69034122), SSIM -0.00007977 (-0.00002861)
Finished validation.
Starting training epoch 20
Epoch: 20, MSE 4691.87402344 (4422.93607594), PSNR -36.71346283 (-36.45448860), SSIM -0.00002231 (-0.00002827)
Finished training epoch 20
Validate: MSE 5018.75781250 (7759.86276875), PSNR -37.00596237 (-38.78620075), SSIM -0.00004975 (-0.00001696)
Finished validation.
Starting training epoch 21
Epoch: 21, MSE 5274.23144531 (4979.76887781), PSNR -37.22158813 (-36.96962376), SSIM -0.00001804 (-0.00002680)
Finished training epoch 21
Validate: MSE 4119.20361328 (7215.07778828), PSNR -36.14812851 (-38.37758974), SSIM -0.00006316 (-0.00002256)
Finished validation.
Starting training epoch 22
Epoch: 22, MSE 5906.47021484 (5586.96846359), PSNR -37.71327972 (-37.46944266), SSIM -0.00003463 (-0.00002550)
Finished training epoch 22


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 6170.04492188 (8531.31041562), PSNR -37.90288162 (-39.28180711), SSIM -0.00004462 (-0.00002680)
Finished validation.
Starting training epoch 23
Epoch: 23, MSE 6592.20068359 (6245.71564734), PSNR -38.19030380 (-37.95364832), SSIM -0.00002370 (-0.00002454)
Finished training epoch 23
Validate: MSE 4305.01269531 (5580.51118281), PSNR -36.33974457 (-37.41500643), SSIM -0.00005958 (-0.00003031)
Finished validation.
Starting training epoch 24
Epoch: 24, MSE 7328.95605469 (6956.65945719), PSNR -38.65042114 (-38.42197734), SSIM -0.00002475 (-0.00002394)
Finished training epoch 24
Validate: MSE 5466.69238281 (8338.88580781), PSNR -37.37724686 (-39.14953192), SSIM -0.00005598 (-0.00002646)
Finished validation.
Starting training epoch 25
Epoch: 25, MSE 8118.43066406 (7720.14445656), PSNR -39.09471893 (-38.87436150), SSIM -0.00002883 (-0.00002315)
Finished training epoch 25
Validate: MSE 6138.54248047 (8446.02163672), PSNR -37.88064957 (-39.21160928), SSIM -0.00006169 (-0.00002610)
Fi

  return func(*args, **kwargs)


Validate: MSE 25816.51171875 (42989.52081875), PSNR -44.11897278 (-46.24842898), SSIM -0.00004432 (-0.00002176)
Finished validation.
Starting training epoch 45
Epoch: 45, MSE 35211.96875000 (34269.11409000), PSNR -45.46689987 (-45.34848369), SSIM -0.00001018 (-0.00001425)
Finished training epoch 45
Validate: MSE 14322.63964844 (25418.39866094), PSNR -41.56023026 (-43.79505306), SSIM -0.00003328 (-0.00001698)
Finished validation.
Starting training epoch 46
Epoch: 46, MSE 37148.61328125 (36176.60624125), PSNR -45.69942474 (-45.58375894), SSIM -0.00001492 (-0.00001475)
Finished training epoch 46
Validate: MSE 19461.38085938 (28614.49847187), PSNR -42.89173508 (-44.37972097), SSIM -0.00005911 (-0.00001543)
Finished validation.
Starting training epoch 47
Epoch: 47, MSE 39139.80468750 (38140.65603750), PSNR -45.92618179 (-45.81338669), SSIM -0.00001336 (-0.00001530)
Finished training epoch 47


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 29943.91796875 (59437.45729375), PSNR -44.76308441 (-47.60453383), SSIM -0.00002410 (-0.00001481)
Finished validation.
Starting training epoch 48
Epoch: 48, MSE 41188.75781250 (40161.31467250), PSNR -46.14778519 (-46.03760712), SSIM -0.00001622 (-0.00001426)
Finished training epoch 48
Validate: MSE 38207.16406250 (74231.94301250), PSNR -45.82144547 (-48.47536246), SSIM -0.00002727 (-0.00000804)
Finished validation.
Starting training epoch 49
Epoch: 49, MSE 43294.26171875 (42238.86595875), PSNR -46.36430359 (-46.25667154), SSIM -0.00000971 (-0.00001237)
Finished training epoch 49
Validate: MSE 40541.98437500 (63734.93152500), PSNR -46.07904816 (-47.93179406), SSIM -0.00002425 (-0.00001222)
Finished validation.
Starting training epoch 50
Epoch: 50, MSE 45457.43750000 (44373.31335000), PSNR -46.57604980 (-46.47078732), SSIM -0.00001458 (-0.00001301)
Finished training epoch 50
Validate: MSE 37400.67968750 (65479.85448750), PSNR -45.72879028 (-47.98238247), SSIM -0.00002989 (-

  return func(*args, **kwargs)


Validate: MSE 38075.66796875 (69829.28811875), PSNR -45.80647278 (-48.08181130), SSIM -0.00003188 (-0.00001105)
Finished validation.
Starting training epoch 60
Epoch: 60, MSE 70240.49218750 (68869.66236750), PSNR -48.46587372 (-48.37999189), SSIM -0.00001111 (-0.00000954)
Finished training epoch 60
Validate: MSE 49823.57031250 (74411.05021250), PSNR -46.97434616 (-48.60300798), SSIM -0.00001973 (-0.00000717)
Finished validation.
Starting training epoch 61
Epoch: 61, MSE 73034.96093750 (71636.47247750), PSNR -48.63530350 (-48.55106480), SSIM -0.00001001 (-0.00000959)
Finished training epoch 61
Validate: MSE 41361.32421875 (63198.60886875), PSNR -46.16593933 (-47.87153313), SSIM -0.00001841 (-0.00000956)
Finished validation.
Starting training epoch 62
Epoch: 62, MSE 75887.58593750 (74460.93365750), PSNR -48.80170441 (-48.71901743), SSIM -0.00000628 (-0.00000890)
Finished training epoch 62
Validate: MSE 52978.81640625 (95274.59885625), PSNR -47.24102020 (-49.35880878), SSIM -0.00000338 (-

Epoch: 87, MSE 166214.10937500 (164053.20941500), PSNR -52.20667648 (-52.14971967), SSIM -0.00000500 (-0.00000459)
Finished training epoch 87
Validate: MSE 114246.01562500 (191235.81002500), PSNR -50.57840729 (-52.70042692), SSIM -0.00001147 (-0.00000400)
Finished validation.
Starting training epoch 88
Epoch: 88, MSE 170595.51562500 (168404.16410500), PSNR -52.31967163 (-52.26340379), SSIM -0.00000572 (-0.00000453)
Finished training epoch 88
Validate: MSE 99223.23437500 (147662.57247500), PSNR -49.96613312 (-51.52836100), SSIM -0.00001750 (-0.00000365)
Finished validation.
Starting training epoch 89
Epoch: 89, MSE 175030.93750000 (172812.26030000), PSNR -52.43114471 (-52.37562423), SSIM -0.00000315 (-0.00000422)
Finished training epoch 89
Validate: MSE 139464.71875000 (255275.49155000), PSNR -51.44464111 (-53.93522886), SSIM -0.00001126 (-0.00000387)
Finished validation.
Starting training epoch 90
Epoch: 90, MSE 179526.35937500 (177277.37781500), PSNR -52.54128265 (-52.48641489), SSIM 

Epoch: 114, MSE 305199.06250000 (302234.58778000), PSNR -54.84582901 (-54.80336971), SSIM -0.00000177 (-0.00000254)
Finished training epoch 114
Validate: MSE 169461.84375000 (291129.72795000), PSNR -52.29071808 (-54.44609093), SSIM -0.00000860 (-0.00000220)
Finished validation.
Starting training epoch 115
Epoch: 115, MSE 311185.40625000 (308187.81721000), PSNR -54.93018723 (-54.88808431), SSIM -0.00000214 (-0.00000233)
Finished training epoch 115
Validate: MSE 185902.53125000 (340602.47925000), PSNR -52.69284821 (-55.05494796), SSIM -0.00000473 (-0.00000200)
Finished validation.
Starting training epoch 116
Epoch: 116, MSE 317220.40625000 (314200.17561000), PSNR -55.01360703 (-54.97199508), SSIM -0.00000184 (-0.00000245)
Finished training epoch 116
Validate: MSE 192898.26562500 (425151.73522500), PSNR -52.85327911 (-55.94597560), SSIM -0.00000189 (-0.00000119)
Finished validation.
Starting training epoch 117
Epoch: 117, MSE 323321.00000000 (320270.60128000), PSNR -55.09634018 (-55.05510

  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 248490.15625000 (412609.10525000), PSNR -53.95309067 (-55.98725346), SSIM -0.00000681 (-0.00000243)
Finished validation.
Starting training epoch 120
Epoch: 120, MSE 341966.93750000 (338830.71622000), PSNR -55.33983994 (-55.29976345), SSIM -0.00000224 (-0.00000228)
Finished training epoch 120
Validate: MSE 248274.15625000 (336138.91545000), PSNR -53.94931412 (-55.16320850), SSIM -0.00000675 (-0.00000192)
Finished validation.
Starting training epoch 121
Epoch: 121, MSE 348314.68750000 (345138.60230000), PSNR -55.41971588 (-55.37987198), SSIM -0.00000130 (-0.00000229)
Finished training epoch 121
Validate: MSE 186793.29687500 (310004.45397500), PSNR -52.71361160 (-54.59548397), SSIM -0.00000352 (-0.00000186)
Finished validation.
Starting training epoch 122
Epoch: 122, MSE 354706.03125000 (351509.93757000), PSNR -55.49868393 (-55.45931384), SSIM -0.00000331 (-0.00000219)
Finished training epoch 122
Validate: MSE 229636.90625000 (329263.37045000), PSNR -53.61041260 (-54.9323840

KeyboardInterrupt: 

<Figure size 432x288 with 0 Axes>

In [None]:
torch.save(model.state_dict(), f'{checkpoints}/last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')

In [None]:
# Validate
save_images = True
with torch.no_grad():
    validate(val_loader, model, criterion, save_images, -1)

In [None]:
# # Show images 
# image_pairs = []

# for i in range(10):
#     image_pairs.append((f'{color_imgs}img-{i}-epoch-{best_epoch}.jpg', f'{gray_imgs}img-{i}-epoch-{best_epoch}.jpg'))
    
# for c, g in image_pairs:
#   color = mpimg.imread(c)
#   gray  = mpimg.imread(g)
#   f, axarr = plt.subplots(1, 2)
#   f.set_size_inches(15, 15)
#   axarr[0].imshow(gray, cmap='gray')
#   axarr[1].imshow(color)
#   axarr[0].axis('off'), axarr[1].axis('off')
#   plt.show()