In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# For conversion
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchmetrics import MeanSquaredError, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from PIL import Image
# For utilities
import os, shutil, time

In [2]:
%load_ext tensorboard
%tensorboard --logdir=runs

Reusing TensorBoard on port 6006 (pid 10786), started 0:18:20 ago. (Use '!kill 10786' to kill it.)

In [3]:
# colab i kaggle jeszcze nie testowane
colab = False
kaggle = False
test_number = '12_2'

In [4]:
color_imgs = 'outputs/color/'
gray_imgs = 'outputs/gray/'
checkpoints = 'checkpoints'
if colab:
    from google.colab import drive
    drive.mount('/content/drive')
    dataset = '/content/drive/MyDrive/MGU/cifar10/'
    
    color_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{color_imgs}'
    gray_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{gray_imgs}'
    checkpoints = f'/content/drive/MyDrive/MGU/{test_number}/{checkpoints}'
elif kaggle:
    dataset = '/kaggle/input/cifar10/'
    
    color_imgs = f'{test_number}/{color_imgs}'
    gray_imgs = f'{test_number}/{gray_imgs}'
    checkpoints = f'{test_number}/{checkpoints}'
else:
    dataset = '../../datasets/cifar10/'

In [5]:
# Make folders and set parameters
os.makedirs(color_imgs, exist_ok=True)
os.makedirs(gray_imgs, exist_ok=True)
os.makedirs(checkpoints, exist_ok=True)
save_images = True
best_losses = [1e10, 1e10, 1e10]
best_epoch = -1
patience = 50
epochs = 500
batch_size = 128
SIZE = 32

In [6]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [7]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()
print(use_gpu)

True


In [8]:
class LabImageFolder(torch.utils.data.Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC),
                transforms.RandomCrop(SIZE),
                transforms.RandomHorizontalFlip(),  
            ])
        elif split == 'val':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC), 
                transforms.RandomCrop(SIZE), 
            ])
            
        self.split = split
        self.size = SIZE
        self.paths = [os.path.join(paths, file) for file in os.listdir(paths) if os.path.isfile(
            os.path.join(paths, file))]
        
        
    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("RGB")
        img_original = self.transforms(img)
        img_original = np.asarray(img_original)
        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        img_gray = rgb2gray(img_original)
        img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
        return img_gray, img_ab
    
    def __len__(self):
        return len(self.paths)

In [9]:
# Training
train_imagefolder = LabImageFolder(dataset + 'train')
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=batch_size, shuffle=True)
# Validation 
val_imagefolder = LabImageFolder(dataset + 'val' , 'val')
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=batch_size, shuffle=False)

