In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# For conversion
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchmetrics import MeanSquaredError, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from PIL import Image
# For utilities
import os, shutil, time

In [2]:
%load_ext tensorboard
%tensorboard --logdir=runs

In [3]:
# colab i kaggle jeszcze nie testowane
colab = False
kaggle = False
test_number = '12_2'

In [4]:
color_imgs = 'outputs/color/'
gray_imgs = 'outputs/gray/'
checkpoints = 'checkpoints'
if colab:
    from google.colab import drive
    drive.mount('/content/drive')
    dataset = '/content/drive/MyDrive/MGU/cifar10/'
    
    color_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{color_imgs}'
    gray_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{gray_imgs}'
    checkpoints = f'/content/drive/MyDrive/MGU/{test_number}/{checkpoints}'
elif kaggle:
    dataset = '/kaggle/input/cifar10/'
    
    color_imgs = f'{test_number}/{color_imgs}'
    gray_imgs = f'{test_number}/{gray_imgs}'
    checkpoints = f'{test_number}/{checkpoints}'
else:
    dataset = '../../datasets/cifar10/'

In [5]:
# Make folders and set parameters
os.makedirs(color_imgs, exist_ok=True)
os.makedirs(gray_imgs, exist_ok=True)
os.makedirs(checkpoints, exist_ok=True)
save_images = True
best_losses = [1e10, 1e10, 1e10]
best_epoch = -1
patience = 50
epochs = 500
batch_size = 128
SIZE = 32

In [6]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [7]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()
print(use_gpu)

True


In [8]:
class LabImageFolder(torch.utils.data.Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC),
                transforms.RandomCrop(SIZE),
                transforms.RandomHorizontalFlip(),  
                transforms.RandomVerticalFlip(),  
                transforms.RandomRotation(10),
                transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05)
            ])
        elif split == 'val':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC), 
                transforms.RandomCrop(SIZE), 
            ])
            
        self.split = split
        self.size = SIZE
        self.paths = [os.path.join(paths, file) for file in os.listdir(paths) if os.path.isfile(
            os.path.join(paths, file))]
        
        
    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("RGB")
        img_original = self.transforms(img)
        img_original = np.asarray(img_original)
        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        img_gray = rgb2gray(img_original)
        img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
        return img_gray, img_ab
    
    def __len__(self):
        return len(self.paths)

In [9]:
# Training
train_imagefolder = LabImageFolder(dataset + 'train')
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=batch_size, shuffle=True)
# Validation 
val_imagefolder = LabImageFolder(dataset + 'val' , 'val')
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=batch_size, shuffle=False)

