# Imports

In [2]:
%%capture
!pip install datasets

In [3]:
from datasets import load_dataset
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import tensorflow as tf
from tensorflow.keras.utils import plot_model
from tensorflow.keras.applications.vgg19 import preprocess_input

  from .autonotebook import tqdm as notebook_tqdm


check versions

In [None]:
print(tf.__version__) # 2.18.0
print(cv2.__version__) # 4.11.0
print(np.__version__) # 2.0.2
print(Image.__version__) # 11.1.0

# Load and Process the data

In [None]:
ds = load_dataset("calibretaliation/colorization", split=f"train[5000:10000]")
ds = ds.remove_columns([col for col in ds.column_names if col not in ["original_image", "colorized_image"]])

In [None]:
def resize_and_preprocess(image, target_size=(32, 32)):
    image = np.array(image)
    if image.ndim == 2:
        image = np.stack([image] * 3, axis=-1)

    image = tf.convert_to_tensor(image, dtype=tf.float32)
    image = tf.image.resize(image, target_size)
    image = (image / 127.5) - 1

    return image

def process_dataset(example):
    return {
        'colorized_image': resize_and_preprocess(example['colorized_image'], (64, 64)),
        'original_image': resize_and_preprocess(example['original_image'], (64, 64))
        }

ds = ds.map(process_dataset, num_proc=4)

In [None]:
train_colored = np.array([item for item in ds['colorized_image']])
train_bw = np.array([item for item in ds['original_image']])

In [None]:
def view_image(generated_images, original_images, num_images=3):
    plt.figure(figsize=(12, 4))

    # Display original grayscale images
    for i in range(num_images):
        plt.subplot(2, num_images, i + 1)
        gray_img = np.squeeze(original_images[i])
        plt.imshow(((gray_img + 1) * 127.5).astype(np.uint8), cmap='gray')
        plt.axis('off')
        plt.title('Original')

    # Display generated color images
    for i in range(num_images):
        plt.subplot(2, num_images, num_images + i + 1)
        color_img = generated_images[i]
        plt.imshow(((color_img + 1) * 127.5).astype(np.uint8))
        plt.axis('off')
        plt.title('Generated')

    plt.tight_layout()
    plt.show()

view_image(train_colored[67:70], train_bw[67:70], num_images=3)

# DISCRIMINATOR

