In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# For conversion
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchmetrics import MeanSquaredError, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from PIL import Image
# For utilities
import os, shutil, time

In [2]:
%load_ext tensorboard
%tensorboard --logdir=runs

In [3]:
# colab i kaggle jeszcze nie testowane
colab = False
kaggle = False
test_number = '12_2'

In [4]:
color_imgs = 'outputs/color/'
gray_imgs = 'outputs/gray/'
checkpoints = 'checkpoints'
if colab:
    from google.colab import drive
    drive.mount('/content/drive')
    dataset = '/content/drive/MyDrive/MGU/cifar10/'
    
    color_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{color_imgs}'
    gray_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{gray_imgs}'
    checkpoints = f'/content/drive/MyDrive/MGU/{test_number}/{checkpoints}'
elif kaggle:
    dataset = '/kaggle/input/cifar10/'
    
    color_imgs = f'{test_number}/{color_imgs}'
    gray_imgs = f'{test_number}/{gray_imgs}'
    checkpoints = f'{test_number}/{checkpoints}'
else:
    dataset = '../../datasets/cifar10/'

In [5]:
# Make folders and set parameters
os.makedirs(color_imgs, exist_ok=True)
os.makedirs(gray_imgs, exist_ok=True)
os.makedirs(checkpoints, exist_ok=True)
save_images = True
best_losses = [1e10, 1e10, 1e10]
best_epoch = -1
patience = 50
epochs = 500
batch_size = 128
SIZE = 32

In [6]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [7]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()
print(use_gpu)

True


In [8]:
class LabImageFolder(torch.utils.data.Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC),
                transforms.RandomCrop(SIZE),
                transforms.RandomHorizontalFlip(),  
            ])
        elif split == 'val':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC), 
                transforms.RandomCrop(SIZE), 
            ])
            
        self.split = split
        self.size = SIZE
        self.paths = [os.path.join(paths, file) for file in os.listdir(paths) if os.path.isfile(
            os.path.join(paths, file))]
        
        
    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("RGB")
        img_original = self.transforms(img)
        img_original = np.asarray(img_original)
        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        img_gray = rgb2gray(img_original)
        img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
        return img_gray, img_ab
    
    def __len__(self):
        return len(self.paths)

In [9]:
# Training
train_imagefolder = LabImageFolder(dataset + 'train')
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=batch_size, shuffle=True)
# Validation 
val_imagefolder = LabImageFolder(dataset + 'val' , 'val')
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=batch_size, shuffle=False)

In [10]:
kernel_size=3
stride_en=1
stride_de=1
padding=1
scale_factor=2
padding_mode='zeros'
class Autoencoder(nn.Module):
  def __init__(self, input_size=128):
    super(Autoencoder, self).__init__()

    self.encoder = nn.Sequential(       
      nn.Conv2d(1, 16, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode),
      nn.BatchNorm2d(16),
      nn.LeakyReLU(0.1),
      nn.AvgPool2d(kernel_size=3, stride=2, padding=1),
      nn.Conv2d(16, 32, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode),
      nn.BatchNorm2d(32),
      nn.LeakyReLU(0.1),
      nn.AvgPool2d(kernel_size=3, stride=2, padding=1),
    )     
        
    self.decoder = nn.Sequential(  
      nn.ConvTranspose2d(32, 16, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode),
      nn.BatchNorm2d(16),
      nn.LeakyReLU(0.1),
      nn.Upsample(scale_factor=scale_factor),
      nn.ConvTranspose2d(16, 8, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode),
      nn.BatchNorm2d(8),
      nn.LeakyReLU(0.1),
      nn.ConvTranspose2d(8, 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode),
      nn.Upsample(scale_factor=scale_factor)
    )

  def forward(self, input):

    encoder = self.encoder(input)
    # Upsample to get colors
    output = self.decoder(encoder)
    return output

In [11]:
model = Autoencoder()

In [12]:
criterion = [MeanSquaredError(), PeakSignalNoiseRatio(data_range=1.0), StructuralSimilarityIndexMeasure(data_range=1.0)]

In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [14]:
# # Move model and loss function to GPU
if use_gpu: 
    criterion = [criterion[0].to("cuda"), criterion[1].to("cuda"), criterion[2].to("cuda")]
    model = model.cuda()

