# Grayscale

> Grayscale Image Colorization with GANs

## Setup

Setup the environment by running the following to format the dataset and define the model.

In [6]:
# Download and format the dataset

%rm -rf grayscale-dataset-data

! if [ ! -d grayscale-dataset-data ] ; \
  then wget https://github.com/anthonytedja/grayscale-dataset/archive/refs/tags/data.zip; \
    unzip data.zip; \
    rm data.zip; \
fi

!mkdir grayscale-dataset-data/test
%cd grayscale-dataset-data/dataset
!mv $(ls | tail -n 1000) ../test
%cd ..

!mv dataset train
%cd ..

(Optional) Run the below cell to train and test on a smaller subset of the dataset.

In [None]:
!ls RESULT/grayscale-dataset-data/ | wc -l

!mkdir grayscale-dataset-data/train_2
!mkdir grayscale-dataset-data/test_2

%cd grayscale-dataset-data/train/
!mv $(ls | tail -n 5500) ../train_2
%cd ../test
!mv $(ls | tail -n 990) ../test_2
%cd ../..

!ls grayscale-dataset-data/test | wc -l
!ls grayscale-dataset-data/test_2 | wc -l
!ls grayscale-dataset-data/train | wc -l
!ls grayscale-dataset-data/train_2 | wc -l

In [None]:
import glob
import os

from functools import partial
import torch.nn as nn

import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from tqdm import tqdm
from torch.utils.data import DataLoader

In [None]:
class Config:
    def __init__(self):

        self.DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Directories
        self.ROOT_DIR = os.path.abspath('./')
        self.DATASET = "grayscale-dataset-data/"

        self.TRAIN_PATH = os.path.join(self.ROOT_DIR, self.DATASET, "train/")
        self.TEST_PATH = os.path.join(self.ROOT_DIR, self.DATASET, "test/")

        self.MODEL_DIR = 'models'
        self.MODEL_C = self.MODEL_DIR + '/colorization.pt'
        self.MODEL_D = self.MODEL_DIR + '/discriminator.pt'
        self.OUTPUT_PATH = os.path.join(self.ROOT_DIR, 'results/'+self.DATASET)

        # Data
        self.IMAGE_SIZE = 224
        self.BATCH_SIZE = 10

        # Training
        self.EPOCHS = 5

        self.GRADIENT_PENALTY_WEIGHT = 10
        self.CHECK_PER = 100
        self.LR = 2e-5

config = Config()

In [None]:
def deprocess(imgs):
    imgs = imgs * 255
    imgs[imgs > 255] = 255
    imgs[imgs < 0] = 0
    return imgs.astype(np.uint8)


def reconstruct(batchX, predictedY):
    result = np.concatenate((batchX, predictedY))
    result = np.transpose(result, (1, 2, 0))
    result = cv2.cvtColor(result, cv2.COLOR_Lab2BGR)
    return result


def wasserstein_loss(inputs, real_or_fake):
    return -torch.mean(inputs) if real_or_fake else torch.mean(inputs)


def random_weighted_average(inputs):
    weight = torch.rand((config.BATCH_SIZE, 1, 1, 1)).to(config.DEVICE)
    return (weight * inputs[0]) + ((1 - weight) * inputs[1])


