In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# For conversion
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchmetrics import MeanSquaredError, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from PIL import Image
# For utilities
import os, shutil, time

In [2]:
%load_ext tensorboard
%tensorboard --logdir=runs

In [3]:
# colab i kaggle jeszcze nie testowane
colab = False
kaggle = False
test_number = '12_2'

In [4]:
color_imgs = 'outputs/color/'
gray_imgs = 'outputs/gray/'
checkpoints = 'checkpoints'
if colab:
    from google.colab import drive
    drive.mount('/content/drive')
    dataset = '/content/drive/MyDrive/MGU/cifar10/'
    
    color_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{color_imgs}'
    gray_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{gray_imgs}'
    checkpoints = f'/content/drive/MyDrive/MGU/{test_number}/{checkpoints}'
elif kaggle:
    dataset = '/kaggle/input/cifar10/'
    
    color_imgs = f'{test_number}/{color_imgs}'
    gray_imgs = f'{test_number}/{gray_imgs}'
    checkpoints = f'{test_number}/{checkpoints}'
else:
    dataset = '../../datasets/cifar10/'

In [5]:
# Make folders and set parameters
os.makedirs(color_imgs, exist_ok=True)
os.makedirs(gray_imgs, exist_ok=True)
os.makedirs(checkpoints, exist_ok=True)
save_images = True
best_losses = [1e10, 1e10, 1e10]
best_epoch = -1
patience = 50
epochs = 500
batch_size = 128
SIZE = 32

In [6]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [7]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()
print(use_gpu)

True


In [8]:
class LabImageFolder(torch.utils.data.Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC),
                transforms.RandomCrop(SIZE),
                transforms.RandomHorizontalFlip(),  
            ])
        elif split == 'val':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC), 
                transforms.RandomCrop(SIZE), 
            ])
            
        self.split = split
        self.size = SIZE
        self.paths = [os.path.join(paths, file) for file in os.listdir(paths) if os.path.isfile(
            os.path.join(paths, file))]
        
        
    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("RGB")
        img_original = self.transforms(img)
        img_original = np.asarray(img_original)
        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        img_gray = rgb2gray(img_original)
        img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
        return img_gray, img_ab
    
    def __len__(self):
        return len(self.paths)

In [9]:
# Training
train_imagefolder = LabImageFolder(dataset + 'train')
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=batch_size, shuffle=True)
# Validation 
val_imagefolder = LabImageFolder(dataset + 'val' , 'val')
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=batch_size, shuffle=False)

In [10]:
kernel_size=3
stride_en=1
stride_de=1
padding=1
scale_factor=2
padding_mode='zeros'
class Autoencoder(nn.Module):
  def __init__(self, input_size=128):
    super(Autoencoder, self).__init__()

    self.encoder = nn.Sequential(       
      nn.Conv2d(1, 16, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode),
      nn.BatchNorm2d(16),
      nn.LeakyReLU(0.1),
      nn.AvgPool2d(kernel_size=3, stride=2, padding=1),
      nn.Conv2d(16, 32, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode),
      nn.BatchNorm2d(32),
      nn.LeakyReLU(0.1),
      nn.AvgPool2d(kernel_size=3, stride=2, padding=1),
      nn.Conv2d(32, 64, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode),
      nn.BatchNorm2d(64),
      nn.LeakyReLU(0.1),
      nn.AvgPool2d(kernel_size=3, stride=2, padding=1),
    )     
        
    self.decoder = nn.Sequential(  
      nn.ConvTranspose2d(64, 32, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode),
      nn.BatchNorm2d(32),
      nn.LeakyReLU(0.1),
      nn.Upsample(scale_factor=scale_factor),   
      nn.ConvTranspose2d(32, 16, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode),
      nn.BatchNorm2d(16),
      nn.LeakyReLU(0.1),
      nn.Upsample(scale_factor=scale_factor),
      nn.ConvTranspose2d(16, 8, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode),
      nn.BatchNorm2d(8),
      nn.LeakyReLU(0.1),
      nn.ConvTranspose2d(8, 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode),
      nn.Upsample(scale_factor=scale_factor)
    )

  def forward(self, input):

    encoder = self.encoder(input)
    # Upsample to get colors
    output = self.decoder(encoder)
    return output

In [11]:
model = Autoencoder()