In [15]:
if use_gpu: 
    from torchsummary import summary
    summary(model, (1, SIZE, SIZE))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 32, 32]             160
       BatchNorm2d-2           [-1, 16, 32, 32]              32
         LeakyReLU-3           [-1, 16, 32, 32]               0
         AvgPool2d-4           [-1, 16, 16, 16]               0
            Conv2d-5           [-1, 32, 16, 16]           4,640
       BatchNorm2d-6           [-1, 32, 16, 16]              64
         LeakyReLU-7           [-1, 32, 16, 16]               0
         AvgPool2d-8             [-1, 32, 8, 8]               0
   ConvTranspose2d-9             [-1, 16, 8, 8]           4,624
      BatchNorm2d-10             [-1, 16, 8, 8]              32
        LeakyReLU-11             [-1, 16, 8, 8]               0
         Upsample-12           [-1, 16, 16, 16]               0
  ConvTranspose2d-13            [-1, 8, 16, 16]           1,160
      BatchNorm2d-14            [-1, 8,

In [16]:
class AverageMeter(object):
    '''A handy class from the PyTorch ImageNet tutorial''' 
    def __init__(self):
        self.reset()
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
       Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
    color_image = lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    if save_path is not None and save_name is not None: 
        plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
        plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

In [17]:
def validate(val_loader, model, criterion, save_images, epoch):
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]

    model.eval()
    already_saved_images = False
    for gray, ab in val_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()

        # Run model and record loss
        output_ab = model(gray) # throw away class predictions
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))

        # Save images to file
        if save_images and not already_saved_images:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)): # save at most 5 images
                save_path = {'grayscale': gray_imgs, 'colorized': color_imgs}
                save_name = 'img-{}-epoch-{}.jpg'.format(j, epoch)
                to_rgb(gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

    print(f'Validate: MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print('Finished validation.')
    if epoch >= 0:
        writer.add_scalar("MSE/test", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/test", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/test", _loss[2].avg, epoch)
    return _loss[0].avg, _loss[1].avg, _loss[2].avg

In [18]:
def train(train_loader, model, criterion, optimizer, epoch):
    print(f'Starting training epoch {epoch}')
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]
    
    model.train()

    for gray, ab in train_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()
            
        optimizer.zero_grad()

        output_ab = model(gray) 
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        loss[0].backward()
        optimizer.step()
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))
        
    print(f'Epoch: {epoch}, MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print(f'Finished training epoch {epoch}')
    if epoch >= 0:
        writer.add_scalar("MSE/train", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/train", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/train", _loss[2].avg, epoch)

In [19]:
# Train model
for epoch in range(epochs):
    # Train for one epoch, then validate
    train(train_loader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(val_loader, model, criterion, save_images, epoch)
    # Save checkpoint and replace old best model if current model is better
    if losses[0] < best_losses[0]:
        best_losses[0] = losses[0]
        best_epoch = epoch
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}.pth')
    if losses[1] < best_losses[1]:
        best_losses[1] = losses[1]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-PSNRLoss-{losses[1]:.8f}.pth')
    if losses[2] < best_losses[2]:
        best_losses[2] = losses[2]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-SSIMLoss-{losses[2]:.8f}.pth')
    
    if epoch - best_epoch >= patience and epoch >= 100:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}-early_stop.pth')
        break
    
    if epoch == epochs - 1:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')


