# Color grey-scale images using UNET 🎨
In the dataset there are a few images that are in greyscale. It adds complexity to the task so we may want to colorize them to make the task easier.

In this notebook, I'll try to fit a UNet to colorize these images.

**This notebook uses Pytorch Lightning ⚡**<br>

In [None]:
import numpy as np
import pandas as pd
import os
import cv2
import albumentations
import warnings
import matplotlib.pyplot as plt
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from tqdm import tqdm
from random import random
from albumentations.pytorch.transforms import ToTensorV2
from torch.utils.data import Dataset
from torchvision import transforms
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.loggers import WandbLogger
warnings.filterwarnings("ignore")

# Configuration ⚙️


In [None]:
class CFG:
    SEED = 69
    ### Dataset
    BATCH_SIZE = 32
    IMAGE_SIZE = 224#380
    NUM_WORKERS = 2
    NOISE_STD = 0.15
    BLUR_SIGMA = (0.1, 2.0)
    BLUR_KERNEL_SIZE = 5
    ## Training
    EPOCHS = 4
    LR = 0.001
    SCHEDULER_PATIENCE = 600
    SCHEDULER_FACTOR = 0.1
    MODEL_PATH="model.ckpt"

# Logging 📄
Logging using Weights and Biases 🪄🐝

In [None]:
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()

# I have saved my API token with "wandb_api" as Label. 
# If you use some other Label make sure to change the same below. 
wandb_api = user_secrets.get_secret("wandb_api") 

wandb.login(key=wandb_api)

# Dataset 🖼️
Implementing the dataset as a Pytorch Dataset as required by Pytorch Lightning. It applies some augmentations:
* Noise: We add some gaussian noise to denoise old images
* Gaussian blur: Can help when the quality was degraded with time

In [None]:
BASE_PATH = "../input/happywhale-enhanced-dataset-light"

class WandDID(Dataset):
    def __init__(self, data, return_colored=True, folder="train_images"):
        self.return_colored = return_colored
        self.base_path = os.path.join(BASE_PATH, folder)
        if "inference_image" not in data.columns:
            data["inference_image"] = data["image"]
        self.data = data
        # Augmentations
        transformations = albumentations.Compose([
            albumentations.Normalize(),
            ToTensorV2(p=1.0)
        ])

        def make_transform(transform=False):
            def f(image):
                if transform:
                    image_np = np.array(image)
                    augmented = transform(image=image_np)
                return augmented
            return f

        self.transforms = transforms.Compose([
            transforms.Lambda(make_transform(transformations)),
        ])
        
    def __getitem__(self, idx):
        colored_image = self.preprocess(self.data["inference_image"].iloc[idx])
        #label = self.data["individual_id_integer"].iloc[idx]
        greyscale_image = colored_image.mean(0).unsqueeze(0)
        
        if not self.return_colored:
            return greyscale_image
        
        if random() < 0.5:
            sigma = torch.empty(1).uniform_(CFG.BLUR_SIGMA[0], CFG.BLUR_SIGMA[1]).item()
            greyscale_image = transforms.functional.gaussian_blur(greyscale_image, CFG.BLUR_KERNEL_SIZE, [sigma, sigma])
            
        if random() < 0.5:
            greyscale_image += torch.randn_like(greyscale_image) * CFG.NOISE_STD
        
        return colored_image, greyscale_image#, torch.tensor(label, dtype=torch.long)
    
    def preprocess(self, image):
        image = os.path.join(self.base_path, image)
        image = cv2.imread(image)[:, :, ::-1]
        if self.transforms is not None:
            image = self.transforms(image)["image"]
        return image
    
    def plot_sample(self, idx):
        image = self.data["image"].iloc[idx]
        image = os.path.join(self.base_path, image)
        image = cv2.imread(image)[:, :, ::-1]
        plt.title("{} ({})".format(
            self.data["individual_id"].iloc[idx],
            self.data["species"].iloc[idx]
        ))
        plt.imshow(image)
        plt.show()
    
    def __len__(self):
        return len(self.data)

In [None]:
train_data = pd.read_csv(os.path.join(BASE_PATH, "train.csv"))
train_greyscale_map = np.load("../input/greyscale-images/train_grey_scale_mask.npy")
train_grey_data = train_data.loc[train_greyscale_map]
train_data = train_data.loc[~train_greyscale_map]

