In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# For conversion
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchmetrics import MeanSquaredError, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from PIL import Image
# For utilities
import os, shutil, time

In [2]:
# colab i kaggle jeszcze nie testowane
colab = False
kaggle = False
test_number = '11_1'

In [3]:
color_imgs = 'outputs/color/'
gray_imgs = 'outputs/gray/'
checkpoints = 'checkpoints'
if colab:
    from google.colab import drive
    drive.mount('/content/drive')
    dataset = '/content/drive/MyDrive/MGU/cifar10/'
    
    color_imgs = '/content/drive/MyDrive/MGU/outputs/color/'
    gray_imgs = '/content/drive/MyDrive/MGU/outputs/gray/'
    checkpoints = '/content/drive/MyDrive/MGU/checkpoints'
elif kaggle:
    os.makedirs(test_number, exist_ok=True)
    results = "results"
    os.makedirs(results, exist_ok=True)
    dataset = '/kaggle/input/cifar10/'
else:
    dataset = '../../datasets/cifar10/'

In [4]:
%load_ext tensorboard
%tensorboard --logdir=runs

In [5]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [6]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()
print(use_gpu)

True


In [7]:
SIZE = 32
class LabImageFolder(torch.utils.data.Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC),
                transforms.RandomCrop(SIZE),
                transforms.RandomHorizontalFlip(), 
            ])
        elif split == 'val':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC), 
                transforms.RandomCrop(SIZE),
            ])
            
        self.split = split
        self.size = SIZE
        self.paths = [os.path.join(paths, file) for file in os.listdir(paths) if os.path.isfile(
            os.path.join(paths, file))]
        
        
    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("RGB")
        img_original = self.transforms(img)
        img_original = np.asarray(img_original)
        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        img_gray = rgb2gray(img_original)
        img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
        return img_gray, img_ab
    
    def __len__(self):
        return len(self.paths)

In [8]:
# Training
batch_size = 128
train_imagefolder = LabImageFolder(dataset+'train')
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=batch_size, shuffle=True)
# Validation 
val_imagefolder = LabImageFolder(dataset+'val' , 'val')
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=batch_size, shuffle=False)

In [9]:
kernel_size=3
stride_en=2
stride_de=1
padding=1
scale_factor=2
padding_mode='zeros'