Starting training epoch 0
Epoch: 0, MSE 0.00275531 (0.01428257), PSNR 25.59829712 (22.37121229), SSIM 0.69594574 (0.54040935)
Finished training epoch 0
Validate: MSE 0.00351311 (0.00308461), PSNR 24.54307747 (25.16439119), SSIM 0.60646254 (0.67288071)
Finished validation.
Starting training epoch 1
Epoch: 1, MSE 0.00312446 (0.00308439), PSNR 25.05224800 (25.13583221), SSIM 0.71817482 (0.70214228)
Finished training epoch 1
Validate: MSE 0.00318243 (0.00287413), PSNR 24.97241211 (25.49370896), SSIM 0.65479648 (0.72248936)
Finished validation.
Starting training epoch 2
Epoch: 2, MSE 0.00290605 (0.00288649), PSNR 25.36697006 (25.41846593), SSIM 0.75462329 (0.73352872)
Finished training epoch 2
Validate: MSE 0.00328390 (0.00276241), PSNR 24.83609390 (25.65244671), SSIM 0.67816436 (0.74270110)
Finished validation.
Starting training epoch 3
Epoch: 3, MSE 0.00212826 (0.00279085), PSNR 26.71974373 (25.56373383), SSIM 0.78012627 (0.74768138)
Finished training epoch 3
Validate: MSE 0.00311074 (0.0

Epoch: 30, MSE 0.00264941 (0.00237731), PSNR 25.76850700 (26.26228870), SSIM 0.77750564 (0.77727714)
Finished training epoch 30
Validate: MSE 0.00278147 (0.00242621), PSNR 25.55726242 (26.21737292), SSIM 0.71381801 (0.77518859)
Finished validation.
Starting training epoch 31
Epoch: 31, MSE 0.00239643 (0.00238306), PSNR 26.20434380 (26.25170577), SSIM 0.76977074 (0.77735051)
Finished training epoch 31
Validate: MSE 0.00276459 (0.00247291), PSNR 25.58369637 (26.13786753), SSIM 0.71592069 (0.77480525)
Finished validation.
Starting training epoch 32
Epoch: 32, MSE 0.00216302 (0.00236378), PSNR 26.64939690 (26.28702376), SSIM 0.79141504 (0.77756324)
Finished training epoch 32
Validate: MSE 0.00268935 (0.00255389), PSNR 25.70353127 (26.00566684), SSIM 0.71148860 (0.77371371)
Finished validation.
Starting training epoch 33
Epoch: 33, MSE 0.00177633 (0.00236283), PSNR 27.50475502 (26.28848740), SSIM 0.82513797 (0.77762321)
Finished training epoch 33
Validate: MSE 0.00300237 (0.00248659), PSNR 

Epoch: 60, MSE 0.00218183 (0.00232454), PSNR 26.61178207 (26.36045342), SSIM 0.76604384 (0.77805854)
Finished training epoch 60
Validate: MSE 0.00256013 (0.00234995), PSNR 25.91738319 (26.35014611), SSIM 0.72237551 (0.77766985)
Finished validation.
Starting training epoch 61
Epoch: 61, MSE 0.00253571 (0.00231752), PSNR 25.95899582 (26.37484795), SSIM 0.79070348 (0.77806196)
Finished training epoch 61
Validate: MSE 0.00268573 (0.00233902), PSNR 25.70937920 (26.37013261), SSIM 0.71839392 (0.77795417)
Finished validation.
Starting training epoch 62
Epoch: 62, MSE 0.00295456 (0.00232916), PSNR 25.29506493 (26.35545602), SSIM 0.76136851 (0.77792484)
Finished training epoch 62
Validate: MSE 0.00270151 (0.00251746), PSNR 25.68392563 (26.05290200), SSIM 0.71984708 (0.77356328)
Finished validation.
Starting training epoch 63
Epoch: 63, MSE 0.00278532 (0.00232416), PSNR 25.55124664 (26.36008893), SSIM 0.75294715 (0.77786088)
Finished training epoch 63
Validate: MSE 0.00272051 (0.00265163), PSNR 

Epoch: 90, MSE 0.00212640 (0.00229453), PSNR 26.72354126 (26.41547134), SSIM 0.77401680 (0.77805931)
Finished training epoch 90
Validate: MSE 0.00263600 (0.00275397), PSNR 25.79054070 (25.71588247), SSIM 0.71618813 (0.77222972)
Finished validation.
Starting training epoch 91
Epoch: 91, MSE 0.00199332 (0.00230083), PSNR 27.00422287 (26.40349048), SSIM 0.80901796 (0.77807990)
Finished training epoch 91
Validate: MSE 0.00282215 (0.00238239), PSNR 25.49419785 (26.29954123), SSIM 0.71380514 (0.77517268)
Finished validation.
Starting training epoch 92
Epoch: 92, MSE 0.00225153 (0.00229870), PSNR 26.47522926 (26.41091398), SSIM 0.78429496 (0.77824277)
Finished training epoch 92
Validate: MSE 0.00278278 (0.00237856), PSNR 25.55520630 (26.30435222), SSIM 0.71764469 (0.77743809)
Finished validation.
Starting training epoch 93
Epoch: 93, MSE 0.00248996 (0.00230158), PSNR 26.03806686 (26.40313168), SSIM 0.77881849 (0.77802400)
Finished training epoch 93
Validate: MSE 0.00291461 (0.00271610), PSNR 

Epoch: 120, MSE 0.00265448 (0.00228152), PSNR 25.76020050 (26.44267100), SSIM 0.76314890 (0.77829878)
Finished training epoch 120
Validate: MSE 0.00276342 (0.00287293), PSNR 25.58553505 (25.49825871), SSIM 0.71391785 (0.76529599)
Finished validation.
Starting training epoch 121
Epoch: 121, MSE 0.00189115 (0.00229410), PSNR 27.23273277 (26.41864620), SSIM 0.78723425 (0.77810560)
Finished training epoch 121
Validate: MSE 0.00263303 (0.00228733), PSNR 25.79544449 (26.46628900), SSIM 0.71797013 (0.77774479)
Finished validation.
Starting training epoch 122
Epoch: 122, MSE 0.00201793 (0.00228853), PSNR 26.95093155 (26.42872373), SSIM 0.80165005 (0.77821343)
Finished training epoch 122
Validate: MSE 0.00270254 (0.00255842), PSNR 25.68228149 (26.00425457), SSIM 0.71220112 (0.77318449)
Finished validation.
Starting training epoch 123
Epoch: 123, MSE 0.00203919 (0.00228419), PSNR 26.90542984 (26.43501401), SSIM 0.77943587 (0.77820106)
Finished training epoch 123
Validate: MSE 0.00267325 (0.00231

Validate: MSE 0.00266671 (0.00229012), PSNR 25.74024773 (26.46317942), SSIM 0.71882051 (0.77960042)
Finished validation.
Starting training epoch 150
Epoch: 150, MSE 0.00218347 (0.00229016), PSNR 26.60852432 (26.42503107), SSIM 0.78637570 (0.77832947)
Finished training epoch 150
Validate: MSE 0.00260264 (0.00231827), PSNR 25.84585571 (26.41220139), SSIM 0.71380830 (0.77480909)
Finished validation.
Starting training epoch 151
Epoch: 151, MSE 0.00215557 (0.00227291), PSNR 26.66438293 (26.45849032), SSIM 0.78160381 (0.77842036)
Finished training epoch 151
Validate: MSE 0.00257532 (0.00236882), PSNR 25.89169312 (26.32354333), SSIM 0.71589643 (0.77712780)
Finished validation.
Starting training epoch 152
Epoch: 152, MSE 0.00257274 (0.00227585), PSNR 25.89604568 (26.45298545), SSIM 0.78741992 (0.77831222)
Finished training epoch 152
Validate: MSE 0.00268112 (0.00226414), PSNR 25.71683311 (26.51280970), SSIM 0.71403021 (0.78174572)
Finished validation.
Starting training epoch 153
Epoch: 153, MS

Epoch: 179, MSE 0.00235812 (0.00226820), PSNR 26.27433968 (26.46635241), SSIM 0.79146117 (0.77833637)
Finished training epoch 179
Validate: MSE 0.00272220 (0.00238588), PSNR 25.65079880 (26.29727607), SSIM 0.71752083 (0.77764395)
Finished validation.
Starting training epoch 180
Epoch: 180, MSE 0.00204486 (0.00227596), PSNR 26.89335442 (26.45395973), SSIM 0.78803331 (0.77833049)
Finished training epoch 180
Validate: MSE 0.00271491 (0.00226918), PSNR 25.66244888 (26.50541561), SSIM 0.71093667 (0.77841737)
Finished validation.
Starting training epoch 181
Epoch: 181, MSE 0.00248726 (0.00227220), PSNR 26.04278183 (26.46172541), SSIM 0.77085179 (0.77846440)
Finished training epoch 181
Validate: MSE 0.00261044 (0.00272858), PSNR 25.83286476 (25.74062907), SSIM 0.71290344 (0.77127018)
Finished validation.
Starting training epoch 182
Epoch: 182, MSE 0.00264240 (0.00226593), PSNR 25.78001022 (26.46915170), SSIM 0.77494466 (0.77844059)
Finished training epoch 182
Validate: MSE 0.00294378 (0.00242

<Figure size 432x288 with 0 Axes>

In [20]:
torch.save(model.state_dict(), f'{checkpoints}/last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')

In [21]:
# Validate
save_images = True
with torch.no_grad():
    validate(val_loader, model, criterion, save_images, -1)

Validate: MSE 0.00275259 (0.00238058), PSNR 25.60258675 (26.30027454), SSIM 0.71490979 (0.77529448)
Finished validation.


<Figure size 432x288 with 0 Axes>

In [22]:
# # Show images 
# image_pairs = []

# for i in range(10):
#     image_pairs.append((f'{color_imgs}img-{i}-epoch-{best_epoch}.jpg', f'{gray_imgs}img-{i}-epoch-{best_epoch}.jpg'))
    
# for c, g in image_pairs:
#   color = mpimg.imread(c)
#   gray  = mpimg.imread(g)
#   f, axarr = plt.subplots(1, 2)
#   f.set_size_inches(15, 15)
#   axarr[0].imshow(gray, cmap='gray')
#   axarr[1].imshow(color)
#   axarr[0].axis('off'), axarr[1].axis('off')
#   plt.show()