test_data = pd.read_csv(os.path.join(BASE_PATH, "sample_submission.csv"))
test_greyscale_map = np.load("../input/greyscale-images/test_grey_scale_mask.npy")
test_grey_data = test_data.loc[test_greyscale_map]
test_data = test_data.loc[~test_greyscale_map]

train_dataset = WandDID(train_data)
# Dataloader
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=CFG.BATCH_SIZE,
    num_workers=CFG.NUM_WORKERS,
    pin_memory=True,
    shuffle=True
)

# Model 🤖
Implements the model as a Pytorch lightning module. We use a classical U-Net architecture
![](https://datascientest.com/wp-content/uploads/2021/05/u-net-architecture-1024x682.png)

In [None]:
"""
    https://github.com/milesial/Pytorch-UNet/tree/master/unet
"""

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

In [None]:
class WandDRecolorize(LightningModule):
    def __init__(
        self,
        learning_rate: float = CFG.LR,
    ) -> None:
        """
            learning_rate: Learning rate
        """
        super().__init__()
        self.save_hyperparameters()

        self.net = UNet(n_channels=1, n_classes=3)

        self.mse = nn.MSELoss()
        self.huber = nn.HuberLoss()
        self.criterion = self.huber

    def configure_optimizers(self):
        opt = torch.optim.Adam(self.net.parameters(), lr=self.hparams.learning_rate)
        return {
            "optimizer": opt,
            "lr_scheduler": {
                "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(opt, patience=CFG.SCHEDULER_PATIENCE, factor=CFG.SCHEDULER_FACTOR, verbose=True),
                "monitor": "train/loss",
                "interval": "step"
            }
        }

    def forward(self, x):
        x = self.net(x)
        y = x[:, 0, :, :]
        cb = x[:, 1, :, :]
        cr = x[:, 2, :, :]
        r = y + 1.402 * cr
        g = y - 0.34414 * cb - 0.71414 * cr
        b = y + 1.772 * cb
        x = torch.stack([r, g, b], dim=1)
        return x
    
    def training_step(self, batch, batch_idx):
        colored_batch, greyscale_batch = batch

        preds = self(greyscale_batch)
        loss = self.criterion(preds, colored_batch)
        
        self.log("train/loss", loss)
        self.log("train/mse", self.mse(preds, colored_batch))
        self.log("train/huber", self.huber(preds, colored_batch))
        
        return loss

# Training 🏃
Create a Pytorch lightning trainer with our configuration, and run our model on our dataset:

In [None]:
"""
    Callbacks
"""

class WandbImageCallback(pl.Callback):
    def __init__(self, data, display_frequency=300):
        super().__init__()
        self.display_frequency = display_frequency
        self.dataset = WandDID(data, return_colored=False)
        self.loader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=CFG.BATCH_SIZE,
            num_workers=CFG.NUM_WORKERS,
            pin_memory=True,
            shuffle=False
        )
        self.mean=torch.tensor([0.485, 0.456, 0.406]).unsqueeze(-1).unsqueeze(-1)
        self.std=torch.tensor([0.229, 0.224, 0.225]).unsqueeze(-1).unsqueeze(-1)
        self.batch = next(iter(self.loader))
        
    def unnormalize(self, x):
        x = x.detach().cpu() * self.std + self.mean
        return (x * 255).permute(1, 2, 0).clamp(0, 255).numpy().astype(np.uint8)
        
    def on_batch_end(self, trainer: pl.Trainer, pl_module: LightningModule):
        if trainer.global_step % self.display_frequency > 0:
            return
        
        with torch.no_grad():
            pl_module.eval()
            images = pl_module(self.batch.to(pl_module.device))
            pl_module.train()

        wandb_images = []
        for image in images:
            image = self.unnormalize(image)
            wandb_images.append(wandb.Image(image))
        
        trainer.logger.experiment.log({
            "val/examples": wandb_images,
            "global_step": trainer.global_step
        })

In [None]:
model = WandDRecolorize()
wandb_logger = WandbLogger(project="W&D - recolorization")
# Trainer 
trainer = Trainer(
    profiler="simple", # Profiling
    gpus=1,# Use the one GPU we have
    max_epochs=CFG.EPOCHS,
    logger=wandb_logger,
    log_every_n_steps=10,
    callbacks=[WandbImageCallback(train_grey_data)]
)
# Let's go ⚡
trainer.fit(model, train_loader)
trainer.save_checkpoint(CFG.MODEL_PATH)

