In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# For conversion
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchmetrics import MeanSquaredError, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from PIL import Image
# For utilities
import os, shutil, time

In [2]:
%load_ext tensorboard
%tensorboard --logdir=runs

In [3]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [4]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()
print(use_gpu)

True


In [5]:
SIZE = 32
class LabImageFolder(torch.utils.data.Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC),
                transforms.RandomCrop(SIZE),
                transforms.RandomHorizontalFlip(), 
            ])
        elif split == 'val':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC), 
                transforms.RandomCrop(SIZE),
            ])
            
        self.split = split
        self.size = SIZE
        self.paths = [os.path.join(paths, file) for file in os.listdir(paths) if os.path.isfile(
            os.path.join(paths, file))]
        
        
    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("RGB")
        img_original = self.transforms(img)
        img_original = np.asarray(img_original)
        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        img_gray = rgb2gray(img_original)
        img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
        return img_gray, img_ab
    
    def __len__(self):
        return len(self.paths)

In [6]:
# Training
batch_size = 128
train_imagefolder = LabImageFolder('../../datasets/cifar10/train')
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=batch_size, shuffle=True)
# Validation 
val_imagefolder = LabImageFolder('../../datasets/cifar10/val' , 'val')
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=batch_size, shuffle=False)

In [7]:
kernel_size=3
stride_en=2
stride_de=1
padding=1
scale_factor=2
padding_mode='zeros'


class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        self.conv1 = nn.Conv2d(1, 16, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        
        self.convtrans1 = nn.ConvTranspose2d(64, 32, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans2 = nn.ConvTranspose2d(32, 16, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans3 = nn.ConvTranspose2d(16, 8, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans4 = nn.ConvTranspose2d(8, 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)

        self.batchnorm8 = nn.BatchNorm2d(8)
        self.batchnorm16 = nn.BatchNorm2d(16)
        self.batchnorm32 = nn.BatchNorm2d(32)
        self.batchnorm64 = nn.BatchNorm2d(64)
        
        
    def forward(self, input):
        # encoder
        x = F.relu(self.batchnorm16(self.conv1(input)))
        x = F.relu(self.batchnorm32(self.conv2(x)))
        x = F.relu(self.batchnorm64(self.conv3(x)))
        
        # decoder
        x = F.relu(self.batchnorm32(self.convtrans1(x)))
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.relu(self.batchnorm16(self.convtrans2(x)))
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.relu(self.batchnorm8(self.convtrans3(x)))
        x = F.interpolate(self.convtrans4(x), scale_factor=scale_factor)

        return x

In [8]:
model = Autoencoder()

In [9]:
criterion = [MeanSquaredError(), PeakSignalNoiseRatio(data_range=1.0), StructuralSimilarityIndexMeasure(data_range=1.0)]

In [10]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [11]:
# # Move model and loss function to GPU
if use_gpu: 
    criterion = [criterion[0].to("cuda"), criterion[1].to("cuda"), criterion[2].to("cuda")]
    model = model.cuda()

In [12]:
if use_gpu: 
    from torchsummary import summary
    summary(model, (1, SIZE, SIZE))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 16, 16]             160
       BatchNorm2d-2           [-1, 16, 16, 16]              32
            Conv2d-3             [-1, 32, 8, 8]           4,640
       BatchNorm2d-4             [-1, 32, 8, 8]              64
            Conv2d-5             [-1, 64, 4, 4]          18,496
       BatchNorm2d-6             [-1, 64, 4, 4]             128
   ConvTranspose2d-7             [-1, 32, 4, 4]          18,464
       BatchNorm2d-8             [-1, 32, 4, 4]              64
   ConvTranspose2d-9             [-1, 16, 8, 8]           4,624
      BatchNorm2d-10             [-1, 16, 8, 8]              32
  ConvTranspose2d-11            [-1, 8, 16, 16]           1,160
      BatchNorm2d-12            [-1, 8, 16, 16]              16
  ConvTranspose2d-13            [-1, 2, 16, 16]             146
Total params: 48,026
Trainable params: 

In [13]:
class AverageMeter(object):
    '''A handy class from the PyTorch ImageNet tutorial''' 
    def __init__(self):
        self.reset()
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
       Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
    color_image = lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    if save_path is not None and save_name is not None: 
        plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
        plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

In [14]:
color_imgs = 'outputs/color/'
gray_imgs = 'outputs/gray/'

In [15]:
def validate(val_loader, model, criterion, save_images, epoch):
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]

    model.eval()
    already_saved_images = False
    for gray, ab in val_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()

        # Run model and record loss
        output_ab = model(gray) # throw away class predictions
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))

        # Save images to file
        if save_images and not already_saved_images:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)): # save at most 5 images
                save_path = {'grayscale': gray_imgs, 'colorized': color_imgs}
                save_name = 'img-{}-epoch-{}.jpg'.format(j, epoch)
                to_rgb(gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

    print(f'Validate: MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print('Finished validation.')
    if epoch >= 0:        
        writer.add_scalar("MSE/test", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/test", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/test", _loss[2].avg, epoch)
    return _loss[0].avg, _loss[1].avg, _loss[2].avg

In [16]:
def train(train_loader, model, criterion, optimizer, epoch):
    print(f'Starting training epoch {epoch}')
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]
    
    model.train()

    for gray, ab in train_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()
            
        optimizer.zero_grad()

        output_ab = model(gray) 
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        loss[0].backward()
        optimizer.step()
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))
        
    print(f'Epoch: {epoch}, MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print(f'Finished training epoch {epoch}')
    if epoch >= 0:
        writer.add_scalar("MSE/train", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/train", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/train", _loss[2].avg, epoch)

In [17]:
# Make folders and set parameters
checkpoints = 'checkpoints'
os.makedirs(color_imgs, exist_ok=True)
os.makedirs(gray_imgs, exist_ok=True)
os.makedirs(checkpoints, exist_ok=True)
save_images = True
best_losses = [1e10, 1e10, 1e10]
best_epoch = -1
patience = 50
epochs = 500

In [18]:
# Train model
for epoch in range(epochs):
    # Train for one epoch, then validate
    train(train_loader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(val_loader, model, criterion, save_images, epoch)
    # Save checkpoint and replace old best model if current model is better
    if losses[0] < best_losses[0]:
        best_losses[0] = losses[0]
        best_epoch = epoch
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}.pth')
    if losses[1] < best_losses[1]:
        best_losses[1] = losses[1]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-PSNRLoss-{losses[1]:.8f}.pth')
    if losses[2] < best_losses[2]:
        best_losses[2] = losses[2]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-SSIMLoss-{losses[2]:.8f}.pth')
    
    if epoch - best_epoch >= patience:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}-early_stop.pth')
        break
    
    if epoch == epochs - 1:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')


