## Import

In [1]:
import tensorflow as tf
from tensorflow import keras
import gc
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, LeakyReLU, Activation
from tensorflow.keras.layers import Concatenate, concatenate, Dropout, BatchNormalization, Flatten
from tensorflow.keras.applications import ResNet101
from tensorflow.keras.optimizers import Adam
from keras.initializers import RandomNormal
import numpy as np
import random
import cv2
import matplotlib.pyplot as plt
from skimage import color
import matplotlib.image as mpimg

## Data

In [2]:
class DataGenerator:
    def __init__(self):
        self.file_paths = [
            '/kaggle/input/image-colorization/l/gray_scale.npy',
            '/kaggle/input/image-colorization/ab/ab/ab1.npy',
            '/kaggle/input/image-colorization/ab/ab/ab2.npy',
            '/kaggle/input/image-colorization/ab/ab/ab3.npy'
        ]
        # Memory-map the file to avoid loading the whole file into RAM
        self.channel_memory_map = [np.load(path, mmap_mode='r') for path in self.file_paths]
        self.AB_range_map = [self.channel_memory_map[i].shape[0] for i in range(1, 4)]

    def _is_valid_image(self, l_channel, ab_channel, ab_variance_threshold=0.005, l_contrast_threshold=0.1):
        """Check if the image has sufficient color and contrast to be valid for colorization"""
        # Check AB channel color variance (to ensure color is present)
        ab_variance = np.var(ab_channel)
        if ab_variance < ab_variance_threshold:
            return False  # No sufficient color information

        # Check L channel contrast (to ensure the image is not washed out)
        l_min, l_max = np.min(l_channel), np.max(l_channel)
        l_contrast = (l_max - l_min) / 255.0  # Normalize contrast
        if l_contrast < l_contrast_threshold:
            return False  # Too little contrast

        return True  # Image is valid

    def _normalize_channels(self, l_channel, ab_channel):
        """Normalize AB channels to a [0, 1] range."""
        
        # Skip normalization for L channel
        l_channel_normalized = l_channel

        # Normalize A channel: [42, 226] -> [0, 1]
        a_channel_normalized = (ab_channel[:, :, 0] - 42) / 184.0

        # Normalize B channel: [20, 223] -> [0, 1]
        b_channel_normalized = (ab_channel[:, :, 1] - 20) / 203.0

        # Stack the normalized AB channels
        ab_channel_normalized = np.stack([a_channel_normalized, b_channel_normalized], axis=-1)

        return l_channel_normalized, ab_channel_normalized

    def create_batch(self, batch_size):
        valid_images = []
        while len(valid_images) < batch_size:
            # Select random AB array
            idx = np.random.randint(0, 3)
            
            # Select indices
            AB_indices = np.random.randint(0, self.AB_range_map[idx], batch_size + int(batch_size * 0.2))
            L_difference = sum(self.AB_range_map[:idx])
            L_indices = AB_indices + L_difference

            # Select batch
            l_batch = self.channel_memory_map[0][L_indices]
            ab_batch = self.channel_memory_map[idx + 1][AB_indices]

            # Filter out low-quality images
            for l_img, ab_img in zip(l_batch, ab_batch):
                if self._is_valid_image(l_img, ab_img):
                    # Normalize the L and AB channels before adding to the batch
                    l_normalized, ab_normalized = self._normalize_channels(l_img, ab_img)
                    valid_images.append((l_normalized, ab_normalized))
                    # valid_images.append((l_img, ab_img))
                    if len(valid_images) == batch_size:
                        break  # Stop when we have enough valid images

        l_batch, ab_batch = zip(*valid_images)
        yield np.array(l_batch), np.array(ab_batch)

# Instantiate the DataGenerator
data_generator = DataGenerator()

# Create the generator function to feed data to the model
def generator_wrapper(batch_size):
    while True:  # Infinite loop to provide batches continuously
        l_batch, ab_batch = next(data_generator.create_batch(batch_size))
        # Reshape L-channel for the model input (e.g., [batch_size, 256, 256, 1])
        l_batch = np.expand_dims(l_batch, axis=-1)
        # Yield the batch (input, output) -> L-channel as input, AB-channel as target
        yield l_batch, ab_batch