In [12]:
criterion = [MeanSquaredError(), PeakSignalNoiseRatio(data_range=1.0), StructuralSimilarityIndexMeasure(data_range=1.0)]

In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [14]:
# # Move model and loss function to GPU
if use_gpu: 
    criterion = [criterion[0].to("cuda"), criterion[1].to("cuda"), criterion[2].to("cuda")]
    model = model.cuda()

In [15]:
if use_gpu: 
    from torchsummary import summary
    summary(model, (1, SIZE, SIZE))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 32, 32]             160
       BatchNorm2d-2           [-1, 16, 32, 32]              32
         LeakyReLU-3           [-1, 16, 32, 32]               0
         AvgPool2d-4           [-1, 16, 16, 16]               0
            Conv2d-5           [-1, 32, 16, 16]           4,640
       BatchNorm2d-6           [-1, 32, 16, 16]              64
         LeakyReLU-7           [-1, 32, 16, 16]               0
         AvgPool2d-8             [-1, 32, 8, 8]               0
            Conv2d-9             [-1, 64, 8, 8]          18,496
      BatchNorm2d-10             [-1, 64, 8, 8]             128
        LeakyReLU-11             [-1, 64, 8, 8]               0
        AvgPool2d-12             [-1, 64, 4, 4]               0
  ConvTranspose2d-13             [-1, 32, 4, 4]          18,464
      BatchNorm2d-14             [-1, 3

In [16]:
class AverageMeter(object):
    '''A handy class from the PyTorch ImageNet tutorial''' 
    def __init__(self):
        self.reset()
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
       Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
    color_image = lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    if save_path is not None and save_name is not None: 
        plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
        plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

In [17]:
def validate(val_loader, model, criterion, save_images, epoch):
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]

    model.eval()
    already_saved_images = False
    for gray, ab in val_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()

        # Run model and record loss
        output_ab = model(gray) # throw away class predictions
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))

        # Save images to file
        if save_images and not already_saved_images:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)): # save at most 5 images
                save_path = {'grayscale': gray_imgs, 'colorized': color_imgs}
                save_name = 'img-{}-epoch-{}.jpg'.format(j, epoch)
                to_rgb(gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

    print(f'Validate: MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print('Finished validation.')
    if epoch >= 0:
        writer.add_scalar("MSE/test", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/test", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/test", _loss[2].avg, epoch)
    return _loss[0].avg, _loss[1].avg, _loss[2].avg

In [18]:
def train(train_loader, model, criterion, optimizer, epoch):
    print(f'Starting training epoch {epoch}')
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]
    
    model.train()

    for gray, ab in train_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()
            
        optimizer.zero_grad()

        output_ab = model(gray) 
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        loss[0].backward()
        optimizer.step()
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))
        
    print(f'Epoch: {epoch}, MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print(f'Finished training epoch {epoch}')
    if epoch >= 0:
        writer.add_scalar("MSE/train", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/train", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/train", _loss[2].avg, epoch)

In [19]:
# Train model
for epoch in range(epochs):
    # Train for one epoch, then validate
    train(train_loader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(val_loader, model, criterion, save_images, epoch)
    # Save checkpoint and replace old best model if current model is better
    if losses[0] < best_losses[0]:
        best_losses[0] = losses[0]
        best_epoch = epoch
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}.pth')
    if losses[1] < best_losses[1]:
        best_losses[1] = losses[1]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-PSNRLoss-{losses[1]:.8f}.pth')
    if losses[2] < best_losses[2]:
        best_losses[2] = losses[2]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-SSIMLoss-{losses[2]:.8f}.pth')
    
    if epoch - best_epoch >= patience and epoch >= 100:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}-early_stop.pth')
        break
    
    if epoch == epochs - 1:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')