# Inference 🔮


In [None]:
mean=torch.tensor([0.485, 0.456, 0.406]).unsqueeze(-1).unsqueeze(-1)
std=torch.tensor([0.229, 0.224, 0.225]).unsqueeze(-1).unsqueeze(-1)
!mkdir train_images
!mkdir test_images

In [None]:
train_grey_dataset = WandDID(train_grey_data, return_colored=False)
train_grey_loader = torch.utils.data.DataLoader(
    train_grey_dataset,
    batch_size=CFG.BATCH_SIZE,
    num_workers=CFG.NUM_WORKERS,
    pin_memory=True,
    shuffle=False
)
preds = trainer.predict(model, dataloaders=train_grey_loader)
preds = torch.cat(preds, dim=0)

for d, image in tqdm(zip(train_grey_data.iterrows(), preds)):
    image = image.detach().cpu() * std + mean
    image = (image * 255).permute(1, 2, 0).clamp(0, 255).numpy().astype(np.uint8)
    img_path = os.path.join("train_images", d[1]["inference_image"])
    cv2.imwrite(img_path, image[:, :, ::-1])

In [None]:
test_grey_dataset = WandDID(test_grey_data, return_colored=False, folder="test_images")
test_grey_loader = torch.utils.data.DataLoader(
    test_grey_dataset,
    batch_size=CFG.BATCH_SIZE,
    num_workers=CFG.NUM_WORKERS,
    pin_memory=True,
    shuffle=False
)
preds = trainer.predict(model, dataloaders=test_grey_loader)
preds = torch.cat(preds, dim=0)

for d, image in tqdm(zip(test_grey_data.iterrows(), preds)):
    image = image.detach().cpu() * std + mean
    image = (image * 255).permute(1, 2, 0).clamp(0, 255).numpy().astype(np.uint8)
    img_path = os.path.join("test_images", d[1]["inference_image"])
    cv2.imwrite(img_path, image[:, :, ::-1])

In [None]:
# Log our model in wandb and finish the run
try:
    artifact = wandb.log_artifact(CFG.MODEL_PATH, name='w_and_d-colorize', type='model') 
    wandb_logger.finalize("success")
    wandb.finish()
except Exception as e:
    print(e)

# Inspect the samples 🕵️

In [None]:
def plot_images(batch, folder="train_images", row=4, col=4):
    """
        Copied and adapted from https://www.kaggle.com/awsaf49/happywhale-data-distribution
    """
    plt.figure(figsize=(col*3, row*3))
    for i in range(row*col):
        plt.subplot(row, col, i+1)
        path = os.path.join(BASE_PATH, folder,  batch["inference_image"].iloc[i])
        img_grey = cv2.imread(path)
        path = os.path.join(".", folder,  batch["inference_image"].iloc[i])
        img_color = cv2.imread(path)
        img = np.concatenate([img_grey, img_color], axis=1)[:, :, ::-1]
        plt.imshow(img)
        plt.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
print("Train images")
plot_images(train_grey_data)

In [None]:
print("Test images")
plot_images(test_grey_data, folder="test_images")

We see that it's quite good but not perfect. There are sometimes unconsistant colors in the water, some shades of blue on the individuals and some pictures that stay grey-ish. We could probably improve it by cleaning our data, because their are still some images that are not proper grey-scale but almost...<br>

## Some ideas for improvement
* Adding a GAN-loss to have even more realistic results
* Use a style-gan architecture
* Condition the model on an embedding of the individual
* Adding down-sample augmentations

In [None]:
!rm -rf ./"W&D - recolorization"
!rm -rf wandb
!zip -r test_images.zip test_images
!zip -r train_images.zip train_images

# Conclusion 🤷
We know have a way to colorize and denoise greyscales images. We can still improve the method by removing the "almost greyscale samples"...

👍 If you found this notebook helpful or learned something please consider giving an upvote, and if you disagree with the content, I'll be pleased to dicsuss it with you in the comments.

😊 Happy Kaggling everyone !

![](https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Thats_all_folks.svg/2560px-Thats_all_folks.svg.png)