class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        self.conv1 = nn.Conv2d(1, 16, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        
        self.convtrans1 = nn.ConvTranspose2d(32, 16, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans2 = nn.ConvTranspose2d(16, 8, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans3 = nn.ConvTranspose2d(8, 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)

        self.batchnorm8 = nn.BatchNorm2d(8)
        self.batchnorm16 = nn.BatchNorm2d(16)
        self.batchnorm32 = nn.BatchNorm2d(32)
        
        
    def forward(self, input):
        # encoder
        x = F.relu(self.batchnorm16(self.conv1(input)))
        x = F.relu(self.batchnorm32(self.conv2(x)))
        
        # decoder
        x = F.relu(self.batchnorm16(self.convtrans1(x)))
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.relu(self.batchnorm8(self.convtrans2(x)))
        x = F.interpolate(self.convtrans3(x), scale_factor=scale_factor)

        return x

In [10]:
model = Autoencoder()

In [11]:
criterion = [MeanSquaredError(), PeakSignalNoiseRatio(data_range=1.0), StructuralSimilarityIndexMeasure(data_range=1.0)]

In [12]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [13]:
# # Move model and loss function to GPU
if use_gpu: 
    criterion = [criterion[0].to("cuda"), criterion[1].to("cuda"), criterion[2].to("cuda")]
    model = model.cuda()

In [14]:
if use_gpu: 
    from torchsummary import summary
    summary(model, (1, SIZE, SIZE))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 16, 16]             160
       BatchNorm2d-2           [-1, 16, 16, 16]              32
            Conv2d-3             [-1, 32, 8, 8]           4,640
       BatchNorm2d-4             [-1, 32, 8, 8]              64
   ConvTranspose2d-5             [-1, 16, 8, 8]           4,624
       BatchNorm2d-6             [-1, 16, 8, 8]              32
   ConvTranspose2d-7            [-1, 8, 16, 16]           1,160
       BatchNorm2d-8            [-1, 8, 16, 16]              16
   ConvTranspose2d-9            [-1, 2, 16, 16]             146
Total params: 10,874
Trainable params: 10,874
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.14
Params size (MB): 0.04
Estimated Total Size (MB): 0.19
---------------------------------------------

In [15]:
class AverageMeter(object):
    '''A handy class from the PyTorch ImageNet tutorial''' 
    def __init__(self):
        self.reset()
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
       Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
    color_image = lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    if save_path is not None and save_name is not None: 
        plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
        plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

In [16]:
def validate(val_loader, model, criterion, save_images, epoch):
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]

    model.eval()
    already_saved_images = False
    for gray, ab in val_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()

        # Run model and record loss
        output_ab = model(gray) # throw away class predictions
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))

        # Save images to file
        if save_images and not already_saved_images:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)): # save at most 5 images
                save_path = {'grayscale': gray_imgs, 'colorized': color_imgs}
                save_name = 'img-{}-epoch-{}.jpg'.format(j, epoch)
                to_rgb(gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

    print(f'Validate: MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print('Finished validation.')
    if epoch >= 0:        
#         writer.add_scalars(f'loss/test', {
#             'MSE': _loss[0].avg,
#             'PSNR': _loss[1].avg,
#             'SSIM': _loss[2].avg,
#         }, epoch)
        writer.add_scalar("MSE/test", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/test", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/test", _loss[2].avg, epoch)
    return _loss[0].avg, _loss[1].avg, _loss[2].avg

In [17]:
def train(train_loader, model, criterion, optimizer, epoch):
    print(f'Starting training epoch {epoch}')
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]
    
    model.train()

    for gray, ab in train_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()
            
        optimizer.zero_grad()

        output_ab = model(gray) 
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        loss[0].backward()
        optimizer.step()
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))
        
    print(f'Epoch: {epoch}, MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print(f'Finished training epoch {epoch}')
    if epoch >= 0:
#         writer.add_scalars(f'loss/train', {
#             'MSE': _loss[0].avg,
#             'PSNR': _loss[1].avg,
#             'SSIM': _loss[2].avg,
#         }, epoch)
        writer.add_scalar("MSE/train", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/train", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/train", _loss[2].avg, epoch)

In [18]:
# Make folders and set parameters
os.makedirs(color_imgs, exist_ok=True)
os.makedirs(gray_imgs, exist_ok=True)
os.makedirs(checkpoints, exist_ok=True)
save_images = True
best_losses = [1e10, 1e10, 1e10]
best_epoch = -1
patience = 50
epochs = 500

In [19]:
# Train model
for epoch in range(epochs):
    # Train for one epoch, then validate
    train(train_loader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(val_loader, model, criterion, save_images, epoch)
    # Save checkpoint and replace old best model if current model is better
    if losses[0] < best_losses[0]:
        best_losses[0] = losses[0]
        best_epoch = epoch
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}.pth')
    if losses[1] < best_losses[1]:
        best_losses[1] = losses[1]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-PSNRLoss-{losses[1]:.8f}.pth')
    if losses[2] < best_losses[2]:
        best_losses[2] = losses[2]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-SSIMLoss-{losses[2]:.8f}.pth')
    
    if epoch - best_epoch >= patience:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}-early_stop.pth')
        break
    
    if epoch == epochs - 1:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')