def gradient_penalty_loss(y_pred, averaged_samples, gradient_penalty_weight):
    gradients = torch.autograd.grad(y_pred, averaged_samples,
                                    grad_outputs=torch.ones(y_pred.size(), device=config.DEVICE),
                                    create_graph=True, retain_graph=True, only_inputs=True)[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = torch.mean(((gradients + 1e-16).norm(2, dim=1) - 1) ** 2) * gradient_penalty_weight
    return gradient_penalty


def partial_gp_loss(y_pred, averaged_samples, gradient_penalty_weight):
    gradients = torch.autograd.grad(
        y_pred, averaged_samples,
        grad_outputs=torch.ones(y_pred.size(), device=config.DEVICE),
        )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = (((gradients+1e-16).norm(2, dim=1) - 1) ** 2).mean() * gradient_penalty_weight
    return gradient_penalty.mean()


def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [None]:
class ColorizeDataLoader(Dataset):
    def __init__(self, color_path, img_size=224):

        if not os.path.isdir(color_path):
            raise Exception("The path is not a directory")
        if img_size != 224:
            raise Exception("The image size must be 224")

        self.color_path = color_path
        self.img_size = config.IMAGE_SIZE
        self.color_channels = 3
        self.gray_channels = 1
        self.data_color = []
        self.filelist = os.listdir(self.color_path)[:None]
        self.size = len(self.filelist)

        for path_images_name in glob.glob(self.color_path + '*'):
            self.data_color.append(path_images_name)

        if len(self.data_color) == 0:
            print(self.color_path)
            raise Exception("No images found in directory:", self.color_path)

    def __len__(self):
        return len(self.data_color)

    def __getitem__(self, idx):
        grey_img, color_img, original_images_shape = self.read_img(idx)
        return self.transform(grey_img), self.transform(color_img), original_images_shape

    def transform(self, img):
        trans = transforms.Compose([
            transforms.ToTensor(),
        ])
        return trans(img)

    def read_img(self, idx):
        img_color_path = self.data_color[idx]
        img_color = cv2.imread(img_color_path)

        lab_img = cv2.cvtColor(
            cv2.resize(img_color, (self.img_size, self.img_size)),
            cv2.COLOR_BGR2Lab)
        original_shape = img_color.shape

        return (
            np.reshape(lab_img[:, :, 0], (self.img_size, self.img_size, 1)),
            np.reshape(lab_img[:, :, 1:],(self.img_size, self.img_size,2)),
            original_shape,
            )

def testing_colorize_dataloader():
    color_loader = ColorizeDataLoader(config.TRAIN_PATH)
    grey_img, color_img, original_images_shape = color_loader[0]

    print('grey_img: ', grey_img.shape)
    print('color_img: ', color_img.shape)
    print('original_images_shape: ', original_images_shape)

def checking_data_loader():
    color_loader = ColorizeDataLoader(config.TRAIN_PATH)
    test_dataloader = DataLoader(
        color_loader, batch_size=config.BATCH_SIZE,
        shuffle=True, num_workers=2, drop_last=True)
    for idx, (grey_img, color_img, _) in enumerate(tqdm(test_dataloader)):
        continue
    print('Validation passed')

In [None]:
import warnings
import torchvision.models as models

from torchvision.utils import save_image

In [None]:
class ConvBlockDiscriminator(nn.Module):
    def __init__(
            self, in_channels, out_channels, kernel_size, stride, padding):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x):
        return self.conv(x)

class Discriminator(nn.Module):
    def __init__(self, input_size, in_channels=3):
        super().__init__()
        self.input_size = input_size
        self.in_channels = in_channels
        self.conv_1 = nn.Sequential(
            nn.Conv2d(self.in_channels, 64, 4, 2, 1),
            nn.LeakyReLU(0.2)
        )
        self.conv_2 = ConvBlockDiscriminator(64, 128, 4, 2, 1)
        self.conv_3 = ConvBlockDiscriminator(128, 256, 4, 2, 1)
        self.conv_4 = ConvBlockDiscriminator(256, 512, 3, 1, 1)
        self.conv_5 = ConvBlockDiscriminator(512, 1, 3, 1, 1)

    def forward(self, x):
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = self.conv_3(x)
        x = self.conv_4(x)
        x = self.conv_5(x)
        return torch.sigmoid(x)

warnings.filterwarnings("ignore")


class ConvBlock(nn.Module):
    def __init__(
            self, in_channels, out_channels,
            kernel_size, stride, padding):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, kernel_size,
                stride, padding, padding_mode='reflect'),
            nn.ReLU(0.2),
            nn.BatchNorm2d(out_channels),
        )

    def forward(self, x):
        return self.conv(x)

class SimpleConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, kernel_size,
                stride, padding),
            nn.ReLU(0.2)
        )

    def forward(self, x):
        return self.conv(x)


class Colorization(nn.Module):
    def __init__(self, input_size=224):
        super().__init__()
        if input_size != 224:
            raise ValueError("Input size must be 224")

        vgg = models.vgg16(pretrained=True)
        self.vgg = nn.Sequential(*list(vgg.features.children())[:-8])

        self.global_features_1 = ConvBlock(512, 512, 3, 2, 1)
        self.global_features_2 = ConvBlock(512, 512, 3, 1, 1)
        self.global_features_3 = ConvBlock(512, 512, 3, 2, 1)
        self.global_features_4 = ConvBlock(512, 512, 3, 1, 1)

        self.flatten = nn.Flatten()

        self.fully_connected_1 = nn.Sequential(
            nn.Linear(512 * 7 * 7, 1024),
            nn.Linear(1024, 512),
            nn.Linear(512, 256),
        )

        self.fully_connected_2 = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.Linear(4096, 4096),
            nn.Linear(4096, 1000),
            nn.Softmax()
        )

        self.mid_level_features_1 = ConvBlock(512, 512, 3, 1, 1)
        self.mid_level_features_2 = ConvBlock(512, 256, 3, 1, 1)

        self.output_1 = SimpleConvBlock(512, 256, 1, 1, 0)
        self.output_2 = SimpleConvBlock(256, 128, 3, 1, 1)
        self.output_3 = SimpleConvBlock(128, 64, 3, 1, 1)
        self.output_4 = SimpleConvBlock(64, 64, 3, 1, 1)
        self.output_5 = SimpleConvBlock(64, 32, 3, 1, 1)
        self.output_6 = nn.Sequential(
            nn.Conv2d(32, 2, 3, 1, 1),
            nn.Sigmoid()
        )

        self.up_sample = nn.Upsample(scale_factor=2)

    def forward(self, x):
        x_model = self.vgg(x)
        global_features = self.global_features_1(x_model)
        global_features = self.global_features_2(global_features)
        global_features = self.global_features_3(global_features)
        global_features = self.global_features_4(global_features)

        global_features2 = self.flatten(global_features)
        global_features2 = self.fully_connected_1(global_features2)
        global_features2 = global_features2.repeat(28, 28, 1, 1)
        global_features2 = global_features2.permute(2, 3, 0, 1)

        global_features_class = self.flatten(global_features)
        global_features_class = self.fully_connected_2(global_features_class)

        mid_level_features = self.mid_level_features_1(x_model)
        mid_level_features = self.mid_level_features_2(mid_level_features)

        fusion = torch.cat((mid_level_features, global_features2), dim=1)

        output = self.output_1(fusion)
        output = self.output_2(output)
        output = self.up_sample(output)

        output = self.output_3(output)
        output = self.output_4(output)
        output = self.up_sample(output)

        output = self.output_5(output)
        output = self.output_6(output)
        output = self.up_sample(output)

        return output, global_features_class

def test_colorization():
    print('Test Colorization')
    x = torch.randn((16, 1, 224, 224))
    x_3 = torch.cat([x, x, x], dim=1)
    save_image(x, 'aaa.png')
    x_3 = x_3.to(config.DEVICE)
    colorization = Colorization(input_size=224).to(config.DEVICE)
    colorization.apply(initialize_weights)
    pred, _ = colorization(x_3)
    print(type(pred))
    print(pred[0].shape)
    save_image(pred[0][0], '111.png')
    save_image(pred[0][1], '222.png')
    x = x.to(config.DEVICE)
    pred_full = torch.cat([x, pred], dim=1)
    save_image(pred_full, '333.png')


