In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# For conversion
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchmetrics import MeanSquaredError, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from PIL import Image
# For utilities
import os, shutil, time

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%load_ext tensorboard
%tensorboard --logdir=runs

In [3]:
# colab i kaggle jeszcze nie testowane
colab = False
kaggle = False
test_number = '12_2'

In [5]:
color_imgs = 'outputs/color/'
gray_imgs = 'outputs/gray/'
checkpoints = 'checkpoints'
if colab:
    from google.colab import drive
    drive.mount('/content/drive')
    dataset = '/content/drive/MyDrive/MGU/cifar10/'
    
    color_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{color_imgs}'
    gray_imgs = f'/content/drive/MyDrive/MGU/{test_number}/{gray_imgs}'
    checkpoints = f'/content/drive/MyDrive/MGU/{test_number}/{checkpoints}'
elif kaggle:
    dataset = '/kaggle/input/cifar10/'
    
    color_imgs = f'{test_number}/{color_imgs}'
    gray_imgs = f'{test_number}/{gray_imgs}'
    checkpoints = f'{test_number}/{checkpoints}'
else:
    dataset = '../../datasets/cifar10/'

In [6]:
# Make folders and set parameters
os.makedirs(color_imgs, exist_ok=True)
os.makedirs(gray_imgs, exist_ok=True)
os.makedirs(checkpoints, exist_ok=True)
save_images = True
best_losses = [1e10, 1e10, 1e10]
best_epoch = -1
patience = 50
epochs = 500
batch_size = 128
SIZE = 32

In [7]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [8]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()
print(use_gpu)

True


In [9]:
import csv

embeddings = {}
with open("../inception_resnet_v2_files/embeddings-cifar10-normalization.csv") as f:
    reader = csv.reader(f)
    for row in reader:
        filename = row[0]
        values = [float(val) for val in row[1:]]
        embeddings[filename] = np.array(values)

In [10]:
images_rgb = {}
imagespath = '../../datasets/cifar10/train'
imagespathval = '../../datasets/cifar10/val'
all_files = [os.path.join(imagespath, file) for file in os.listdir(imagespath) if os.path.isfile(
            os.path.join(imagespath, file))]
all_files2 = [os.path.join(imagespathval, file) for file in os.listdir(imagespathval) if os.path.isfile(
            os.path.join(imagespathval, file))]
all_files = all_files + all_files2
for image in all_files:
    rgbfilename = os.path.basename(image)
    img = Image.open(image).convert("RGB")
    images_rgb[rgbfilename] = img

In [11]:
print(len(embeddings))
print(embeddings["abandoned_ship_s_000004.png"][:20])
print(len(images_rgb))
print(images_rgb["abandoned_ship_s_000004.png"])

60000
[0.21317568 0.16253434 0.26581229 0.29878058 0.33441289 0.24510603
 0.51277979 0.18380144 0.1670414  0.15279668 0.10542652 0.20540888
 0.20522403 0.19407258 0.17829575 0.11919297 0.2119326  0.24213714
 0.23753528 0.18780443]
60000
<PIL.Image.Image image mode=RGB size=32x32 at 0x1EDC23FA3A0>


In [12]:
class LabImageFolder(torch.utils.data.Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC),
                transforms.RandomCrop(SIZE),
                transforms.RandomHorizontalFlip(),  
            ])
        elif split == 'val':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), transforms.InterpolationMode.BICUBIC), 
                transforms.RandomCrop(SIZE), 
            ])
            
        self.split = split
        self.size = SIZE
        self.paths = [os.path.join(paths, file) for file in os.listdir(paths) if os.path.isfile(
            os.path.join(paths, file))]
        
        
    def __getitem__(self, index):
        itemfilename = os.path.basename(self.paths[index])
        img = images_rgb[itemfilename]
        img_original = self.transforms(img)
        img_original = np.asarray(img_original)
        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        img_gray = rgb2gray(img_original)
        img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
        embd = embeddings[itemfilename]
        embd = torch.from_numpy(embd).unsqueeze(0).float()
        return img_gray, img_ab, embd
    
    def __len__(self):
        return len(self.paths)