In [None]:
def discriminator(shape=(64, 64, 3)):
    image = tf.keras.layers.Input(shape=shape, name='input_image')

    x = tf.keras.layers.Conv2D(64, (3, 3), strides=(1, 1), padding='same')(image)
    x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)

    x = tf.keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)

    x = tf.keras.layers.Conv2D(256, (3, 3), strides=(2, 2), padding='same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)

    x = tf.keras.layers.Conv2D(256, (3, 3), strides=(2, 2), padding='same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)

    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dropout(0.4)(x)

    x = tf.keras.layers.Dense(1, activation='sigmoid')(x)

    model = tf.keras.Model(inputs=image, outputs=x, name='Discriminator')
    model.compile(
        loss='binary_crossentropy',
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5),
        metrics=['accuracy']
    )

    return model

d_model = discriminator(shape=(64, 64, 3))
d_model.summary()

In [None]:
plot_model(d_model, show_shapes=True, show_layer_names=True, show_layer_activations=True, dpi=100)

In [None]:
def generate_samples(dataset, n_samples, is_real):
    ix = np.random.randint(0, dataset.shape[0], n_samples)
    X = dataset[ix]
    y = np.ones((n_samples, 1)) if is_real else np.zeros((n_samples, 1))
    return X, y

# GENERATOR

In [None]:
def generator(input_shape=(64, 64, 3)):
    input_image = tf.keras.layers.Input(shape=input_shape, name='input_image')

    down1 = tf.keras.layers.Conv2D(64, (3, 3), strides=(2, 2), padding='same')(input_image)
    bn1 = tf.keras.layers.BatchNormalization()(down1)
    act1 = tf.keras.layers.LeakyReLU(alpha=0.2)(bn1)

    down2 = tf.keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same')(act1)
    bn2 = tf.keras.layers.BatchNormalization()(down2)
    act2 = tf.keras.layers.LeakyReLU(alpha=0.2)(bn2)

    down3 = tf.keras.layers.Conv2D(256, (3, 3), strides=(2, 2), padding='same')(act2)
    bn3 = tf.keras.layers.BatchNormalization()(down3)
    act3 = tf.keras.layers.LeakyReLU(alpha=0.2)(bn3)

    down4 = tf.keras.layers.Conv2D(256, (3, 3), strides=(2, 2), padding='same')(act3)
    bn4 = tf.keras.layers.BatchNormalization()(down4)
    act4 = tf.keras.layers.LeakyReLU(alpha=0.2)(bn4)

    up = tf.keras.layers.Conv2DTranspose(256, (3, 3), strides=(2, 2), padding='same')(act4)
    bn_up = tf.keras.layers.BatchNormalization()(up)
    drop = tf.keras.layers.Dropout(0.2)(bn_up)
    relu_up = tf.keras.layers.ReLU()(drop)
    skip = tf.keras.layers.Concatenate()([relu_up, act3])

    up1 = tf.keras.layers.Conv2DTranspose(256, (3, 3), strides=(2, 2), padding='same')(skip)
    bn_up1 = tf.keras.layers.BatchNormalization()(up1)
    drop1 = tf.keras.layers.Dropout(0.2)(bn_up1)
    relu_up1 = tf.keras.layers.ReLU()(drop1)
    skip1 = tf.keras.layers.Concatenate()([relu_up1, act2])

    up2 = tf.keras.layers.Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same')(skip1)
    bn_up2 = tf.keras.layers.BatchNormalization()(up2)
    drop2 = tf.keras.layers.Dropout(0.2)(bn_up2)
    relu_up2 = tf.keras.layers.ReLU()(drop2)
    skip2 = tf.keras.layers.Concatenate()([relu_up2, act1])

    up3 = tf.keras.layers.Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same')(skip2)
    bn_up3 = tf.keras.layers.BatchNormalization()(up3)
    drop3 = tf.keras.layers.Dropout(0.2)(bn_up3)
    relu_up3 = tf.keras.layers.ReLU()(drop3)

    output = tf.keras.layers.Conv2D(3, (3, 3), activation='tanh', padding='same')(relu_up3)

    model = tf.keras.Model(inputs=input_image, outputs=output, name='Generator')
    return model

In [None]:
plot_model(g_model, show_shapes=True, show_layer_names=True, show_layer_activations=True, dpi=70)

# Training

In [None]:
def generate_generator_samples(g_model, dataset, n_samples, is_real=False):
    ix = np.random.randint(0, dataset.shape[0], n_samples)
    X = dataset[ix]
    gen_images = g_model.predict(X, verbose=0)
    y = np.ones((n_samples, 1)) if is_real else np.zeros((n_samples, 1))
    return gen_images, y

vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet', input_shape=(64, 64, 3))
vgg.trainable = False
vgg_feature_extractor = tf.keras.Model(inputs=vgg.input, outputs=vgg.get_layer('block3_conv3').output)

@tf.function(reduce_retracing=True)
def perceptual_loss(y_true, y_pred):
    y_true_processed = tf.keras.applications.vgg19.preprocess_input(y_true * 255.0)
    y_pred_processed = tf.keras.applications.vgg19.preprocess_input(y_pred * 255.0)

    features_true = vgg_feature_extractor(y_true_processed)
    features_pred = vgg_feature_extractor(y_pred_processed)

    return tf.reduce_mean(tf.abs(features_true - features_pred))


def l1_loss(y_true, y_pred):
    return tf.reduce_mean(tf.abs(y_true - y_pred))
bce = tf.keras.losses.BinaryCrossentropy(from_logits=False)

def gan_loss(disc_generated_output):
    return bce(tf.ones_like(disc_generated_output), disc_generated_output)
def combined_generator_loss(y_true, y_pred, disc_generated_output, lambda_l1=100.0, lambda_perc=10.0):
    l1 = l1_loss(y_true, y_pred)
    perc = perceptual_loss(y_true, y_pred)
    g_loss = gan_loss(disc_generated_output)

    total_loss = g_loss + (lambda_l1 * l1) + (lambda_perc * perc)
    return total_loss

def define_gan(g_model, d_model, shape):
    d_model.trainable = False

    gan_input = tf.keras.layers.Input(shape=shape)
    g_output = g_model(gan_input)

    d_output = d_model(g_output)

    model = tf.keras.Model(inputs=gan_input, outputs=[g_output, d_output], name='GAN')

    initial_learning_rate = 0.0002
    decay_steps = 10000
    decay_rate = 0.96

    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate,
        decay_steps=decay_steps,
        decay_rate=decay_rate,
        staircase=True
    )

    optimizer = tf.keras.optimizers.Adam(
        learning_rate=lr_schedule,
        beta_1=0.5,
        clipnorm=1.0
    )

    def total_loss(y_true, y_pred):
        return combined_generator_loss(y_true, y_pred, d_model(y_pred))

    model.compile(
        loss=[total_loss, 'binary_crossentropy'],
        loss_weights=[100, 1],  # Change the parameter according to the image it's generating
        optimizer=optimizer
    )

    return model