Starting training epoch 0
Epoch: 0, MSE 0.00312806 (0.00858896), PSNR 25.04724693 (23.38803656), SSIM 0.69198024 (0.58827725)
Finished training epoch 0
Validate: MSE 0.00480921 (0.00512265), PSNR 23.17926216 (23.05454900), SSIM 0.69618142 (0.76550161)
Finished validation.
Starting training epoch 1
Epoch: 1, MSE 0.00346041 (0.00317059), PSNR 24.60872459 (25.02291199), SSIM 0.70561165 (0.71814212)
Finished training epoch 1
Validate: MSE 0.00348187 (0.00328993), PSNR 24.58187675 (24.93850446), SSIM 0.69689804 (0.76738502)
Finished validation.
Starting training epoch 2
Epoch: 2, MSE 0.00254427 (0.00280736), PSNR 25.94437027 (25.54497692), SSIM 0.73960483 (0.74441751)
Finished training epoch 2
Validate: MSE 0.00396337 (0.00415997), PSNR 24.01935577 (23.95610585), SSIM 0.70020902 (0.76959373)
Finished validation.
Starting training epoch 3
Epoch: 3, MSE 0.00271964 (0.00268596), PSNR 25.65488052 (25.73589025), SSIM 0.76681274 (0.75614283)
Finished training epoch 3
Validate: MSE 0.00359241 (0.0

  return func(*args, **kwargs)
  return func(*args, **kwargs)


Validate: MSE 0.00397019 (0.00438236), PSNR 24.01188087 (23.70382924), SSIM 0.68554765 (0.75315993)
Finished validation.
Starting training epoch 5
Epoch: 5, MSE 0.00211480 (0.00256419), PSNR 26.74730301 (25.93132100), SSIM 0.78349912 (0.76564812)
Finished training epoch 5
Validate: MSE 0.00340772 (0.00353153), PSNR 24.67536354 (24.62381059), SSIM 0.68933308 (0.76018742)
Finished validation.
Starting training epoch 6
Epoch: 6, MSE 0.00266824 (0.00253790), PSNR 25.73775673 (25.97738868), SSIM 0.74135435 (0.76771196)
Finished training epoch 6


  return func(*args, **kwargs)


Validate: MSE 0.00417070 (0.00473688), PSNR 23.79790688 (23.37478766), SSIM 0.69441438 (0.76172907)
Finished validation.
Starting training epoch 7
Epoch: 7, MSE 0.00216831 (0.00252037), PSNR 26.63878250 (26.01156649), SSIM 0.78698373 (0.76952832)
Finished training epoch 7
Validate: MSE 0.00363724 (0.00394907), PSNR 24.39227676 (24.17170408), SSIM 0.69877410 (0.76641296)
Finished validation.
Starting training epoch 8
Epoch: 8, MSE 0.00226306 (0.00250344), PSNR 26.45303917 (26.03941793), SSIM 0.77281207 (0.77067100)
Finished training epoch 8
Validate: MSE 0.00372768 (0.00407305), PSNR 24.28560829 (24.04099582), SSIM 0.69825518 (0.76623179)
Finished validation.
Starting training epoch 9
Epoch: 9, MSE 0.00227840 (0.00249946), PSNR 26.42370605 (26.04245342), SSIM 0.75865287 (0.77131183)
Finished training epoch 9
Validate: MSE 0.00334314 (0.00346920), PSNR 24.75845146 (24.72955500), SSIM 0.70211917 (0.77134857)
Finished validation.
Starting training epoch 10
Epoch: 10, MSE 0.00235066 (0.0024

Validate: MSE 0.00315537 (0.00288464), PSNR 25.00949287 (25.49001817), SSIM 0.70031416 (0.76889595)
Finished validation.
Starting training epoch 37
Epoch: 37, MSE 0.00206858 (0.00237919), PSNR 26.84326553 (26.25823540), SSIM 0.78281963 (0.77689383)
Finished training epoch 37
Validate: MSE 0.00316349 (0.00289154), PSNR 24.99833107 (25.47567043), SSIM 0.69809687 (0.76601166)
Finished validation.
Starting training epoch 38
Epoch: 38, MSE 0.00243790 (0.00237082), PSNR 26.12983322 (26.27184617), SSIM 0.76914501 (0.77677258)
Finished training epoch 38
Validate: MSE 0.00308529 (0.00282502), PSNR 25.10704231 (25.57668540), SSIM 0.70317960 (0.77001884)
Finished validation.
Starting training epoch 39
Epoch: 39, MSE 0.00229413 (0.00236712), PSNR 26.39382172 (26.28566716), SSIM 0.78141797 (0.77688692)
Finished training epoch 39
Validate: MSE 0.00319592 (0.00276847), PSNR 24.95403862 (25.64915012), SSIM 0.69947922 (0.76917920)
Finished validation.
Starting training epoch 40
Epoch: 40, MSE 0.0025632

Validate: MSE 0.00326136 (0.00288388), PSNR 24.86601639 (25.48408148), SSIM 0.69931400 (0.76738777)
Finished validation.
Starting training epoch 67
Epoch: 67, MSE 0.00232842 (0.00234158), PSNR 26.32938576 (26.32725476), SSIM 0.77618533 (0.77613669)
Finished training epoch 67
Validate: MSE 0.00324815 (0.00310207), PSNR 24.88364029 (25.19162122), SSIM 0.69707251 (0.76319306)
Finished validation.
Starting training epoch 68
Epoch: 68, MSE 0.00239571 (0.00233444), PSNR 26.20565605 (26.34299208), SSIM 0.77642983 (0.77615955)
Finished training epoch 68
Validate: MSE 0.00312517 (0.00284236), PSNR 25.05125999 (25.55053670), SSIM 0.69970727 (0.76713234)
Finished validation.
Starting training epoch 69
Epoch: 69, MSE 0.00226945 (0.00233834), PSNR 26.44078636 (26.33238081), SSIM 0.77910709 (0.77609868)
Finished training epoch 69
Validate: MSE 0.00322705 (0.00275026), PSNR 24.91194916 (25.67848710), SSIM 0.70274729 (0.77195834)
Finished validation.
Starting training epoch 70
Epoch: 70, MSE 0.0021530

Validate: MSE 0.00327794 (0.00291539), PSNR 24.84398842 (25.43485817), SSIM 0.69657135 (0.76579828)
Finished validation.
Starting training epoch 97
Epoch: 97, MSE 0.00215796 (0.00232468), PSNR 26.65955734 (26.36086479), SSIM 0.78118432 (0.77589478)
Finished training epoch 97
Validate: MSE 0.00340659 (0.00303325), PSNR 24.67680740 (25.26013253), SSIM 0.69388080 (0.76306809)
Finished validation.
Starting training epoch 98
Epoch: 98, MSE 0.00210869 (0.00232584), PSNR 26.75987244 (26.35597406), SSIM 0.76939654 (0.77585330)
Finished training epoch 98
Validate: MSE 0.00331220 (0.00281934), PSNR 24.79883194 (25.56875089), SSIM 0.69857109 (0.76806881)
Finished validation.
Starting training epoch 99
Epoch: 99, MSE 0.00198256 (0.00231926), PSNR 27.02773094 (26.37068317), SSIM 0.78258985 (0.77585905)
Finished training epoch 99
Validate: MSE 0.00328293 (0.00302732), PSNR 24.83738518 (25.27739350), SSIM 0.69585502 (0.76518395)
Finished validation.
Starting training epoch 100
Epoch: 100, MSE 0.00246

<Figure size 432x288 with 0 Axes>

In [20]:
torch.save(model.state_dict(), f'{checkpoints}/last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')

In [21]:
# Validate
save_images = True
with torch.no_grad():
    validate(val_loader, model, criterion, save_images, -1)

Validate: MSE 0.00349254 (0.00306739), PSNR 24.56858063 (25.19594125), SSIM 0.69315994 (0.76116881)
Finished validation.


<Figure size 432x288 with 0 Axes>

In [22]:
# # Show images 
# image_pairs = []

# for i in range(10):
#     image_pairs.append((f'{color_imgs}img-{i}-epoch-{best_epoch}.jpg', f'{gray_imgs}img-{i}-epoch-{best_epoch}.jpg'))
    
# for c, g in image_pairs:
#   color = mpimg.imread(c)
#   gray  = mpimg.imread(g)
#   f, axarr = plt.subplots(1, 2)
#   f.set_size_inches(15, 15)
#   axarr[0].imshow(gray, cmap='gray')
#   axarr[1].imshow(color)
#   axarr[0].axis('off'), axarr[1].axis('off')
#   plt.show()