In [13]:
# Training
train_imagefolder = LabImageFolder(dataset + 'train')
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=batch_size, shuffle=True)
# Validation 
val_imagefolder = LabImageFolder(dataset + 'val' , 'val')
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=batch_size, shuffle=False)

In [14]:
class FusionLAyer(nn.Module):
    def __init__(self):
        super().__init__()
        pass

    def forward(self, imgs, embds):
        embds = torch.reshape(embds, (embds.shape[0], embds.shape[2]))
        reshaped_shape = (imgs.shape[0], embds.shape[1], imgs.shape[2], imgs.shape[3])
        #shapetosave = (embds.shape[0], imgs.shape[2] * imgs.shape[3], embds.shape[1])
        embds = embds.repeat(1,imgs.shape[2] * imgs.shape[3]) #shape (128, 64000)
        #embds = torch.reshape(embds, shapetosave)
        #embds = torch.reshape(embds, reshaped_shape)
        embds = torch.reshape(embds, reshaped_shape)
        cos = torch.cat((imgs, embds), dim=1)
        return cos

In [27]:
kernel_size=3
stride_en=2
stride_de=1
padding=1
scale_factor=2
padding_mode='zeros'
channels_base = 64
p1 = .5

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        self.conv1 = nn.Conv2d(1, channels_base, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        self.conv2 = nn.Conv2d(channels_base, channels_base * 2, kernel_size=kernel_size, stride=stride_en, padding=padding, padding_mode=padding_mode)
        
        self.convtrans1 = nn.ConvTranspose2d(channels_base * 2, channels_base, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans2 = nn.ConvTranspose2d(channels_base, channels_base // 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)
        self.convtrans3 = nn.ConvTranspose2d(channels_base // 2, 2, kernel_size=kernel_size, stride=stride_de, padding=padding, padding_mode=padding_mode)

        self.batchnorm1 = nn.BatchNorm2d(channels_base // 2)
        self.batchnorm2 = nn.BatchNorm2d(channels_base)
        self.batchnorm3 = nn.BatchNorm2d(channels_base * 2)
        
        self.dropout1 = nn.Dropout(p=p1)
        
        self.fusion = FusionLAyer()
        self.after_fusion = nn.Sequential(
            nn.Conv2d(1128,128, kernel_size=1),
            nn.ReLU()
        )
        
    def forward(self, input, embds):
        # encoder
        x = y = F.relu(self.batchnorm2(self.conv1(input)))
        x = self.dropout1(x)
        x = F.relu(self.batchnorm3(self.conv2(x)))
        x = self.dropout1(x)

        # fusion
        x = self.fusion(x, embds)
        x = self.after_fusion(x)
        
        # decoder
        x = F.relu(self.batchnorm2(self.convtrans1(x)))
        x = self.dropout1(x)
        x = F.interpolate(x, scale_factor=scale_factor)
        x = F.relu(self.batchnorm1(self.convtrans2(x + y)))
        x = self.dropout1(x)
        x = self.convtrans3(F.interpolate(x, scale_factor=scale_factor) + input)

        return x

In [28]:
model = Autoencoder()

In [29]:
criterion = [MeanSquaredError(), PeakSignalNoiseRatio(data_range=1.0), StructuralSimilarityIndexMeasure(data_range=1.0)]

In [30]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [31]:
# # Move model and loss function to GPU
if use_gpu: 
    criterion = [criterion[0].to("cuda"), criterion[1].to("cuda"), criterion[2].to("cuda")]
    model = model.cuda()

In [32]:
if use_gpu: 
    from torchsummary import summary
    summary(model, [(1, SIZE, SIZE), (1, 1000)])

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 16, 16]             640
       BatchNorm2d-2           [-1, 64, 16, 16]             128
           Dropout-3           [-1, 64, 16, 16]               0
            Conv2d-4            [-1, 128, 8, 8]          73,856
       BatchNorm2d-5            [-1, 128, 8, 8]             256
           Dropout-6            [-1, 128, 8, 8]               0
       FusionLAyer-7           [-1, 1128, 8, 8]               0
            Conv2d-8            [-1, 128, 8, 8]         144,512
              ReLU-9            [-1, 128, 8, 8]               0
  ConvTranspose2d-10             [-1, 64, 8, 8]          73,792
      BatchNorm2d-11             [-1, 64, 8, 8]             128
          Dropout-12             [-1, 64, 8, 8]               0
  ConvTranspose2d-13           [-1, 32, 16, 16]          18,464
      BatchNorm2d-14           [-1, 32,

ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (2,) + inhomogeneous part.

In [33]:
class AverageMeter(object):
    '''A handy class from the PyTorch ImageNet tutorial''' 
    def __init__(self):
        self.reset()
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
    '''Show/save rgb image from grayscale and ab channels
       Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib 
    color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
    color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
    color_image = lab2rgb(color_image.astype(np.float64))
    grayscale_input = grayscale_input.squeeze().numpy()
    if save_path is not None and save_name is not None: 
        plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
        plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

In [34]:
def validate(val_loader, model, criterion, save_images, epoch):
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]

    model.eval()
    already_saved_images = False
    for gray, ab, embd in val_loader:
        if use_gpu: 
            gray, ab, embd = gray.cuda(), ab.cuda(), embd.cuda()

        # Run model and record loss
        output_ab = model(gray, embd) # throw away class predictions
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))

        # Save images to file
        if save_images and not already_saved_images:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)): # save at most 5 images
                save_path = {'grayscale': gray_imgs, 'colorized': color_imgs}
                save_name = 'img-{}-epoch-{}.jpg'.format(j, epoch)
                to_rgb(gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

    print(f'Validate: MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print('Finished validation.')
    if epoch >= 0:
        writer.add_scalar("MSE/test", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/test", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/test", _loss[2].avg, epoch)
    return _loss[0].avg, _loss[1].avg, _loss[2].avg

In [35]:
def train(train_loader, model, criterion, optimizer, epoch):
    print(f'Starting training epoch {epoch}')
    _loss = [AverageMeter(), AverageMeter(), AverageMeter()]
    
    model.train()

    for gray, ab, embd in train_loader:
        if use_gpu: 
            gray, ab, embd = gray.cuda(), ab.cuda(), embd.cuda()
            
        optimizer.zero_grad()

        output_ab = model(gray, embd) 
        loss = [criterion[0](output_ab, ab), criterion[1](output_ab, ab), criterion[2](output_ab, ab)]
        
        loss[0].backward()
        optimizer.step()
        
        _loss[0].update(loss[0].item(), gray.size(0))
        _loss[1].update(loss[1].item(), gray.size(0))
        _loss[2].update(loss[2].item(), gray.size(0))
        
    print(f'Epoch: {epoch}, MSE {_loss[0].val:.8f} ({_loss[0].avg:.8f}), PSNR {_loss[1].val:.8f} ({_loss[1].avg:.8f}), SSIM {_loss[2].val:.8f} ({_loss[2].avg:.8f})')

    print(f'Finished training epoch {epoch}')
    if epoch >= 0:
        writer.add_scalar("MSE/train", _loss[0].avg, epoch)
        writer.add_scalar("PSNR/train", _loss[1].avg, epoch)
        writer.add_scalar("SSIM/train", _loss[2].avg, epoch)

In [36]:
# Train model
for epoch in range(epochs):
    # Train for one epoch, then validate
    train(train_loader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(val_loader, model, criterion, save_images, epoch)
    # Save checkpoint and replace old best model if current model is better
    if losses[0] < best_losses[0]:
        best_losses[0] = losses[0]
        best_epoch = epoch
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}.pth')
    if losses[1] < best_losses[1]:
        best_losses[1] = losses[1]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-PSNRLoss-{losses[1]:.8f}.pth')
    if losses[2] < best_losses[2]:
        best_losses[2] = losses[2]
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-SSIMLoss-{losses[2]:.8f}.pth')
    
    if epoch - best_epoch >= patience:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-MSELoss-{losses[0]:.8f}-early_stop.pth')
        break
    
    if epoch == epochs - 1:
        torch.save(model.state_dict(), f'{checkpoints}/epoch-{epoch}-last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')


Starting training epoch 0
Epoch: 0, MSE 0.00533428 (0.13532598), PSNR 22.72924232 (15.87930546), SSIM 0.42230463 (0.15027729)
Finished training epoch 0
Validate: MSE 0.00990952 (0.00776068), PSNR 20.03947449 (21.12791138), SSIM 0.41128796 (0.50265245)
Finished validation.
Starting training epoch 1
Epoch: 1, MSE 0.00309090 (0.00386322), PSNR 25.09914970 (24.22983105), SSIM 0.70676881 (0.56899927)
Finished training epoch 1
Validate: MSE 0.00407268 (0.00379682), PSNR 23.90119743 (24.33973156), SSIM 0.65067172 (0.73791778)
Finished validation.
Starting training epoch 2
Epoch: 2, MSE 0.00266942 (0.00322813), PSNR 25.73582840 (24.97675146), SSIM 0.75234491 (0.73039166)
Finished training epoch 2
Validate: MSE 0.00348662 (0.00294432), PSNR 24.57595634 (25.39234597), SSIM 0.69321811 (0.76386575)
Finished validation.
Starting training epoch 3
Epoch: 3, MSE 0.00235120 (0.00351089), PSNR 26.28711128 (24.74959963), SSIM 0.77855730 (0.74843146)
Finished training epoch 3
Validate: MSE 0.00415976 (0.0

  return func(*args, **kwargs)


Validate: MSE 0.00666143 (0.00754976), PSNR 21.76432228 (21.23579765), SSIM 0.65138167 (0.73196089)
Finished validation.
Starting training epoch 21
Epoch: 21, MSE 0.00500639 (0.00404859), PSNR 23.00474739 (24.35840167), SSIM 0.66798592 (0.70670583)
Finished training epoch 21
Validate: MSE 0.00482126 (0.00477999), PSNR 23.16839218 (23.23561401), SSIM 0.64175117 (0.70832571)
Finished validation.
Starting training epoch 22
Epoch: 22, MSE 0.00422284 (0.00407971), PSNR 23.74395370 (24.25672047), SSIM 0.68016785 (0.70233732)
Finished training epoch 22
Validate: MSE 0.00378978 (0.00369205), PSNR 24.21385956 (24.35422421), SSIM 0.65513903 (0.73760451)
Finished validation.
Starting training epoch 23
Epoch: 23, MSE 0.00619553 (0.00411882), PSNR 22.07921791 (24.38570777), SSIM 0.68703490 (0.70480346)
Finished training epoch 23
Validate: MSE 0.00352331 (0.00321092), PSNR 24.53048897 (24.97834566), SSIM 0.66895622 (0.73271902)
Finished validation.
Starting training epoch 24
Epoch: 24, MSE 0.0021221

KeyboardInterrupt: 

<Figure size 640x480 with 0 Axes>

In [37]:
torch.save(model.state_dict(), f'{checkpoints}/last-{losses[0]:.8f}-{losses[1]:.8f}-{losses[2]:.8f}.pth')

In [38]:
# Validate
save_images = True
with torch.no_grad():
    validate(val_loader, model, criterion, save_images, -1)

Validate: MSE 0.00640837 (0.00620033), PSNR 21.93252563 (22.12080662), SSIM 0.65652573 (0.70865668)
Finished validation.


<Figure size 640x480 with 0 Axes>

In [22]:
# # Show images 
# image_pairs = []

# for i in range(10):
#     image_pairs.append((f'{color_imgs}img-{i}-epoch-{best_epoch}.jpg', f'{gray_imgs}img-{i}-epoch-{best_epoch}.jpg'))
    
# for c, g in image_pairs:
#   color = mpimg.imread(c)
#   gray  = mpimg.imread(g)
#   f, axarr = plt.subplots(1, 2)
#   f.set_size_inches(15, 15)
#   axarr[0].imshow(gray, cmap='gray')
#   axarr[1].imshow(color)
#   axarr[0].axis('off'), axarr[1].axis('off')
#   plt.show()