def summarize_performance(epoch, g_model, d_model, dataset_colored, dataset_bw, n_samples=100):
    X_real, y_real = generate_samples(dataset_colored, n_samples, is_real=True)
    _, acc_real = d_model.evaluate(X_real, y_real, verbose=0)

    X_fake, y_fake = generate_generator_samples(g_model, dataset_bw, n_samples, is_real=False)
    _, acc_fake = d_model.evaluate(X_fake, y_fake, verbose=0)

    print(f"Accuracy real: {acc_real*100:.2f}%, Accuracy fake: {acc_fake*100:.2f}%")
    view_image(X_fake[:5], X_fake[5:10])

def train_gan(g_model, d_model, gan_model, dataset_colored, dataset_bw, epochs=50, batch_size=128, checkpoint_dir='./training_checkpoints'):
    # Create checkpoint directory if it doesn't exist
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Create optimizers for generator and discriminator
    g_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
    d_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)

    # Create checkpoint
    checkpoint = tf.train.Checkpoint(
        generator_optimizer=g_optimizer,
        discriminator_optimizer=d_optimizer,
        generator=g_model,
        discriminator=d_model
    )

    # Checkpoint manager to keep multiple checkpoints
    checkpoint_manager = tf.train.CheckpointManager(
        checkpoint,
        directory=checkpoint_dir,
        max_to_keep=5  # Keep last 5 checkpoints
    )

    # Restore the latest checkpoint if exists
    latest_checkpoint = checkpoint_manager.latest_checkpoint
    if latest_checkpoint:
        checkpoint.restore(latest_checkpoint).expect_partial()
        print(f"Resuming training from {latest_checkpoint}")
    else:
        print("Initializing from scratch.")

    batch_per_epoch = dataset_colored.shape[0] // batch_size
    for epoch in range(epochs):
        for batch in range(batch_per_epoch):
            idx = np.random.randint(0, dataset_colored.shape[0], batch_size)
            real_images = dataset_colored[idx]
            bw_images = dataset_bw[idx]

            # Generate fake samples
            generated_images = g_model.predict(bw_images, verbose=0)

            # Train discriminator
            d_loss_real = d_model.train_on_batch(real_images, np.ones((batch_size, 1)))
            d_loss_fake = d_model.train_on_batch(generated_images, np.zeros((batch_size, 1)))

            # Ensure values are scalars
            d_loss_real = d_loss_real[0] if isinstance(d_loss_real, (list, np.ndarray)) else d_loss_real
            d_loss_fake = d_loss_fake[0] if isinstance(d_loss_fake, (list, np.ndarray)) else d_loss_fake
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # Train generator
            g_loss = gan_model.train_on_batch(bw_images, [real_images, np.ones((batch_size, 1))])
            g_total, g_custom, g_bce = g_loss[:3] if isinstance(g_loss, (list, np.ndarray)) else (g_loss, 0, 0)

        # Print progress every 10 epochs
        if (epoch + 1) % 10 == 0:
            print(f"Epoch: {epoch+1}/{epochs}")
            print(f"D1 Loss: {d_loss_real:.4f}, D2 Loss: {d_loss_fake:.4f}")
            print(f"G Total Loss: {g_total:.4f}, G Custom Loss: {g_custom:.4f}, G BCE Loss: {g_bce:.4f}")

        # Save checkpoint every 50 epochs
        if (epoch + 1) % 10 == 0:
            # Save checkpoint
            save_path = checkpoint_manager.save()
            print(f"Checkpoint saved: {save_path}")

            idx = np.random.randint(0, dataset_bw.shape[0], 3)
            generated_images = g_model.predict(dataset_bw[idx], verbose=0)
            view_image(generated_images, dataset_bw[idx])

    # Final checkpoint
    save_path = checkpoint_manager.save()
    print(f"Training complete. Final checkpoint saved: {save_path}")