def test_colorization_dataloader():
    print('Test Colorization Dataloader')
    colorization = Colorization(input_size=224).to(config.DEVICE)
    test_dataloader = ColorizeDataLoader(config.TEST_PATH)
    test_dataloader = DataLoader(
        test_dataloader, batch_size=config.BATCH_SIZE,
        shuffle=True, num_workers=4)
    for idx, (gray, real, _) in enumerate(tqdm(test_dataloader)):
        l_3 = np.tile(gray, [1, 3, 1, 1])
        l_3 = torch.from_numpy(l_3).to(config.DEVICE)
        colorization = Colorization(input_size=224).to(config.DEVICE)
        colored, _ = colorization(l_3)
        colored = colored.detach()
        if not os.path.exists(config.OUTPUT_PATH):
            os.makedirs(config.OUTPUT_PATH)
        images_path = config.OUTPUT_PATH + str(idx) + '.png'
        print('real', real.shape)
        real_path = config.OUTPUT_PATH + str(idx) + '_real.png'
        save_image(colored, images_path)
        save_image(real, real_path)
        break
    print('Finished Data Loading')

In [None]:
import json
import pandas as pd

import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.optim import Adam

In [None]:
def model(train_data, test_data, epochs, version=0.0):

    # Load Data
    train_dataloader = ColorizeDataLoader(train_data)
    train_dataloader = DataLoader(
        train_dataloader, batch_size=config.BATCH_SIZE,
        shuffle=True, num_workers=2, drop_last=True)

    test_dataloader = ColorizeDataLoader(test_data)
    test_dataloader = DataLoader(
        test_dataloader, batch_size=config.BATCH_SIZE,
        shuffle=True, num_workers=2, drop_last=True)

    # Load Models
    discriminator = Discriminator(input_size=224).to(config.DEVICE)
    discriminator.apply(initialize_weights)
    colorization_model = Colorization(input_size=224).to(config.DEVICE)
    vgg_model_f = models.vgg16(pretrained=True).to(config.DEVICE)
    vgg_model_f.requires_grad_(False)

    optimizer_g = Adam(
        colorization_model.parameters(), lr=config.LR, betas=(0.5, 0.999)
    )
    optimizer_d = Adam(
        discriminator.parameters(), lr=config.LR, betas=(0.5, 0.999)
    )

    # Initialize Loss
    KKLDivergence = nn.KLDivLoss()
    MSE = nn.MSELoss()

    for epoch in range(epochs):
        print(f'EPOCH {epoch} / {epochs}')
        print('-' * 30)

        for idx, (trainL, trainAB, _) in enumerate(tqdm(train_dataloader)):
            trainL = trainL.to(config.DEVICE)
            trainAB = trainAB.to(config.DEVICE)

            l_3 = torch.cat([trainL, trainL, trainL], dim=1)
            pred_class_vgg = F.softmax(vgg_model_f(l_3))
            
            # Generator Training
            optimizer_g.zero_grad()
            pred_AB, pred_class_c = colorization_model(l_3)
            pred_LAB_C = torch.cat([trainL, pred_AB], dim=1)
            with torch.no_grad():
                dis_C = discriminator(pred_LAB_C)
            KLD_loss = KKLDivergence(
                F.softmax(pred_class_c).detach().float(),
                pred_class_vgg.detach().float()
                ) * 0.003
            MSE_loss = MSE(pred_AB.float(), trainAB.float())
            W_loss = wasserstein_loss(dis_C, True) * 0.1
            g_loss = KLD_loss + MSE_loss + W_loss

            # Discriminator Training
            for param in discriminator.parameters():
                param.requires_grad = True
            optimizer_d.zero_grad()
            pred_LAB_D = torch.cat([trainL, pred_AB], dim=1)
            dis_pred = discriminator(pred_LAB_D)
            dis_pred = dis_pred.mean()

            true_LAB_D = torch.cat([trainL, trainAB], dim=1)
            dis_true = discriminator(true_LAB_D)
            dis_true = dis_true.mean()

            weights = torch.randn((trainAB.size(0),1,1,1), device=config.DEVICE)
            averaged_samples = (weights * trainAB) + ((1 - weights) * pred_AB)
            averaged_samples = torch.autograd.Variable(averaged_samples, requires_grad=True)
            avg_img = torch.cat([trainL, averaged_samples], dim=1)
            dis_avg = discriminator(avg_img)

            W_loss_true = wasserstein_loss(dis_true, False)
            W_loss_pred = wasserstein_loss(dis_pred, True)
            gp_loss_avg = partial_gp_loss(dis_avg, averaged_samples, config.GRADIENT_PENALTY_WEIGHT)
            d_loss = W_loss_true + W_loss_pred + gp_loss_avg
            with torch.autograd.set_detect_anomaly(True):
                g_loss.backward(retain_graph=True)
                d_loss.backward()
                optimizer_g.step()
                optimizer_d.step()
            
            if config.CHECK_PER!=-1:
                if idx % config.CHECK_PER == 0:
                    print('\n')
                    print(f"Epoch {epoch} - Batch {idx} - Loss G: {g_loss} - Loss D: {d_loss}")

    if not os.path.exists(config.MODEL_DIR):
        os.makedirs(config.MODEL_DIR)
    torch.save(discriminator.state_dict(), config.MODEL_D)
    torch.save(colorization_model.state_dict(), config.MODEL_C)