## Model

In [10]:
class GAN:
    def __init__(self, input_shape, target_shape):
        self.input_shape = input_shape
        self.target_shape = target_shape
        self.weight_init = RandomNormal(stddev=0.02)
        self.generator = self._build_generator()
        self.discriminator = self._build_discriminator()
        self.model = self._build_gan()
        self.patch_shape = self._get_patch_size()
        self.batch_size = 32
        self.data = generator_wrapper(batch_size=self.batch_size)

    def _encoder_block(self, input, filters, kernel, strides):
        conv = Conv2D(filters, kernel, strides=strides, padding='same', kernel_initializer=self.weight_init)(input)
        conv = BatchNormalization()(conv)
        conv = LeakyReLU(alpha=0.2)(conv)
        return conv

    def _build_generator(self):
        in_image = Input(shape=self.input_shape)
        in_backbone = Concatenate()([in_image, in_image, in_image])
        resnet = ResNet101(
            include_top=False,
            input_shape=(224, 224, 3),
            weights='imagenet'
        )
        for layer in resnet.layers:
            layer.trainable = False
        
        backbone = resnet(in_backbone)
        conv1 = self._encoder_block(in_image, 64, (3, 3), (2, 2))
        conv2 = self._encoder_block(conv1, 128, (3, 3), strides=(1, 1))
        conv3 = self._encoder_block(conv2, 128, (3, 3), strides=(2, 2))
        conv4 = self._encoder_block(conv3, 256, (3, 3), strides=(2, 2))
        conv4_ = self._encoder_block(conv4, 256, (3, 3), strides=(1, 1))
        conv5 = self._encoder_block(conv4_, 512, (3, 3), strides=(2, 2))
        conv5_ = self._encoder_block(conv5, 256, (3, 3), strides=(2, 2))
        conc = concatenate([backbone, conv5_])
        fusion = self._encoder_block(conc, 512, (1, 1), (1,1))
        skip_fusion = concatenate([fusion, conv5_])
        
        decoder = Conv2DTranspose(1024, (3, 3), strides=(2, 2), padding='same', kernel_initializer=self.weight_init)(skip_fusion)
        decoder = Activation('relu')(decoder)
        decoder = Dropout(0.25)(decoder)
        skip_4_drop = Dropout(0.25)(conv5)
        skip_4 = concatenate([decoder, skip_4_drop])
        decoder = Conv2DTranspose(512, (3, 3), strides=(2, 2), padding='same', kernel_initializer=self.weight_init)(skip_4)
        decoder = Activation('relu')(decoder)
        decoder = Dropout(0.25)(decoder)
        skip_3_drop = Dropout(0.25)(conv4_)
        skip_3 = concatenate([decoder, skip_3_drop])
        decoder = Conv2DTranspose(256, (3, 3), strides=(2, 2), padding='same', kernel_initializer=self.weight_init)(skip_3)
        decoder = Activation('relu')(decoder)
        decoder = Dropout(0.25)(decoder)
        decoder = Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same', kernel_initializer=self.weight_init)(decoder)
        decoder = Activation('relu')(decoder)
        decoder = Dropout(0.25)(decoder)
        decoder = Conv2DTranspose(64, (3, 3), strides=(1, 1), padding='same', kernel_initializer=self.weight_init)(decoder)
        decoder = Activation('relu')(decoder)
        decoder = Conv2DTranspose(32, (3, 3), strides=(2, 2), padding='same', kernel_initializer=self.weight_init)(decoder)
        decoder = Activation('relu')(decoder)
        output_layer = Conv2D(2, (1, 1), activation='sigmoid')(decoder)
        model = Model(in_image, output_layer)
        return model

    def _build_discriminator(self):
        init = RandomNormal(stddev=0.02)
        in_src_image = Input(shape=self.input_shape)
        in_target_image = Input(shape=self.target_shape)
        merged = Concatenate()([in_src_image, in_target_image])

        d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(merged)
        d = LeakyReLU(alpha=0.2)(d)
        d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
        d = BatchNormalization()(d)
        d = LeakyReLU(alpha=0.2)(d)
        d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
        d = BatchNormalization()(d)
        d = LeakyReLU(alpha=0.2)(d)
        d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
        d = BatchNormalization()(d)
        d = LeakyReLU(alpha=0.2)(d)
        d = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)
        d = BatchNormalization()(d)
        d = LeakyReLU(alpha=0.2)(d)
        # Patch gan
        d = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)
        patch_out = Activation('sigmoid')(d)

        model = Model([in_src_image, in_target_image], patch_out)
        opt = Adam(learning_rate=0.0002, beta_1=0.5)
        model.compile(loss='binary_crossentropy', optimizer=opt)
        return model

    def _build_gan(self):
      for layer in self.discriminator.layers:
        if not isinstance(layer, BatchNormalization):
          layer.trainable = False
      in_src = Input(shape=self.input_shape)
      gen_out = self.generator(in_src)
      dis_out = self.discriminator([in_src, gen_out])
      model = Model(in_src, dis_out)
      opt = Adam(learning_rate=0.0002, beta_1=0.5)
      model.compile(loss='binary_crossentropy', optimizer=opt,metrics=[keras.metrics.BinaryAccuracy()])
      return model

    def pretrain_generator(self, steps_per_epoch=100, epochs=10):
        self.generator.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=0.001, beta_1=0.5),  # Adjust learning rate and beta as needed
            loss='mean_absolute_error',  # You can also try 'mean_absolute_error'
            metrics=['mse', 'mae']  # Monitoring MAE during training
        )
        self.generator.fit(self.data, steps_per_epoch=steps_per_epoch, epochs=epochs)

    def _get_patch_size(self):
      return self.discriminator.output.shape[1]

    def _generate_real_samples(self):
        """Generate a batch of real samples from the data generator."""
        l_batch, ab_batch = next(self.data)  # Get batch from data generator
        true_labels = np.ones((self.batch_size, self.patch_shape, self.patch_shape, 1))  # Label as real
        return [l_batch, ab_batch], true_labels

    def _generate_fake_samples(self, l_batch):
        """Generate a batch of fake samples (predicted by the generator)."""
        predictions = self.generator.predict(l_batch, verbose=0)  # Predict AB channels
        fake_labels = np.zeros((len(l_batch), self.patch_shape, self.patch_shape, 1))  # Label as fake
        return predictions, fake_labels

    def _denormalize_ab_channel(self, ab_channel_normalized):
        """Denormalize AB channels back to their original value ranges."""    
        # Denormalize A channel: [0, 1] -> [42, 226]
        a_channel = (ab_channel_normalized[:, :, :, 0] * 184.0) + 42
    
        # Denormalize B channel: [0, 1] -> [20, 223]
        b_channel = (ab_channel_normalized[:, :, :, 1] * 203.0) + 20
    
        # Stack the denormalized AB channels
        ab_channel = np.stack([a_channel, b_channel], axis=-1)
    
        return ab_channel

    def display_results(self, num_images=5):
        """Display results during training."""
        l_batch, ab_batch = next(self.data)
        predictions = self.generator.predict(l_batch, verbose=0)
        # Denormalize AB channels
        ab_batch = self._denormalize_ab_channel(ab_batch)
        predictions = self._denormalize_ab_channel(predictions)

         # Select 5 images to display
        real_images = []
        generated_images = []
        
        for i in range(num_images):
            # Recombine the L and AB channels for the real images
            lab_real = np.stack([l_batch[i, :, :, 0], ab_batch[i, :, :, 0], ab_batch[i, :, :, 1]], axis=-1)
            # Recombine the L and AB channels for the generated images
            lab_pred = np.stack([l_batch[i, :, :, 0], predictions[i, :, :, 0], predictions[i, :, :, 1]], axis=-1)
            # Convert from LAB to RGB for display
            rgb_real = cv2.cvtColor(lab_real.astype(np.uint8), cv2.COLOR_LAB2RGB)
            rgb_pred = cv2.cvtColor(lab_pred.astype(np.uint8), cv2.COLOR_LAB2RGB)
            
            real_images.append(rgb_real)
            generated_images.append(rgb_pred)
        
        # Display the images in one row
        fig, axes = plt.subplots(2, num_images, figsize=(15, 6))
        
        for i in range(num_images):
            # Display real images
            axes[0, i].imshow(real_images[i])
            axes[0, i].axis("off")
            axes[0, i].set_title(f"Real {i+1}")
            
            # Display generated images
            axes[1, i].imshow(generated_images[i])
            axes[1, i].axis("off")
            axes[1, i].set_title(f"Generated {i+1}")
    
        plt.show()
        
    def _set_discriminator_trainable(self, trainable=True):
        for layer in self.discriminator.layers:
            if not isinstance(layer, BatchNormalization):
                layer.trainable = trainable

    def _train_step(self, L_real, AB_real, y_real, AB_fake, y_fake):
        self._set_discriminator_trainable()
        with tf.GradientTape() as tape_d:
            # Discriminator loss on real samples
            d_loss_real = self.discriminator([L_real, AB_real], training=True)
            d_loss_real = tf.keras.losses.binary_crossentropy(y_real, d_loss_real)
            
            # Discriminator loss on fake samples
            d_loss_fake = self.discriminator([L_real, AB_fake], training=True)
            d_loss_fake = tf.keras.losses.binary_crossentropy(y_fake, d_loss_fake)
            
            # Total discriminator loss
            d_loss = d_loss_real + d_loss_fake

        # Compute and apply discriminator gradients
        grads_d = tape_d.gradient(d_loss, self.discriminator.trainable_variables)
        self.discriminator.optimizer.apply_gradients(zip(grads_d, self.discriminator.trainable_variables))
        self._set_discriminator_trainable(False)
        with tf.GradientTape() as tape_g:
            # Generator loss
            g_loss_fake = self.model(L_real, training=True)
            g_loss = tf.keras.losses.binary_crossentropy(y_real, g_loss_fake)

        # Compute and apply generator gradients
        grads_g = tape_g.gradient(g_loss, self.model.trainable_variables)
        self.model.optimizer.apply_gradients(zip(grads_g, self.model.trainable_variables))

        return d_loss_real, d_loss_fake, g_loss

    def train(self, n_epochs=100, batch_size=32):
        """Train the GAN using the data generator."""
        self.batch_size = batch_size
        self.data = generator_wrapper(batch_size=self.batch_size)
        n_steps = 10000  # Adjust for epochs and steps

        for i in range(n_steps):
            # Generate real samples from the data generator
            [L_real, AB_real], y_real = self._generate_real_samples()

            # Generate fake samples from the generator
            AB_fake, y_fake = self._generate_fake_samples(L_real)

            # Perform the training step manually
            d_loss_real, d_loss_fake, g_loss = self._train_step(L_real, AB_real, y_real, AB_fake, y_fake)

            if i % 250 == 0:
                # Discriminator predictions on real and fake samples
                d_real_pred = self.discriminator.predict([L_real, AB_real], verbose=0)
                d_fake_pred = self.discriminator.predict([L_real, AB_fake], verbose=0)

                # Calculate the percentage of true and false predictions
                real_acc = np.mean(d_real_pred > 0.5) # real samples classified as real
                fake_acc = np.mean(d_fake_pred < 0.5) # fake samples classified as fake
                print(f'Real: {real_acc}, Fake: {fake_acc}, G-loss: {np.mean(g_loss)}')
                self.display_results()

    def save(self):
        self.generator.save(f'generator.h5')
        
inp_shape = (224,224,1)
out_shape = (224,224,2)
gan = GAN(inp_shape, out_shape)

## Initial generator training

In [None]:
gan.pretrain_generator(steps_per_epoch=500, epochs=10)
gan.save()

## GAN Training

In [None]:
gan.train(batch_size=24)
gan.save()