Starting training epoch 0
Epoch: 0, MSE 0.00236279 (0.00680099), PSNR 26.26574516 (24.10485087), SSIM 0.72902524 (0.68562312)
Finished training epoch 0
Validate: MSE 0.00371635 (0.00331523), PSNR 24.29883194 (24.86133724), SSIM 0.64437044 (0.71164127)
Finished validation.
Starting training epoch 1
Epoch: 1, MSE 0.00241660 (0.00290992), PSNR 26.16795731 (25.39315139), SSIM 0.77453399 (0.74808238)
Finished training epoch 1
Validate: MSE 0.00336713 (0.00357604), PSNR 24.72740746 (24.56307506), SSIM 0.67050004 (0.73275071)
Finished validation.
Starting training epoch 2
Epoch: 2, MSE 0.00256545 (0.00266102), PSNR 25.90836906 (25.77143605), SSIM 0.76085919 (0.75887578)
Finished training epoch 2
Validate: MSE 0.00326473 (0.00340442), PSNR 24.86153030 (24.79282401), SSIM 0.68858463 (0.75702783)
Finished validation.
Starting training epoch 3
Epoch: 3, MSE 0.00306578 (0.00258378), PSNR 25.13458633 (25.90071177), SSIM 0.79098105 (0.76462931)
Finished training epoch 3
Validate: MSE 0.00322338 (0.0

  return func(*args, **kwargs)


Validate: MSE 0.00539325 (0.00577996), PSNR 22.68149185 (22.53053926), SSIM 0.68694866 (0.75575990)
Finished validation.
Starting training epoch 22
Epoch: 22, MSE 0.00218623 (0.00228376), PSNR 26.60304832 (26.43644467), SSIM 0.78189260 (0.77798159)
Finished training epoch 22
Validate: MSE 0.00334331 (0.00288176), PSNR 24.75823212 (25.48557003), SSIM 0.70107937 (0.77347383)
Finished validation.
Starting training epoch 23
Epoch: 23, MSE 0.00220047 (0.00227435), PSNR 26.57483673 (26.45557460), SSIM 0.78832859 (0.77809521)
Finished training epoch 23
Validate: MSE 0.00326587 (0.00299029), PSNR 24.86001778 (25.33719903), SSIM 0.70093608 (0.77050699)
Finished validation.
Starting training epoch 24
Epoch: 24, MSE 0.00206753 (0.00226964), PSNR 26.84548569 (26.46150061), SSIM 0.77935553 (0.77818000)
Finished training epoch 24
Validate: MSE 0.00311289 (0.00290580), PSNR 25.06835747 (25.46080097), SSIM 0.70259941 (0.77287426)
Finished validation.
Starting training epoch 25
Epoch: 25, MSE 0.0023538

  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00719636 (0.00754393), PSNR 21.42886925 (21.34859339), SSIM 0.68538350 (0.75898492)
Finished validation.
Starting training epoch 44
Epoch: 44, MSE 0.00229155 (0.00215708), PSNR 26.39870453 (26.68205088), SSIM 0.75372553 (0.77815377)
Finished training epoch 44
Validate: MSE 0.00549154 (0.00571917), PSNR 22.60305786 (22.56938293), SSIM 0.69529319 (0.76549729)
Finished validation.
Starting training epoch 45
Epoch: 45, MSE 0.00228984 (0.00215173), PSNR 26.40195274 (26.69353514), SSIM 0.75701702 (0.77834540)
Finished training epoch 45
Validate: MSE 0.00396648 (0.00418978), PSNR 24.01594543 (23.93545188), SSIM 0.69769418 (0.76765366)
Finished validation.
Starting training epoch 46
Epoch: 46, MSE 0.00227404 (0.00214388), PSNR 26.43201447 (26.70934502), SSIM 0.76225317 (0.77832042)
Finished training epoch 46
Validate: MSE 0.00356435 (0.00344686), PSNR 24.48019981 (24.75548914), SSIM 0.70114660 (0.77267688)
Finished validation.
Starting training epoch 47
Epoch: 47, MSE 0.0021360

  return func(*args, **kwargs)


Validate: MSE 0.00638643 (0.00629643), PSNR 21.94742012 (22.13429651), SSIM 0.68906081 (0.76205627)
Finished validation.
Starting training epoch 50
Epoch: 50, MSE 0.00198959 (0.00212647), PSNR 27.01236153 (26.74384209), SSIM 0.78302133 (0.77803451)
Finished training epoch 50
Validate: MSE 0.00362345 (0.00363648), PSNR 24.40877533 (24.52584274), SSIM 0.69654179 (0.76789108)
Finished validation.
Starting training epoch 51
Epoch: 51, MSE 0.00257611 (0.00212105), PSNR 25.89034843 (26.75484335), SSIM 0.78690833 (0.77782216)
Finished training epoch 51
Validate: MSE 0.00471937 (0.00491374), PSNR 23.26115417 (23.23805527), SSIM 0.69488215 (0.76532084)
Finished validation.
Starting training epoch 52
Epoch: 52, MSE 0.00178501 (0.00212121), PSNR 27.48359108 (26.75249069), SSIM 0.80232966 (0.77792970)
Finished training epoch 52
Validate: MSE 0.00380067 (0.00370226), PSNR 24.20140076 (24.45377776), SSIM 0.69763213 (0.77080658)
Finished validation.
Starting training epoch 53
Epoch: 53, MSE 0.0018567

  return func(*args, **kwargs)


Validate: MSE 0.00586661 (0.00600169), PSNR 22.31612396 (22.35085907), SSIM 0.68838179 (0.75987182)
Finished validation.
Starting training epoch 58
Epoch: 58, MSE 0.00160022 (0.00209206), PSNR 27.95819473 (26.81308200), SSIM 0.80034286 (0.77750997)
Finished training epoch 58
Validate: MSE 0.00415139 (0.00404721), PSNR 23.81806374 (24.06772216), SSIM 0.69572926 (0.76806526)
Finished validation.
Starting training epoch 59
Epoch: 59, MSE 0.00185211 (0.00208993), PSNR 27.32333755 (26.81930899), SSIM 0.77511519 (0.77748398)
Finished training epoch 59


  return func(*args, **kwargs)


Validate: MSE 0.00573436 (0.00595120), PSNR 22.41514969 (22.38647921), SSIM 0.68859506 (0.75835023)
Finished validation.
Starting training epoch 60
Epoch: 60, MSE 0.00203688 (0.00207804), PSNR 26.91035080 (26.84443858), SSIM 0.75245464 (0.77738906)
Finished training epoch 60
Validate: MSE 0.00351142 (0.00352714), PSNR 24.54516792 (24.65951621), SSIM 0.69752800 (0.76728061)
Finished validation.
Starting training epoch 61
Epoch: 61, MSE 0.00213769 (0.00208664), PSNR 26.70055771 (26.82651224), SSIM 0.77377582 (0.77729960)
Finished training epoch 61


  return func(*args, **kwargs)


Validate: MSE 0.00574542 (0.00589853), PSNR 22.40678215 (22.42810151), SSIM 0.68531048 (0.75493803)
Finished validation.
Starting training epoch 62
Epoch: 62, MSE 0.00205823 (0.00207798), PSNR 26.86505699 (26.84194979), SSIM 0.77591890 (0.77716575)
Finished training epoch 62
Validate: MSE 0.00422173 (0.00412286), PSNR 23.74509239 (23.98308243), SSIM 0.69419396 (0.76716912)
Finished validation.
Starting training epoch 63
Epoch: 63, MSE 0.00162526 (0.00207930), PSNR 27.89076233 (26.84303368), SSIM 0.80434704 (0.77726755)
Finished training epoch 63
Validate: MSE 0.00438643 (0.00427268), PSNR 23.57889175 (23.82428365), SSIM 0.69106156 (0.76345552)
Finished validation.
Starting training epoch 64
Epoch: 64, MSE 0.00173384 (0.00206621), PSNR 27.60991478 (26.86729575), SSIM 0.78850764 (0.77728711)
Finished training epoch 64


  return func(*args, **kwargs)


Validate: MSE 0.00616273 (0.00637737), PSNR 22.10226822 (22.08579307), SSIM 0.68492025 (0.75547680)
Finished validation.
Starting training epoch 65
Epoch: 65, MSE 0.00218989 (0.00207374), PSNR 26.59577942 (26.85520898), SSIM 0.78094637 (0.77705600)
Finished training epoch 65
Validate: MSE 0.00505074 (0.00518416), PSNR 22.96644592 (22.99873259), SSIM 0.69078887 (0.76208186)
Finished validation.
Starting training epoch 66
Epoch: 66, MSE 0.00172718 (0.00206171), PSNR 27.62661743 (26.87718764), SSIM 0.78384858 (0.77707422)
Finished training epoch 66
Validate: MSE 0.00397149 (0.00387433), PSNR 24.01046371 (24.25104557), SSIM 0.69457114 (0.76645487)
Finished validation.
Starting training epoch 67
Epoch: 67, MSE 0.00232896 (0.00207010), PSNR 26.32838058 (26.86150698), SSIM 0.76332480 (0.77716718)
Finished training epoch 67
Validate: MSE 0.00498263 (0.00526148), PSNR 23.02541542 (22.93707592), SSIM 0.68847132 (0.75819437)
Finished validation.
Starting training epoch 68
Epoch: 68, MSE 0.0026853

<Figure size 432x288 with 0 Axes>

In [19]:
torch.save(model.state_dict(), f'{checkpoints}/last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')

In [20]:
# Validate
save_images = True
with torch.no_grad():
    validate(val_loader, model, criterion, save_images, -1)

Validate: MSE 0.00329225 (0.00311359), PSNR 24.82506561 (25.17827191), SSIM 0.69621390 (0.76937311)
Finished validation.


<Figure size 432x288 with 0 Axes>

In [21]:
# # Show images 
# image_pairs = []

# for i in range(10):
#     image_pairs.append((f'{color_imgs}img-{i}-epoch-{best_epoch}.jpg', f'{gray_imgs}img-{i}-epoch-{best_epoch}.jpg'))
    
# for c, g in image_pairs:
#   color = mpimg.imread(c)
#   gray  = mpimg.imread(g)
#   f, axarr = plt.subplots(1, 2)
#   f.set_size_inches(15, 15)
#   axarr[0].imshow(gray, cmap='gray')
#   axarr[1].imshow(color)
#   axarr[0].axis('off'), axarr[1].axis('off')
#   plt.show()