def train():
    train_path = config.TRAIN_PATH
    test_path = config.TEST_PATH
    epochs = config.EPOCHS

    print('Training Start')
    print('-' * 30)

    model(train_path, test_path, epochs)

    print('-' * 30)
    print('Training Finished')
    print('-' * 30)

    print('Testing Start')
    print('-' * 30)

    test()

    print('Testing Finished')

    print('-' * 30)
    print('All Finished')


def sample_images(test_data, colorizationModel):
    print('Sampling Images')
    for idx, (gray, ori_ab, _) in enumerate(tqdm(test_data)):
        l_3 = torch.cat([gray, gray, gray], dim=1).to(config.DEVICE)

        with torch.no_grad():
            colored, _ = colorizationModel(l_3)

        gray = gray.detach().cpu().numpy()
        ori_ab = ori_ab.detach().cpu().numpy()
        colored = colored.detach().cpu().numpy()
        print("idx", idx)
        for i in range(config.BATCH_SIZE):
            original_result_red = reconstruct(deprocess(gray)[i], deprocess(colored)[i])
            cv2.imwrite(config.OUTPUT_PATH + str(idx) + '.png', original_result_red)

    print('Sampling Finished')

def test():
    path = config.MODEL_C
    test_dataloader = ColorizeDataLoader(config.TEST_PATH)
    test_dataloader = DataLoader(
        test_dataloader, batch_size=config.BATCH_SIZE,
        shuffle=True, num_workers=2, drop_last=True)

    for idx, (grey_img, color_img, original_images_shape) in enumerate(test_dataloader):
        print(f"{idx} / {len(test_dataloader)}")
        print(f"gray shape: {grey_img.shape}")
        print(f"ori_ab shape: {color_img.shape}")
        print(f"{original_images_shape}")
        break
    colorizationModel = Colorization(input_size=224).to(config.DEVICE)
    if config.DEVICE == 'cpu':
        colorizationModel.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
    else:
        colorizationModel.load_state_dict(torch.load(path))
    colorizationModel.eval()
    sample_images(test_dataloader, colorizationModel)

class Trainer:
    def __init__(self):
        self.device =config.DEVICE

    @staticmethod
    def train():
        train()

class Tester:
    def __init__(self):
        self.device =config.DEVICE

    def test(self):
        test()

class Train:
    def __init__(self):
        self.trainer = Trainer()

    def train(self):
        self.trainer.train()
        print('Training Finished')

if not os.path.exists(config.OUTPUT_PATH):
    os.makedirs(config.OUTPUT_PATH)

## Training

Run the following to train the model and save the model weights. Models are saved to the `models` directory.

In [None]:
Train().train()

## Testing

Run the following to test the model and display the results. The images are saved to the `results` directory.

In [None]:
Tester().test()