#  Project group 30 
## Image colorization by combining semantic segmentation and autoencoding 

Install dependencies

In [None]:
import torch
import torch.utils.data as data
from torch.utils.data import TensorDataset, DataLoader
import torchvision.transforms
from torchvision.transforms import ToTensor
from torch import nn, optim
from torchvision.transforms import Compose, RandomGrayscale, ColorJitter, RandomHorizontalFlip, Resize, Normalize

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

from pathlib import Path

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors

from PIL import Image, ImageOps 

#!pip install kornia
from matplotlib.colors import LinearSegmentedColormap
import kornia

import glob
import os

Path 

In [None]:
train_path = "\Train"

val_path = "\Validation"

test_path = "\Test"

Training loop for autoencoder

In [None]:
def training_loop_autoencoder(model, optimizer, loss_fn, train_loader, val_loader, num_epochs, print_every):
    
    print("Starting training")
    device = torch.device("cuda" if torch.cuda.is_available() 
                                  else "cpu")
    device = torch.device("cpu")
    model.to(device)
    train_losses, val_losses = [], []
           
    for epoch in range(1, num_epochs+1):
        model, train_loss = train_epoch_autoencoder(model,
                                                   optimizer,
                                                   loss_fn,
                                                   train_loader,
                                                   val_loader,
                                                   device,
                                                   print_every)
        # Validate fucks up
        val_loss = validate_autoencoder(model, loss_fn, val_loader, device)
        
        print(f"Epoch {epoch}/{num_epochs}: "
              f"Train loss: {sum(train_loss)/len(train_loss):.3f}, "
              f"Val. loss: {val_loss:.3f}, ")
        train_losses.extend(train_loss)
        val_losses.append(val_loss)
        
        file_name = 'loss_autoencoder.txt';
        with open(file_name, "a") as file:
            file.write(f"Train loss epoch: {sum(train_loss_batches[-print_every:])/print_every:.3f}, Val. loss epoch: {val_loss:.3f}\n")
                       
        # Save model to disk after each epoch
        torch.save(model.state_dict(), 'autoencoder_model.pth')
        
    return model, train_losses, val_losses

def train_epoch_autoencoder(model, optimizer, loss_fn, train_loader, val_loader, device, print_every):
    # Train:
    model.train()
    train_loss_batches = []
    num_batches = len(train_loader)
    for batch_index, (x, y) in enumerate(train_loader, 1):    
        inputs, labels = x.to(device), y.to(device)
        inputs = inputs[None, :]
        
        optimizer.zero_grad()
        z = model.forward(inputs)
        z = torch.squeeze(z)
        #z = torch.round((z+1)*16) # 32 classes
        loss = loss_fn(z, labels.float())
        loss.backward()
        optimizer.step()
        train_loss_batches.append(loss.item())

        print(f"\tBatch {batch_index}/{num_batches}")
        # If you want to print your progress more often than every epoch you can
        # set `print_every` to the number of batches you want between every status update.
        # Note that the print out will trigger a full validation on the full val. set => slows down training
        if print_every is not None and batch_index % print_every == 0:
            val_loss = validate_autoencoder(model, loss_fn, val_loader, device)
            model.train()
            print(f"\tTrain loss: {sum(train_loss_batches[-print_every:])/print_every:.3f}, "
                  f"\tVal. loss: {val_loss:.3f}, ")

            file_name = 'loss_autoencoder.txt';
            with open(file_name, "a") as file:
                file.write(f"Train loss batch: {sum(train_loss_batches[-print_every:])/print_every:.3f}, Val. loss batch: {val_loss:.3f}\n")

    return model, train_loss_batches

def validate_autoencoder(model, loss_fn, val_loader, device):
    val_loss_cum = 0
    model.eval()
    with torch.no_grad():
        for batch_index, (x, y) in enumerate(val_loader, 1):
            inputs, labels = x.to(device), y.to(device)
            inputs = inputs[None, :]
            
            z_ = model.forward(inputs)
            z = torch.squeeze(z_)
            
            batch_loss = loss_fn(z, labels.float())
            val_loss_cum += batch_loss.item()      
            
            # Visual test
            if batch_index == 1:
                f, axarr = plt.subplots(1,4)
                
                # With static cmap
                L = inputs[0,0,:,:].to(device)
                A = z[0,0,:,:].to(device)
                B = z[0,1,:,:].to(device)
                
                Lfill = 50*torch.ones(L.shape).to(device)
                Zfill = torch.zeros(L.shape).to(device)
                
                A_vis = kornia.color.lab_to_rgb(torch.stack((Lfill, A, Zfill),0)).permute(1, 2, 0)
                B_vis = kornia.color.lab_to_rgb(torch.stack((Lfill, Zfill, B),0)).permute(1, 2, 0)
                full_vis = kornia.color.lab_to_rgb(torch.stack((L, A, B),0)).permute(1, 2, 0)
                
                axarr[0].imshow(L,'grey')
                axarr[1].imshow(A_vis)
                axarr[2].imshow(B_vis)
                axarr[3].imshow(full_vis)
                
                plt.show()
                # Save the model each time we plot (REMOVE LATER)
                torch.save(model.state_dict(), 'autoencoder_model.pth')
                
    return val_loss_cum/len(val_loader)


Dataloader construction for the autoencoder