Starting training epoch 0
Epoch: 0, MSE 0.00296084 (0.00825625), PSNR 25.28585434 (23.81237645), SSIM 0.73405820 (0.67882486)
Finished training epoch 0
Validate: MSE 0.00332757 (0.00284765), PSNR 24.77873039 (25.52045190), SSIM 0.66391182 (0.74147415)
Finished validation.
Starting training epoch 1
Epoch: 1, MSE 0.00271226 (0.00294773), PSNR 25.66669083 (25.33627408), SSIM 0.77718747 (0.75051177)
Finished training epoch 1
Validate: MSE 0.00317363 (0.00275268), PSNR 24.98444176 (25.67277425), SSIM 0.68635416 (0.76352518)
Finished validation.
Starting training epoch 2
Epoch: 2, MSE 0.00266115 (0.00266243), PSNR 25.74930000 (25.76880881), SSIM 0.74878037 (0.76123161)
Finished training epoch 2
Validate: MSE 0.00360995 (0.00300064), PSNR 24.42499161 (25.30243275), SSIM 0.68744993 (0.75967747)
Finished validation.
Starting training epoch 3
Epoch: 3, MSE 0.00298071 (0.00255864), PSNR 25.25680733 (25.94366337), SSIM 0.75364286 (0.76753235)
Finished training epoch 3
Validate: MSE 0.00276107 (0.0

Epoch: 30, MSE 0.00185586 (0.00225477), PSNR 27.31453896 (26.48948592), SSIM 0.78928095 (0.78011699)
Finished training epoch 30
Validate: MSE 0.00249906 (0.00230033), PSNR 26.02223206 (26.44225679), SSIM 0.71522522 (0.77784292)
Finished validation.
Starting training epoch 31
Epoch: 31, MSE 0.00184984 (0.00224545), PSNR 27.32866669 (26.51050891), SSIM 0.79980958 (0.78018924)
Finished training epoch 31
Validate: MSE 0.00283671 (0.00258375), PSNR 25.47185135 (25.95369193), SSIM 0.71006012 (0.77790182)
Finished validation.
Starting training epoch 32
Epoch: 32, MSE 0.00229982 (0.00224074), PSNR 26.38306427 (26.51813263), SSIM 0.78756160 (0.78029377)
Finished training epoch 32
Validate: MSE 0.00282430 (0.00235888), PSNR 25.49089622 (26.34471139), SSIM 0.71435285 (0.78001094)
Finished validation.
Starting training epoch 33
Epoch: 33, MSE 0.00255532 (0.00225127), PSNR 25.92555428 (26.49981421), SSIM 0.78530824 (0.78026395)
Finished training epoch 33
Validate: MSE 0.00268328 (0.00235331), PSNR 

Epoch: 60, MSE 0.00199126 (0.00214242), PSNR 27.00871277 (26.71197721), SSIM 0.78835088 (0.78134325)
Finished training epoch 60
Validate: MSE 0.00249531 (0.00241990), PSNR 26.02875328 (26.23076910), SSIM 0.71699965 (0.77893768)
Finished validation.
Starting training epoch 61
Epoch: 61, MSE 0.00275425 (0.00215318), PSNR 25.59997368 (26.68971414), SSIM 0.77074927 (0.78132859)
Finished training epoch 61
Validate: MSE 0.00309235 (0.00255375), PSNR 25.09710693 (26.00715815), SSIM 0.71646142 (0.77828448)
Finished validation.
Starting training epoch 62
Epoch: 62, MSE 0.00168453 (0.00214152), PSNR 27.73520470 (26.71777319), SSIM 0.79779220 (0.78127520)
Finished training epoch 62
Validate: MSE 0.00244308 (0.00223423), PSNR 26.12062073 (26.57179214), SSIM 0.71805209 (0.78003994)
Finished validation.
Starting training epoch 63
Epoch: 63, MSE 0.00222488 (0.00213657), PSNR 26.52693176 (26.72346375), SSIM 0.78846222 (0.78130396)
Finished training epoch 63
Validate: MSE 0.00254721 (0.00222441), PSNR 

Epoch: 90, MSE 0.00256561 (0.00208824), PSNR 25.90808678 (26.82482353), SSIM 0.75310552 (0.78134754)
Finished training epoch 90
Validate: MSE 0.00256608 (0.00235096), PSNR 25.90729523 (26.36312340), SSIM 0.71627796 (0.77940801)
Finished validation.
Starting training epoch 91
Epoch: 91, MSE 0.00202421 (0.00208924), PSNR 26.93743515 (26.82249799), SSIM 0.78944772 (0.78132635)
Finished training epoch 91
Validate: MSE 0.00248336 (0.00222999), PSNR 26.04959869 (26.58238231), SSIM 0.71609825 (0.77930284)
Finished validation.
Starting training epoch 92
Epoch: 92, MSE 0.00234721 (0.00208489), PSNR 26.29447746 (26.83334625), SSIM 0.75235814 (0.78130506)
Finished training epoch 92
Validate: MSE 0.00239823 (0.00223139), PSNR 26.20108986 (26.57851457), SSIM 0.71925664 (0.77970770)
Finished validation.
Starting training epoch 93
Epoch: 93, MSE 0.00169652 (0.00208898), PSNR 27.70440292 (26.82132918), SSIM 0.78885907 (0.78122177)
Finished training epoch 93
Validate: MSE 0.00254305 (0.00224564), PSNR 

Epoch: 120, MSE 0.00206008 (0.00205104), PSNR 26.86115646 (26.90263816), SSIM 0.77954906 (0.78131828)
Finished training epoch 120
Validate: MSE 0.00237084 (0.00232007), PSNR 26.25098038 (26.40654547), SSIM 0.72141171 (0.77489118)
Finished validation.
Starting training epoch 121
Epoch: 121, MSE 0.00238784 (0.00204979), PSNR 26.21995544 (26.90751774), SSIM 0.77893841 (0.78126152)
Finished training epoch 121
Validate: MSE 0.00246398 (0.00218330), PSNR 26.08361816 (26.67220645), SSIM 0.71837270 (0.78012045)
Finished validation.
Starting training epoch 122
Epoch: 122, MSE 0.00210787 (0.00205029), PSNR 26.76156235 (26.90206372), SSIM 0.75561535 (0.78127256)
Finished training epoch 122
Validate: MSE 0.00260861 (0.00226637), PSNR 25.83590126 (26.51685197), SSIM 0.71670026 (0.77731699)
Finished validation.
Starting training epoch 123
Epoch: 123, MSE 0.00241484 (0.00204870), PSNR 26.17110634 (26.90689256), SSIM 0.76342559 (0.78117846)
Finished training epoch 123
Validate: MSE 0.00253062 (0.00225

Validate: MSE 0.00254112 (0.00248163), PSNR 25.94974327 (26.13560946), SSIM 0.71260381 (0.77390507)
Finished validation.
Starting training epoch 150
Epoch: 150, MSE 0.00183036 (0.00202809), PSNR 27.37464142 (26.95074385), SSIM 0.79128039 (0.78114006)
Finished training epoch 150
Validate: MSE 0.00243028 (0.00227220), PSNR 26.14343071 (26.49778745), SSIM 0.71908981 (0.77810252)
Finished validation.
Starting training epoch 151
Epoch: 151, MSE 0.00179695 (0.00201568), PSNR 27.45462799 (26.97662622), SSIM 0.78619683 (0.78120885)
Finished training epoch 151
Validate: MSE 0.00262870 (0.00233270), PSNR 25.80259514 (26.40454067), SSIM 0.71559685 (0.77737791)
Finished validation.
Starting training epoch 152
Epoch: 152, MSE 0.00194664 (0.00201970), PSNR 27.10713196 (26.97005366), SSIM 0.77695072 (0.78128304)
Finished training epoch 152
Validate: MSE 0.00268839 (0.00227423), PSNR 25.70507812 (26.50084155), SSIM 0.71394145 (0.77751978)
Finished validation.
Starting training epoch 153
Epoch: 153, MS

<Figure size 432x288 with 0 Axes>

In [20]:
torch.save(model.state_dict(), f'{checkpoints}/last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')

In [21]:
# Validate
save_images = True
with torch.no_grad():
    validate(val_loader, model, criterion, save_images, -1)

Validate: MSE 0.00240289 (0.00227638), PSNR 26.19265556 (26.48524731), SSIM 0.71927053 (0.77822917)
Finished validation.


<Figure size 432x288 with 0 Axes>

In [22]:
# # Show images 
# image_pairs = []

# for i in range(10):
#     image_pairs.append((f'{color_imgs}img-{i}-epoch-{best_epoch}.jpg', f'{gray_imgs}img-{i}-epoch-{best_epoch}.jpg'))
    
# for c, g in image_pairs:
#   color = mpimg.imread(c)
#   gray  = mpimg.imread(g)
#   f, axarr = plt.subplots(1, 2)
#   f.set_size_inches(15, 15)
#   axarr[0].imshow(gray, cmap='gray')
#   axarr[1].imshow(color)
#   axarr[0].axis('off'), axarr[1].axis('off')
#   plt.show()