In [10]:
kernel_size=3
stride_en=1
stride_de=1
padding=1
scale_factor=2
padding_mode='zeros'
channels_base = 256
p1 = .25

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        self.conv1 = nn.Conv2d(1, channels_base, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        self.conv2 = nn.Conv2d(channels_base, channels_base * 2, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)

        self.convtrans1 = nn.ConvTranspose2d(channels_base * 2, channels_base, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans2 = nn.ConvTranspose2d(channels_base, channels_base // 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans3 = nn.ConvTranspose2d(channels_base // 2, 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)

        self.batchnorm1 = nn.BatchNorm2d(channels_base)
        self.batchnorm2 = nn.BatchNorm2d(channels_base * 2)
        self.batchnorm3 = nn.BatchNorm2d(channels_base)
        self.batchnorm4 = nn.BatchNorm2d(channels_base // 2)
        self.batchnorm5 = nn.BatchNorm2d(2)
        
    def forward(self, input):
        # encoder
        x = F.leaky_relu(self.batchnorm1(self.conv1(input)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = y = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        x = F.leaky_relu(self.batchnorm2(self.conv2(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        
        # decoder
        x = F.leaky_relu(self.batchnorm3(self.convtrans1(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.leaky_relu(self.batchnorm4(self.convtrans2(x + y)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.leaky_relu(self.batchnorm5(self.convtrans3(x + input)), negative_slope=0.1)

        return x

In [11]:
model = Autoencoder()

In [12]:
criterion = [MeanSquaredError(), PeakSignalNoiseRatio(data_range=1.0), StructuralSimilarityIndexMeasure(data_range=1.0)]

In [13]:
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9, weight_decay=1e-5)

In [14]:
# # Move model and loss function to GPU
if use_gpu: 
    criterion = [criterion[0].to("cuda"), criterion[1].to("cuda"), criterion[2].to("cuda")]
    model = model.cuda()

In [15]:
if use_gpu: 
    from torchsummary import summary
    summary(model, (1, SIZE, SIZE))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 256, 32, 32]           2,560
       BatchNorm2d-2          [-1, 256, 32, 32]             512
            Conv2d-3          [-1, 512, 16, 16]       1,180,160
       BatchNorm2d-4          [-1, 512, 16, 16]           1,024
   ConvTranspose2d-5            [-1, 256, 8, 8]       1,179,904
       BatchNorm2d-6            [-1, 256, 8, 8]             512
   ConvTranspose2d-7          [-1, 128, 16, 16]         295,040
       BatchNorm2d-8          [-1, 128, 16, 16]             256
   ConvTranspose2d-9            [-1, 2, 32, 32]           2,306
      BatchNorm2d-10            [-1, 2, 32, 32]               4
Total params: 2,662,278
Trainable params: 2,662,278
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 6.78
Params size (MB): 10.16
Estima

In [16]:
class AverageMeter(object):
    '''A handy class from the PyTorch ImageNet tutorial''' 
    def __init__(self):
        self.reset()
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
       Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
    color_image = lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    if save_path is not None and save_name is not None: 
        plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
        plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

In [17]:
def validate(val_loader, model, criterion, save_images, epoch):
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]

    model.eval()
    already_saved_images = False
    for gray, ab in val_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()

        # Run model and record loss
        output_ab = model(gray) # throw away class predictions
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))

        # Save images to file
        if save_images and not already_saved_images:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)): # save at most 5 images
                save_path = {'grayscale': gray_imgs, 'colorized': color_imgs}
                save_name = 'img-{}-epoch-{}.jpg'.format(j, epoch)
                to_rgb(gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

    print(f'Validate: MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print('Finished validation.')
    if epoch >= 0:
        writer.add_scalar("MSE/test", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/test", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/test", _loss[2].avg, epoch)
    return _loss[0].avg, _loss[1].avg, _loss[2].avg

In [18]:
def train(train_loader, model, criterion, optimizer, epoch):
    print(f'Starting training epoch {epoch}')
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]
    
    model.train()

    for gray, ab in train_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()
            
        optimizer.zero_grad()

        output_ab = model(gray) 
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        loss[0].backward()
        optimizer.step()
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))
        
    print(f'Epoch: {epoch}, MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print(f'Finished training epoch {epoch}')
    if epoch >= 0:
        writer.add_scalar("MSE/train", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/train", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/train", _loss[2].avg, epoch)

In [None]:
# Train model
for epoch in range(epochs):
    # Train for one epoch, then validate
    train(train_loader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(val_loader, model, criterion, save_images, epoch)
    # Save checkpoint and replace old best model if current model is better
    if losses[0] < best_losses[0]:
        best_losses[0] = losses[0]
        best_epoch = epoch
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}.pth')
    if losses[1] < best_losses[1]:
        best_losses[1] = losses[1]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-PSNRLoss-{losses[1]:.8f}.pth')
    if losses[2] < best_losses[2]:
        best_losses[2] = losses[2]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-SSIMLoss-{losses[2]:.8f}.pth')
    
    if epoch - best_epoch >= patience and epoch >= 100:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}-early_stop.pth')
        break
    
    if epoch == epochs - 1:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')


Starting training epoch 0
Epoch: 0, MSE 0.00254552 (0.04366420), PSNR 25.94223404 (18.91462367), SSIM 0.78552955 (0.46850461)
Finished training epoch 0
Validate: MSE 0.00319207 (0.00275314), PSNR 24.95927620 (25.67761057), SSIM 0.70460057 (0.77466265)
Finished validation.
Starting training epoch 1
Epoch: 1, MSE 0.00291384 (0.00274353), PSNR 25.35534668 (25.63784571), SSIM 0.76775277 (0.76886805)
Finished training epoch 1
Validate: MSE 0.00321687 (0.00275669), PSNR 24.92565918 (25.67070935), SSIM 0.70384294 (0.77446084)
Finished validation.
Starting training epoch 2
Epoch: 2, MSE 0.00273689 (0.00275175), PSNR 25.62742424 (25.62521241), SSIM 0.75058353 (0.76887820)
Finished training epoch 2
Validate: MSE 0.00317916 (0.00275237), PSNR 24.97687340 (25.67922664), SSIM 0.70412993 (0.77444272)
Finished validation.
Starting training epoch 3
Epoch: 3, MSE 0.00240915 (0.00275450), PSNR 26.18136215 (25.62277028), SSIM 0.78932685 (0.76849084)
Finished training epoch 3
Validate: MSE 0.00322224 (0.0

Epoch: 30, MSE 0.00232596 (0.00273517), PSNR 26.33396912 (25.65342195), SSIM 0.77255279 (0.76821394)
Finished training epoch 30
Validate: MSE 0.00316906 (0.00273501), PSNR 24.99069023 (25.70598969), SSIM 0.70357949 (0.77500509)
Finished validation.
Starting training epoch 31
Epoch: 31, MSE 0.00250596 (0.00273811), PSNR 26.01025772 (25.64927463), SSIM 0.77789319 (0.76821278)
Finished training epoch 31
Validate: MSE 0.00315135 (0.00273325), PSNR 25.01503563 (25.70938372), SSIM 0.70393384 (0.77457994)
Finished validation.
Starting training epoch 32
Epoch: 32, MSE 0.00296845 (0.00273159), PSNR 25.27470589 (25.66031618), SSIM 0.78289354 (0.76839941)
Finished training epoch 32
Validate: MSE 0.00313891 (0.00273133), PSNR 25.03220367 (25.71308536), SSIM 0.70293272 (0.77430256)
Finished validation.
Starting training epoch 33
Epoch: 33, MSE 0.00307220 (0.00272982), PSNR 25.12549973 (25.65985842), SSIM 0.76594192 (0.76842229)
Finished training epoch 33
Validate: MSE 0.00311887 (0.00272970), PSNR 

Epoch: 60, MSE 0.00203395 (0.00269952), PSNR 26.91658974 (25.70982668), SSIM 0.79973352 (0.76667648)
Finished training epoch 60
Validate: MSE 0.00309020 (0.00269324), PSNR 25.10013962 (25.77363981), SSIM 0.69947982 (0.77301683)
Finished validation.
Starting training epoch 61
Epoch: 61, MSE 0.00227584 (0.00269235), PSNR 26.42858124 (25.72049097), SSIM 0.77728498 (0.76648800)
Finished training epoch 61
Validate: MSE 0.00307892 (0.00269037), PSNR 25.11601448 (25.77894646), SSIM 0.69939977 (0.77250473)
Finished validation.
Starting training epoch 62
Epoch: 62, MSE 0.00195015 (0.00269769), PSNR 27.09931374 (25.71755284), SSIM 0.77675807 (0.76621500)
Finished training epoch 62
Validate: MSE 0.00310593 (0.00269331), PSNR 25.07808876 (25.77181024), SSIM 0.70174289 (0.77314427)
Finished validation.
Starting training epoch 63
Epoch: 63, MSE 0.00245834 (0.00268905), PSNR 26.09357071 (25.72761003), SSIM 0.78640336 (0.76634945)
Finished training epoch 63
Validate: MSE 0.00307353 (0.00268816), PSNR 

Epoch: 90, MSE 0.00258042 (0.00263047), PSNR 25.88309669 (25.82637076), SSIM 0.74952334 (0.76154685)
Finished training epoch 90
Validate: MSE 0.00297898 (0.00261390), PSNR 25.25932503 (25.89830391), SSIM 0.69641924 (0.76723042)
Finished validation.
Starting training epoch 91
Epoch: 91, MSE 0.00285359 (0.00262461), PSNR 25.44608498 (25.83062753), SSIM 0.75158221 (0.76150084)
Finished training epoch 91
Validate: MSE 0.00300965 (0.00261727), PSNR 25.21484184 (25.89276594), SSIM 0.69929779 (0.76948350)
Finished validation.
Starting training epoch 92
Epoch: 92, MSE 0.00239645 (0.00261810), PSNR 26.20431328 (25.84475394), SSIM 0.77296448 (0.76156290)
Finished training epoch 92
Validate: MSE 0.00298928 (0.00261204), PSNR 25.24432945 (25.90135770), SSIM 0.69748318 (0.76910329)
Finished validation.
Starting training epoch 93
Epoch: 93, MSE 0.00260970 (0.00261867), PSNR 25.83408928 (25.84495314), SSIM 0.74619567 (0.76152398)
Finished training epoch 93
Validate: MSE 0.00298342 (0.00260817), PSNR 

Epoch: 120, MSE 0.00307103 (0.00257245), PSNR 25.12716293 (25.91995555), SSIM 0.74247932 (0.76025683)
Finished training epoch 120
Validate: MSE 0.00294263 (0.00255738), PSNR 25.31264305 (25.98947494), SSIM 0.70139301 (0.76783691)
Finished validation.
Starting training epoch 121
Epoch: 121, MSE 0.00271214 (0.00257253), PSNR 25.66687775 (25.91722702), SSIM 0.76107168 (0.76028338)
Finished training epoch 121
Validate: MSE 0.00295486 (0.00255560), PSNR 25.29462433 (25.99182457), SSIM 0.69866121 (0.76745541)
Finished validation.
Starting training epoch 122
Epoch: 122, MSE 0.00260640 (0.00256894), PSNR 25.83959579 (25.92695142), SSIM 0.74596161 (0.76057339)
Finished training epoch 122
Validate: MSE 0.00295444 (0.00255673), PSNR 25.29524994 (25.98982423), SSIM 0.69804561 (0.76675130)
Finished validation.
Starting training epoch 123
Epoch: 123, MSE 0.00244788 (0.00256741), PSNR 26.11209488 (25.92518943), SSIM 0.75107962 (0.76060984)
Finished training epoch 123
Validate: MSE 0.00293520 (0.00255

Validate: MSE 0.00293355 (0.00252489), PSNR 25.32606316 (26.04266894), SSIM 0.70283937 (0.76821961)
Finished validation.
Starting training epoch 150
Epoch: 150, MSE 0.00227956 (0.00254371), PSNR 26.42148018 (25.96757582), SSIM 0.77820688 (0.76123010)
Finished training epoch 150
Validate: MSE 0.00293119 (0.00252646), PSNR 25.32955170 (26.03927274), SSIM 0.70078802 (0.76693448)
Finished validation.
Starting training epoch 151
Epoch: 151, MSE 0.00269966 (0.00254569), PSNR 25.68690491 (25.96458373), SSIM 0.73520964 (0.76094424)
Finished training epoch 151
Validate: MSE 0.00290624 (0.00252220), PSNR 25.36668205 (26.04658725), SSIM 0.70196819 (0.76686428)
Finished validation.
Starting training epoch 152
Epoch: 152, MSE 0.00228986 (0.00254995), PSNR 26.40191269 (25.95887526), SSIM 0.75857460 (0.76063928)
Finished training epoch 152
Validate: MSE 0.00294212 (0.00252389), PSNR 25.31339264 (26.04381774), SSIM 0.70369154 (0.76831503)
Finished validation.
Starting training epoch 153
Epoch: 153, MS

Epoch: 179, MSE 0.00224326 (0.00252470), PSNR 26.49119949 (26.00070964), SSIM 0.77598059 (0.76175023)
Finished training epoch 179
Validate: MSE 0.00293784 (0.00250437), PSNR 25.31971741 (26.07743433), SSIM 0.70490450 (0.76992771)
Finished validation.
Starting training epoch 180
Epoch: 180, MSE 0.00240447 (0.00252067), PSNR 26.18981171 (26.00750341), SSIM 0.74909580 (0.76186627)
Finished training epoch 180
Validate: MSE 0.00290043 (0.00250136), PSNR 25.37537003 (26.08104949), SSIM 0.70123869 (0.76691961)
Finished validation.
Starting training epoch 181
Epoch: 181, MSE 0.00168967 (0.00251304), PSNR 27.72197723 (26.01865941), SSIM 0.77333701 (0.76190985)
Finished training epoch 181
Validate: MSE 0.00289743 (0.00249829), PSNR 25.37987328 (26.08709381), SSIM 0.70358896 (0.76885261)
Finished validation.
Starting training epoch 182
Epoch: 182, MSE 0.00261840 (0.00252776), PSNR 25.81964302 (25.99752406), SSIM 0.74449319 (0.76178760)
Finished training epoch 182
Validate: MSE 0.00290287 (0.00249

Validate: MSE 0.00290564 (0.00248535), PSNR 25.36758804 (26.10897220), SSIM 0.70558733 (0.76912209)
Finished validation.
Starting training epoch 209
Epoch: 209, MSE 0.00314849 (0.00250787), PSNR 25.01898193 (26.03063097), SSIM 0.75086081 (0.76260689)
Finished training epoch 209
Validate: MSE 0.00290000 (0.00248376), PSNR 25.37601662 (26.11156769), SSIM 0.70310748 (0.76790039)
Finished validation.
Starting training epoch 210
Epoch: 210, MSE 0.00247800 (0.00250599), PSNR 26.05898857 (26.03335720), SSIM 0.74413502 (0.76261777)
Finished training epoch 210
Validate: MSE 0.00289687 (0.00248274), PSNR 25.38070679 (26.11361628), SSIM 0.70412272 (0.76926378)
Finished validation.
Starting training epoch 211
Epoch: 211, MSE 0.00216696 (0.00250774), PSNR 26.64148521 (26.02983153), SSIM 0.77014160 (0.76236582)
Finished training epoch 211
Validate: MSE 0.00292707 (0.00248630), PSNR 25.33566856 (26.10741871), SSIM 0.70395076 (0.76849608)
Finished validation.
Starting training epoch 212
Epoch: 212, MS

Epoch: 238, MSE 0.00239359 (0.00248892), PSNR 26.20949364 (26.06241645), SSIM 0.75059485 (0.76334059)
Finished training epoch 238
Validate: MSE 0.00289981 (0.00247127), PSNR 25.37630272 (26.13357723), SSIM 0.70552433 (0.77016705)
Finished validation.
Starting training epoch 239
Epoch: 239, MSE 0.00250471 (0.00250064), PSNR 26.01243210 (26.04176650), SSIM 0.76577604 (0.76320364)
Finished training epoch 239
Validate: MSE 0.00288600 (0.00247071), PSNR 25.39703178 (26.13451042), SSIM 0.70561337 (0.76947010)
Finished validation.
Starting training epoch 240
Epoch: 240, MSE 0.00209389 (0.00248411), PSNR 26.79046822 (26.06927204), SSIM 0.78078789 (0.76343854)
Finished training epoch 240
Validate: MSE 0.00288532 (0.00246912), PSNR 25.39805794 (26.13608018), SSIM 0.70499218 (0.76904513)
Finished validation.
Starting training epoch 241
Epoch: 241, MSE 0.00292710 (0.00249559), PSNR 25.33562660 (26.05213339), SSIM 0.74175131 (0.76316921)
Finished training epoch 241
Validate: MSE 0.00288739 (0.00246

Validate: MSE 0.00289117 (0.00245910), PSNR 25.38925743 (26.15458541), SSIM 0.70624709 (0.77001078)
Finished validation.
Starting training epoch 268
Epoch: 268, MSE 0.00277150 (0.00249082), PSNR 25.57284355 (26.06165249), SSIM 0.76427066 (0.76388443)
Finished training epoch 268
Validate: MSE 0.00286295 (0.00245748), PSNR 25.43186760 (26.15718820), SSIM 0.70626736 (0.76983993)
Finished validation.
Starting training epoch 269
Epoch: 269, MSE 0.00225048 (0.00248858), PSNR 26.47724915 (26.06387150), SSIM 0.77145660 (0.76375535)
Finished training epoch 269
Validate: MSE 0.00290045 (0.00245996), PSNR 25.37534714 (26.15412917), SSIM 0.70668453 (0.77285398)
Finished validation.
Starting training epoch 270
Epoch: 270, MSE 0.00219575 (0.00248450), PSNR 26.58417130 (26.07062549), SSIM 0.78579611 (0.76412161)
Finished training epoch 270
Validate: MSE 0.00289017 (0.00245526), PSNR 25.39076996 (26.16137472), SSIM 0.70693660 (0.77063797)
Finished validation.
Starting training epoch 271
Epoch: 271, MS

Epoch: 297, MSE 0.00252581 (0.00247403), PSNR 25.97598839 (26.09007082), SSIM 0.75672156 (0.76494134)
Finished training epoch 297
Validate: MSE 0.00286795 (0.00245347), PSNR 25.42428207 (26.16439536), SSIM 0.70564663 (0.77023634)
Finished validation.
Starting training epoch 298
Epoch: 298, MSE 0.00228543 (0.00247437), PSNR 26.41032982 (26.08591477), SSIM 0.76530421 (0.76461310)
Finished training epoch 298
Validate: MSE 0.00287646 (0.00245105), PSNR 25.41141891 (26.16882997), SSIM 0.70676994 (0.77224801)
Finished validation.
Starting training epoch 299
Epoch: 299, MSE 0.00263599 (0.00248689), PSNR 25.79056358 (26.06823887), SSIM 0.73591292 (0.76451866)
Finished training epoch 299
Validate: MSE 0.00285993 (0.00245079), PSNR 25.43643761 (26.16883299), SSIM 0.70709682 (0.77096675)
Finished validation.
Starting training epoch 300
Epoch: 300, MSE 0.00287806 (0.00247936), PSNR 25.40900040 (26.07724472), SSIM 0.76300895 (0.76450635)
Finished training epoch 300
Validate: MSE 0.00287801 (0.00245

Validate: MSE 0.00286913 (0.00243947), PSNR 25.42250252 (26.18922840), SSIM 0.70813078 (0.77167055)
Finished validation.
Starting training epoch 327
Epoch: 327, MSE 0.00240725 (0.00246189), PSNR 26.18478394 (26.11224693), SSIM 0.77298033 (0.76527864)
Finished training epoch 327
Validate: MSE 0.00283404 (0.00243727), PSNR 25.47594643 (26.19262316), SSIM 0.70950848 (0.77212022)
Finished validation.
Starting training epoch 328
Epoch: 328, MSE 0.00252476 (0.00247469), PSNR 25.97778893 (26.08994542), SSIM 0.74957085 (0.76485805)
Finished training epoch 328
Validate: MSE 0.00287016 (0.00244297), PSNR 25.42093849 (26.18313095), SSIM 0.70654476 (0.77162945)
Finished validation.
Starting training epoch 329
Epoch: 329, MSE 0.00254389 (0.00247517), PSNR 25.94500732 (26.08831097), SSIM 0.77260065 (0.76494762)
Finished training epoch 329
Validate: MSE 0.00286892 (0.00243800), PSNR 25.42280960 (26.19263970), SSIM 0.70814735 (0.77320516)
Finished validation.
Starting training epoch 330
Epoch: 330, MS

In [None]:
torch.save(model.state_dict(), f'{checkpoints}/last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')

In [None]:
# Validate
save_images = True
with torch.no_grad():
    validate(val_loader, model, criterion, save_images, -1)

In [None]:
# # Show images 
# image_pairs = []

# for i in range(10):
#     image_pairs.append((f'{color_imgs}img-{i}-epoch-{best_epoch}.jpg', f'{gray_imgs}img-{i}-epoch-{best_epoch}.jpg'))
    
# for c, g in image_pairs:
#   color = mpimg.imread(c)
#   gray  = mpimg.imread(g)
#   f, axarr = plt.subplots(1, 2)
#   f.set_size_inches(15, 15)
#   axarr[0].imshow(gray, cmap='gray')
#   axarr[1].imshow(color)
#   axarr[0].axis('off'), axarr[1].axis('off')
#   plt.show()