In [None]:
class DataLoaderAutoEncoder(data.Dataset):
    def __init__(self, folder_path, transform):
        super().__init__()
        
        #print(os.getcwd() + "\Dataset" + folder_path + "\Images\*.png")
        self.transform = transform
        
        self.img_files = glob.glob(os.getcwd() + "\Dataset" + folder_path + "\Images\*.png")
        self.mask_files = glob.glob(os.getcwd() + "\Dataset" + folder_path + "\Labels\*.png")

    def __getitem__(self, index):
            img_path = self.img_files[index]
            img_rgb = Image.open(img_path)

            # Perform transforms, if any.
            if self.transform is not None:
                label = self.transform(img_rgb)
            
            label = np.array(label)
            np.moveaxis(label, 0, -1).shape
            label = torch.from_numpy(label).float()
            label = label/255
            label = kornia.color.rgb_to_lab(label.permute(2, 0, 1))
            data = label[0, :, :]
            label = label[1:, :, :]
     
            return data, label

    def __len__(self):
        return len(self.img_files)

Autoencoder model

In [None]:
class conv_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_c)
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_c)
        self.relu = nn.ReLU()
            
    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        return x

    
class encoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv = conv_block(in_c, out_c)
        self.pool = nn.MaxPool2d((2, 2))
    def forward(self, inputs):
        x = self.conv(inputs)
        p = self.pool(x)
        return x, p

class decoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
        self.conv = conv_block(out_c+out_c, out_c)
    def forward(self, inputs, skip):
        x = self.up(inputs) 
        x = torch.cat([x, skip], axis=1)
        x = self.conv(x)
        return x
        
class build_unet_autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        """ Encoder """
        self.e1 = encoder_block(1, 64)
        self.e2 = encoder_block(64, 128)
        self.e3 = encoder_block(128, 256)
        self.e4 = encoder_block(256, 512)
        """ Bottleneck """
        self.b = conv_block(512, 1024)
        """ Decoder """
        self.d1 = decoder_block(1024, 512)
        self.d2 = decoder_block(512, 256)
        self.d3 = decoder_block(256, 128)
        self.d4 = decoder_block(128, 64) 
        """ Classifier """
        self.outputs = nn.Conv2d(64, 2, kernel_size=1, padding=0)
    def forward(self, inputs):
        """ Encoder """
        inputs = inputs.permute(1,0,2,3)
        s1, p1 = self.e1(inputs)  
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)
        """ Bottleneck """
        b = self.b(p4)
        """ Decoder """
        d1 = self.d1(b, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)
        """ Classifier """
        outputs = self.outputs(d4)
        return outputs

DataLoader autoencoder

In [None]:
d = 256
train_transforms = Compose([Resize((d,d), antialias=True)])

train_set_autoencoder = DataLoaderAutoEncoder(train_path, train_transforms)
val_set_autoencoder = DataLoaderAutoEncoder(val_path, train_transforms)

batch_size = 16
train_dataloader_autoencoder = DataLoader(train_set_autoencoder, batch_size=batch_size, shuffle=True)
val_dataloader_autoencoder = DataLoader(val_set_autoencoder, batch_size=batch_size, shuffle=True)

Visualization test for autoencoder part

In [None]:
example_transform = Compose([Resize((d,d), antialias=True)])
example_dataset = DataLoaderAutoEncoder(train_path,example_transform)

img, label = example_dataset[30]
L = img
A = (label.permute(1,2,0))[:,:,0]
B = (label.permute(1,2,0))[:,:,1]

Lfill = 50*torch.ones(L.shape)
Zfill = torch.zeros(L.shape)

A_vis = kornia.color.lab_to_rgb(torch.stack((Lfill, A, Zfill),0)).permute(1, 2, 0)
B_vis = kornia.color.lab_to_rgb(torch.stack((Lfill, Zfill, B),0)).permute(1, 2, 0)

full_vis = kornia.color.lab_to_rgb(torch.stack((L, A, B),0)).permute(1, 2, 0)

plt.figure(0)
plt.imshow(img, 'grey')
plt.figure(1)
plt.imshow(A_vis)
plt.figure(2)
plt.imshow(B_vis)
plt.figure(3)
plt.imshow(full_vis)

Training of the autoencoder

In [None]:
# False if you want to create a new file saving to and True if you want to continue training on last file saved
train_from_disk_autoencoder = False

if train_from_disk_autoencoder:  
    # Read model from disk
    print("Reading model from disk")
    pretrained = torch.load('autoencoder_model.pth', map_location=lambda storage, loc: storage)
    autoencoder_model = build_unet_autoencoder()
    autoencoder_model.load_state_dict(torch.load('autoencoder_model.pth'))
    print("Continue writing to loss textfile")
    file_name = 'loss_autoencoder.txt';
    with open(file_name, "a") as file:
        file.write('Continue:\n') 
else:
    # Create new model
    print("Creating new model")
    autoencoder_model = build_unet_autoencoder()
    file_name = 'loss_autoencoder.txt';
    print("Starting over writing to loss textfile")
    with open(file_name, "w") as file:
        file.write('Losses:\n')

loss_fn = nn.MSELoss()
optimizer = optim.Adam(autoencoder_model.parameters(), lr =0.001)
train_loader = train_dataloader_autoencoder
val_loader = val_dataloader_autoencoder
num_epochs = 10
print_every = 1

autoencoder_model, autoencoder_train_losses, autoencoder_val_losses= training_loop_autoencoder(autoencoder_model, optimizer, loss_fn, train_loader, val_loader, num_epochs, print_every)