In [10]:
kernel_size=3
stride_en=1
stride_de=1
padding=1
scale_factor=2
padding_mode='zeros'
channels_base = 256
p1 = .25

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        self.conv1 = nn.Conv2d(1, channels_base, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        self.conv2 = nn.Conv2d(channels_base, channels_base * 2, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)

        self.convtrans1 = nn.ConvTranspose2d(channels_base * 2, channels_base, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans2 = nn.ConvTranspose2d(channels_base, channels_base // 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans3 = nn.ConvTranspose2d(channels_base // 2, 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)

        self.batchnorm1 = nn.BatchNorm2d(channels_base)
        self.batchnorm2 = nn.BatchNorm2d(channels_base * 2)
        self.batchnorm3 = nn.BatchNorm2d(channels_base)
        self.batchnorm4 = nn.BatchNorm2d(channels_base // 2)
        self.batchnorm5 = nn.BatchNorm2d(2)
        
    def forward(self, input):
        # encoder
        x = F.leaky_relu(self.batchnorm1(self.conv1(input)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        x = F.leaky_relu(self.batchnorm2(self.conv2(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        
        # decoder
        x = F.leaky_relu(self.batchnorm3(self.convtrans1(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.leaky_relu(self.batchnorm4(self.convtrans2(x)), negative_slope=0.1)
        x = F.dropout(x, p=p1)
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.leaky_relu(self.batchnorm5(self.convtrans3(x)), negative_slope=0.1)

        return x

In [11]:
model = Autoencoder()

In [12]:
criterion = [MeanSquaredError(), PeakSignalNoiseRatio(data_range=1.0), StructuralSimilarityIndexMeasure(data_range=1.0)]

In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [14]:
# # Move model and loss function to GPU
if use_gpu: 
    criterion = [criterion[0].to("cuda"), criterion[1].to("cuda"), criterion[2].to("cuda")]
    model = model.cuda()

In [15]:
if use_gpu: 
    from torchsummary import summary
    summary(model, (1, SIZE, SIZE))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 256, 32, 32]           2,560
       BatchNorm2d-2          [-1, 256, 32, 32]             512
            Conv2d-3          [-1, 512, 16, 16]       1,180,160
       BatchNorm2d-4          [-1, 512, 16, 16]           1,024
   ConvTranspose2d-5            [-1, 256, 8, 8]       1,179,904
       BatchNorm2d-6            [-1, 256, 8, 8]             512
   ConvTranspose2d-7          [-1, 128, 16, 16]         295,040
       BatchNorm2d-8          [-1, 128, 16, 16]             256
   ConvTranspose2d-9            [-1, 2, 32, 32]           2,306
      BatchNorm2d-10            [-1, 2, 32, 32]               4
Total params: 2,662,278
Trainable params: 2,662,278
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 6.78
Params size (MB): 10.16
Estima

In [16]:
class AverageMeter(object):
    '''A handy class from the PyTorch ImageNet tutorial''' 
    def __init__(self):
        self.reset()
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
       Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
    color_image = lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    if save_path is not None and save_name is not None: 
        plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
        plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

In [17]:
def validate(val_loader, model, criterion, save_images, epoch):
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]

    model.eval()
    already_saved_images = False
    for gray, ab in val_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()

        # Run model and record loss
        output_ab = model(gray) # throw away class predictions
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))

        # Save images to file
        if save_images and not already_saved_images:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)): # save at most 5 images
                save_path = {'grayscale': gray_imgs, 'colorized': color_imgs}
                save_name = 'img-{}-epoch-{}.jpg'.format(j, epoch)
                to_rgb(gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

    print(f'Validate: MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print('Finished validation.')
    if epoch >= 0:
        writer.add_scalar("MSE/test", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/test", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/test", _loss[2].avg, epoch)
    return _loss[0].avg, _loss[1].avg, _loss[2].avg

In [18]:
def train(train_loader, model, criterion, optimizer, epoch):
    print(f'Starting training epoch {epoch}')
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]
    
    model.train()

    for gray, ab in train_loader:
        if use_gpu: 
            gray, ab = gray.cuda(), ab.cuda()
            
        optimizer.zero_grad()

        output_ab = model(gray) 
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        loss[0].backward()
        optimizer.step()
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))
        
    print(f'Epoch: {epoch}, MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print(f'Finished training epoch {epoch}')
    if epoch >= 0:
        writer.add_scalar("MSE/train", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/train", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/train", _loss[2].avg, epoch)

In [19]:
# Train model
for epoch in range(epochs):
    # Train for one epoch, then validate
    train(train_loader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(val_loader, model, criterion, save_images, epoch)
    # Save checkpoint and replace old best model if current model is better
    if losses[0] < best_losses[0]:
        best_losses[0] = losses[0]
        best_epoch = epoch
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}.pth')
    if losses[1] < best_losses[1]:
        best_losses[1] = losses[1]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-PSNRLoss-{losses[1]:.8f}.pth')
    if losses[2] < best_losses[2]:
        best_losses[2] = losses[2]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-SSIMLoss-{losses[2]:.8f}.pth')
    
    if epoch - best_epoch >= patience and epoch >= 100:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}-early_stop.pth')
        break
    
    if epoch == epochs - 1:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')


Starting training epoch 0
Epoch: 0, MSE 0.00281495 (0.02120731), PSNR 25.50529861 (20.41187861), SSIM 0.76312405 (0.60692621)
Finished training epoch 0
Validate: MSE 0.00313312 (0.00271486), PSNR 25.04022789 (25.73842564), SSIM 0.70213848 (0.77096545)
Finished validation.
Starting training epoch 1
Epoch: 1, MSE 0.00299842 (0.00265485), PSNR 25.23106956 (25.78115687), SSIM 0.74857455 (0.76541622)
Finished training epoch 1
Validate: MSE 0.00296990 (0.00256114), PSNR 25.27257919 (25.98286840), SSIM 0.69806635 (0.76667390)
Finished validation.
Starting training epoch 2
Epoch: 2, MSE 0.00268930 (0.00254314), PSNR 25.70360374 (25.96930976), SSIM 0.76108265 (0.76587218)
Finished training epoch 2
Validate: MSE 0.00295500 (0.00249831), PSNR 25.29443169 (26.08939319), SSIM 0.70388460 (0.76980982)
Finished validation.
Starting training epoch 3
Epoch: 3, MSE 0.00354716 (0.00249024), PSNR 24.50119019 (26.06000463), SSIM 0.75358641 (0.76895690)
Finished training epoch 3
Validate: MSE 0.00277556 (0.0

Epoch: 30, MSE 0.00238573 (0.00223987), PSNR 26.22378731 (26.52057457), SSIM 0.78453040 (0.78068555)
Finished training epoch 30
Validate: MSE 0.00258091 (0.00226620), PSNR 25.88226700 (26.51290045), SSIM 0.72599196 (0.78065407)
Finished validation.
Starting training epoch 31
Epoch: 31, MSE 0.00224266 (0.00223441), PSNR 26.49236298 (26.53120698), SSIM 0.78851289 (0.78072155)
Finished training epoch 31
Validate: MSE 0.00264716 (0.00229725), PSNR 25.77219963 (26.44795500), SSIM 0.72164369 (0.78212034)
Finished validation.
Starting training epoch 32
Epoch: 32, MSE 0.00236585 (0.00223002), PSNR 26.26012230 (26.54084882), SSIM 0.78452069 (0.78084720)
Finished training epoch 32
Validate: MSE 0.00243135 (0.00226992), PSNR 26.14151573 (26.49932452), SSIM 0.72196937 (0.77609201)
Finished validation.
Starting training epoch 33
Epoch: 33, MSE 0.00217497 (0.00222840), PSNR 26.62546539 (26.54265384), SSIM 0.77585459 (0.78082588)
Finished training epoch 33
Validate: MSE 0.00280070 (0.00230730), PSNR 

Epoch: 60, MSE 0.00192639 (0.00210933), PSNR 27.15256500 (26.78206626), SSIM 0.78841794 (0.78094140)
Finished training epoch 60
Validate: MSE 0.00258893 (0.00226288), PSNR 25.86879158 (26.51138633), SSIM 0.71019483 (0.76935522)
Finished validation.
Starting training epoch 61
Epoch: 61, MSE 0.00212406 (0.00210124), PSNR 26.72834015 (26.79713438), SSIM 0.80023861 (0.78108434)
Finished training epoch 61
Validate: MSE 0.00277223 (0.00238658), PSNR 25.57170486 (26.29621771), SSIM 0.71348161 (0.77553734)
Finished validation.
Starting training epoch 62
Epoch: 62, MSE 0.00235260 (0.00210736), PSNR 26.28452682 (26.78508651), SSIM 0.78752178 (0.78095830)
Finished training epoch 62
Validate: MSE 0.00270139 (0.00239325), PSNR 25.68412971 (26.28410801), SSIM 0.71455944 (0.77498967)
Finished validation.
Starting training epoch 63
Epoch: 63, MSE 0.00194254 (0.00209454), PSNR 27.11630630 (26.81218725), SSIM 0.79187340 (0.78099925)
Finished training epoch 63
Validate: MSE 0.00246874 (0.00226643), PSNR 

  return func(*args, **kwargs)


Validate: MSE 0.00276980 (0.00229321), PSNR 25.57551575 (26.44879089), SSIM 0.70057380 (0.76407223)
Finished validation.
Starting training epoch 77
Epoch: 77, MSE 0.00253696 (0.00205304), PSNR 25.95685577 (26.89603092), SSIM 0.77795041 (0.78015292)
Finished training epoch 77
Validate: MSE 0.00286596 (0.00235789), PSNR 25.42729759 (26.34586178), SSIM 0.70938075 (0.77856904)
Finished validation.
Starting training epoch 78
Epoch: 78, MSE 0.00189363 (0.00204957), PSNR 27.22705841 (26.90405635), SSIM 0.78693521 (0.78021696)
Finished training epoch 78
Validate: MSE 0.00263072 (0.00230765), PSNR 25.79925537 (26.43466021), SSIM 0.71413040 (0.76810512)
Finished validation.
Starting training epoch 79
Epoch: 79, MSE 0.00204074 (0.00205200), PSNR 26.90211105 (26.89920115), SSIM 0.76462865 (0.78005734)
Finished training epoch 79
Validate: MSE 0.00242932 (0.00225248), PSNR 26.14514923 (26.53499159), SSIM 0.71627212 (0.77152697)
Finished validation.
Starting training epoch 80
Epoch: 80, MSE 0.0022730

  return func(*args, **kwargs)


Validate: MSE 0.00266009 (0.00228554), PSNR 25.75103378 (26.46409050), SSIM 0.70597458 (0.76912521)
Finished validation.
Starting training epoch 86
Epoch: 86, MSE 0.00216789 (0.00202422), PSNR 26.63963318 (26.95939938), SSIM 0.76782084 (0.77962638)
Finished training epoch 86
Validate: MSE 0.00266283 (0.00230931), PSNR 25.74656487 (26.43144902), SSIM 0.71263015 (0.76905341)
Finished validation.
Starting training epoch 87
Epoch: 87, MSE 0.00157706 (0.00202289), PSNR 28.02152252 (26.95859178), SSIM 0.79419523 (0.77978102)
Finished training epoch 87
Validate: MSE 0.00285853 (0.00228699), PSNR 25.43857956 (26.47453373), SSIM 0.70667064 (0.77564686)
Finished validation.
Starting training epoch 88
Epoch: 88, MSE 0.00210624 (0.00202343), PSNR 26.76492500 (26.96106691), SSIM 0.79652429 (0.77954423)
Finished training epoch 88
Validate: MSE 0.00262431 (0.00226424), PSNR 25.80984879 (26.51098868), SSIM 0.71888274 (0.77576651)
Finished validation.
Starting training epoch 89
Epoch: 89, MSE 0.0021646

Epoch: 115, MSE 0.00221098 (0.00196351), PSNR 26.55414200 (27.09125021), SSIM 0.76743156 (0.77819803)
Finished training epoch 115
Validate: MSE 0.00265430 (0.00234226), PSNR 25.76050186 (26.36404153), SSIM 0.70160985 (0.76051287)
Finished validation.
Starting training epoch 116
Epoch: 116, MSE 0.00196560 (0.00195795), PSNR 27.06504631 (27.10311832), SSIM 0.80093896 (0.77827887)
Finished training epoch 116
Validate: MSE 0.00256116 (0.00226628), PSNR 25.91562653 (26.50230692), SSIM 0.71130157 (0.77283930)
Finished validation.
Starting training epoch 117
Epoch: 117, MSE 0.00222661 (0.00196465), PSNR 26.52355385 (27.08820393), SSIM 0.77118576 (0.77815255)
Finished training epoch 117
Validate: MSE 0.00272531 (0.00225396), PSNR 25.64583778 (26.53461857), SSIM 0.70834649 (0.76881552)
Finished validation.
Starting training epoch 118
Epoch: 118, MSE 0.00220281 (0.00195763), PSNR 26.57022476 (27.10099592), SSIM 0.76238197 (0.77798952)
Finished training epoch 118
Validate: MSE 0.00248034 (0.00224

<Figure size 432x288 with 0 Axes>

In [20]:
torch.save(model.state_dict(), f'{checkpoints}/last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')

In [21]:
# Validate
save_images = True
with torch.no_grad():
    validate(val_loader, model, criterion, save_images, -1)

Validate: MSE 0.00255264 (0.00228250), PSNR 25.93009758 (26.47263344), SSIM 0.71170187 (0.76470737)
Finished validation.


<Figure size 432x288 with 0 Axes>

In [22]:
# # Show images 
# image_pairs = []

# for i in range(10):
#     image_pairs.append((f'{color_imgs}img-{i}-epoch-{best_epoch}.jpg', f'{gray_imgs}img-{i}-epoch-{best_epoch}.jpg'))
    
# for c, g in image_pairs:
#   color = mpimg.imread(c)
#   gray  = mpimg.imread(g)
#   f, axarr = plt.subplots(1, 2)
#   f.set_size_inches(15, 15)
#   axarr[0].imshow(gray, cmap='gray')
#   axarr[1].imshow(color)
#   axarr[0].axis('off'), axarr[1].axis('off')
#   plt.show()