In [None]:
epochs=100 # set as you need
batch_size=128
shape=(64, 64, 3)

Start training

In [None]:
# import time

g_model = generator(shape)
d_model = discriminator(shape)
gan_model = define_gan(g_model, d_model, shape)

# start_time = time.time() # To check how much time its taking
train_gan(g_model, d_model, gan_model, train_colored, train_bw, epochs=epochs, batch_size=batch_size)
# end_time = time.time()

In [None]:
# To check the time
# elapsed_time = end_time - start_time
# print(f"Total training time: {elapsed_time:.2f} seconds ({elapsed_time/60:.2f} minutes)")

# Testing

In [None]:
def load_checkpoint(checkpoint_dir):
    # Recreate the generator model with the same architecture
    g_model = generator(input_shape=(64, 64, 3))
    # Create a checkpoint
    checkpoint = tf.train.Checkpoint(generator=g_model)
    # Restore the latest checkpoint
    checkpoint_manager = tf.train.CheckpointManager(
        checkpoint,
        directory=checkpoint_dir,
        max_to_keep=5
    )
    latest_checkpoint = checkpoint_manager.latest_checkpoint
    if latest_checkpoint:
        checkpoint.restore(latest_checkpoint)
        print(f"Restored generator from checkpoint: {latest_checkpoint}")
    else:
        print("No checkpoint found!")
    return g_model
def colorize_images(generator, black_and_white_images):
    # Ensure input is in the right shape and normalized
    if black_and_white_images.ndim == 3:
        black_and_white_images = black_and_white_images[np.newaxis, ...]

    # Predict colorized images
    colorized_images = generator.predict(black_and_white_images)

    return colorized_images

In [None]:
checkpoint_dir = './training_checkpoints'
generator_infer = load_checkpoint(checkpoint_dir)
test_bw_images = train_bw[:10]
colorized_images = colorize_images(generator_infer, test_bw_images)
view_image(colorized_images, test_bw_images, num_images=10)

Saving and loading the model

In [None]:
tf.saved_model.save(g_model, './model') 

loaded_model = tf.saved_model.load('./model')
infer = loaded_model.signatures["serving_default"]

Test the model

In [None]:
print(infer.structured_input_signature)

input_tensor = tf.convert_to_tensor(train_bw[10:20], dtype=tf.float32)
output = infer(inputs=input_tensor)
# print(output.keys())
generated_image = output[next(iter(output.keys()))].numpy()
view_image(generated_image, train_bw[10:20], num_images=10)