# Practical Assesment Task #1: Colorizing Images with Generative Adversarial Networks
Técnicas Generativas y Aprendizaje por Refuerzo - Curso 2024/2025



## Import Dependencies and Set General Parameters

In [3]:
%pip install tensorflow -q
%pip install scikit-image -q

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 25.0 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 25.0 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [4]:
# Import libraries
import os
import sys
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from datetime import datetime
from tensorflow.keras.layers import (
    Activation, AveragePooling2D, BatchNormalization, Conv2D, Dense,
    Dropout, Flatten, LeakyReLU, UpSampling2D)
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.models import Sequential
from time import time
from skimage.color import rgb2lab, lab2rgb
from pathlib import Path
from tqdm import tqdm

from tensorflow.keras.models import Sequential

# Enable memory growth for GPU
physical_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

# Constants
IMAGE_SIZE = 32
EPOCHS = 200 # Increase or decrease as appropriate
BATCH_SIZE = 128
SHUFFLE_BUFFER_SIZE = 100

WORKDIR = "/content/drive/My Drive/Colab Notebooks/Workspace/"

# Create directories
Path(WORKDIR).mkdir(parents=True, exist_ok=True)
Path(f"{WORKDIR}/results").mkdir(parents=True, exist_ok=True)

IndexError: list index out of range

## Define Dataset Loader

In [2]:
def generate_dataset(images, debug=False):
    """
    This function converts a list of RGB images to the LAB color space, and separates the L channel from the A and B channels.
    The images are normalized and reshaped to the appropriate dimensions for the model.

    Parameters:
    images (list): A list of RGB images
    debug (bool): If True, the function will display the original and converted images

    Returns:
    X (np.array): A numpy array of L channels of the images
    Y (np.array): A numpy array of A and B channels of the images
    """
    X = []
    Y = []

    for i in images:
        lab_image_array = rgb2lab(i / 255)  # Convert the image from RGB to LAB color space
        x = lab_image_array[:, :, 0]  # Get the L channel
        y = lab_image_array[:, :, 1:]  # Get the A and B channels
        y /= 128  # normalize

        if debug:  # If debug is True, display the original and converted images
            fig = plt.figure()
            fig.add_subplot(1, 2, 1)
            plt.imshow(i / 255)

            fig.add_subplot(1, 2, 2)
            plt.imshow(lab2rgb(np.dstack((x, y * 128))))
            plt.show()

        X.append(x.reshape(IMAGE_SIZE, IMAGE_SIZE, 1))  # Reshape the L channel and append to the list
        Y.append(y)  # Append the A and B channels to the list

    X = np.array(X, dtype=np.float32)  # Convert the list to a numpy array
    Y = np.array(Y, dtype=np.float32)  # Convert the list to a numpy array

    return X, Y


def load_data(force=False):
    """
    This function loads the CIFAR-10 dataset, processes it, and saves/loads the processed data to/from the disk.

    Parameters:
    force (bool): If True, the function will reprocess the data even if it already exists on the disk

    Returns:
    X_train, Y_train, X_test, Y_test (np.array): Numpy arrays of the training and testing data
    """
    is_saved_arrays_exist = os.path.isfile(os.path.join(WORKDIR, 'X_train.npy'))  # Check if the processed data already exists on the disk

    if not is_saved_arrays_exist or force:  # If the processed data does not exist or if force is True, process the data
        (train_images, _), (test_images, _) = cifar10.load_data()  # Load the CIFAR-10 dataset
        X_train, Y_train = generate_dataset(train_images)  # Process the training images
        X_test, Y_test = generate_dataset(test_images)  # Process the testing images
        print('Saving processed data to Drive')
        np.save(os.path.join(WORKDIR, 'X_train.npy'), X_train)  # Save the processed training data to the disk
        np.save(os.path.join(WORKDIR, 'Y_train.npy'), Y_train)  # Save the processed training data to the disk
        np.save(os.path.join(WORKDIR, 'X_test.npy'), X_test)  # Save the processed testing data to the disk
        np.save(os.path.join(WORKDIR, 'Y_test.npy'), Y_test)  # Save the processed testing data to the disk
    else:  # If the processed data exists and force is False, load the data from the disk
        print('Loading processed data from Drive')
        X_train = np.load(os.path.join(WORKDIR, 'X_train.npy'))  # Load the processed training data from the disk
        Y_train = np.load(os.path.join(WORKDIR, 'Y_train.npy'))  # Load the processed training data from the disk
        X_test = np.load(os.path.join(WORKDIR, 'X_test.npy'))  # Load the processed testing data from the disk
        Y_test = np.load(os.path.join(WORKDIR, 'Y_test.npy'))  # Load the processed testing data from the disk

    return X_train, Y_train, X_test, Y_test

## Load Dataset

In [3]:
# Load the processed training and testing data
X_train, Y_train, X_test, Y_test = load_data()

# Create TensorFlow datasets from the training and testing data
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, Y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, Y_test))

# Shuffle the training dataset and batch both datasets
# SHUFFLE_BUFFER_SIZE determines the randomness of the shuffling
# BATCH_SIZE determines the number of samples that will be propagated through the network at once
train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
test_dataset = test_dataset.batch(BATCH_SIZE)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 0us/step
Saving processed data to Drive


## Define Generator Model

In [4]:
def build_generator_model():
    """
    This function builds the generator model for the GAN. The generator is responsible for generating new, fake images.
    It uses a series of Conv2D, BatchNormalization, and UpSampling2D layers.

    Returns:
    model (Sequential): The generator model
    """
    model = Sequential()  # Initialize the model

    # Downsampling (Extracción de características)
    model.add(Conv2D(64, (3,3), strides=2, padding='same', input_shape=(IMAGE_SIZE, IMAGE_SIZE, 1)))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Conv2D(128, (3,3), strides=2, padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Conv2D(256, (3,3), strides=2, padding='same'))  # Ahora 4x4
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    # Upsampling (Reconstrucción)
    model.add(UpSampling2D(size=(2,2)))  # 4x4 → 8x8
    model.add(Conv2D(128, (3,3), padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(UpSampling2D(size=(2,2)))  # 8x8 → 16x16
    model.add(Conv2D(64, (3,3), padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(UpSampling2D(size=(2,2)))  # 16x16 → 32x32
    model.add(Conv2D(32, (3,3), padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    # Capa de salida
    model.add(Conv2D(2, (3,3), padding='same'))  # Salida con canales A y B
    model.add(Activation('tanh'))  # Salida en [-1,1]


    # COMPLETE

    # Suggested architecture:
    # Downsampling
        # Add a series of Conv2D layers followed by BatchNormalization layers
        # Conv2D layers are used for the convolution operation that extracts features from the input images
        # Conv2D layers with stride=1 preserve the dimensions of the image
        # Conv2D layers with stride=2 downsample the image (i.e., reduce the dimensions of the image by half)
        # BatchNormalization layers are used to normalize the activations of the previous layer at each batch
        # ReLU activation function is used to add non-linearity to the output of the previous layer
        # Padding is set to 'same' to preserve the dimensions of the image
        # Repeat this series of layers N times

    # Upsampling
        # UpSampling2D layer is used to increase the dimension of the image
        # Add as many UpSampling2D layers as Conv2D layers with stride=2 to upsample the image to its original dimensions
        # Add a series of Conv2D layers followed by BatchNormalization layers
        # BatchNormalization layers are used to normalize the activations of the previous layer at each batch
        # ReLU activation function is used to add non-linearity to the output of the previous layer

    # Output layer
        # Add a Conv2D layer with 2 filters to output the A and B LAB channels of the image
        # Activation layer is used to apply the tanh activation function to the output
        # tanh activation function is used because the pixels of the output images range from -1 to 1

    return model

## Define Discriminator Model

In [5]:
def build_discriminator_model():
    """
    This function builds the discriminator model for the GAN. The discriminator is responsible for distinguishing real images from fake ones.
    It is recommended to use a series of Conv2D, Dropout, AveragePooling2D, Flatten, Dense, LeakyReLU, BatchNormalization, and Activation layers to create the model.

    Returns:
    model (Sequential): The discriminator model
    """

    model = Sequential()  # Initialize the model

    # Bloques de convolución
    model.add(Conv2D(64, (3,3), strides=2, padding='same', input_shape=(IMAGE_SIZE, IMAGE_SIZE, 2)))  # Canales A y B en LAB
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.3))

    model.add(Conv2D(128, (3,3), strides=2, padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.3))

    model.add(Conv2D(256, (3,3), strides=2, padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.3))

    # Pooling y capas densas
    model.add(AveragePooling2D(pool_size=(2,2)))
    model.add(Flatten())
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization())
    model.add(Dropout(0.4))

    # Capa de salida
    model.add(Dense(1, activation='sigmoid'))  # Salida en [0,1] (real o falsa)
    # COMPLETE

    # Suggested architecture:
    # Convolution blocks
        # Add a series of Conv2D layers followed by Dropout
        # Conv2D layers are used for the convolution operation that extracts features from the input images
        # Dropout layers are used to prevent overfitting by randomly setting a fraction rate of input units to 0 at each update during training time
        # Repeat this series of layers N times

    # Pooling and dense layers
        # AveragePooling2D layer is used to downscale the image spatially
        # Flatten layer is used to convert the 2D matrix of features into a vector that can be fed into a fully connected neural network classifier
        # Dense layers are the regular deeply connected neural network layers
        # LeakyReLU is a type of activation function that allows a small gradient when the unit is not active
        # BatchNormalization layers are used to normalize the activations of the previous layer at each batch
        # Dropout layers are used to prevent overfitting by randomly setting a fraction rate of input units to 0 at each update during training time

    # Output layer
        # Add a Dense layer with 1 unit to output the probability of the input image being real or synthetic
        # Activation layer is used to apply the sigmoid activation function to the output
        # sigmoid activation function is used because the output is a probability between 0 and 1

    return model

## Define Generator and Discriminator Loss Functions

In [6]:
# Define the weight of the GAN loss
gan_loss_weight = 1

# Define the regularization parameter for the generator's L2 loss
l2_lambda = 150

# Define the loss function for the discriminator
# Binary Cross Entropy is used as the loss function since we are dealing with a binary classification problem (real vs fake images)
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(disc_real_output, disc_generated_output):
    """
    This function calculates the total loss for the discriminator.

    Parameters:
    disc_real_output (Tensor): The discriminator's prediction on the real images
    disc_generated_output (Tensor): The discriminator's prediction on the generated (fake) images

    Returns:
    total_disc_loss (Tensor): The total loss for the discriminator
    """

    # Cross entropy loss for real images (label 1 = real)
    real_loss = cross_entropy(tf.ones_like(disc_real_output), disc_real_output)

    # Cross entropy loss for generated (fake) images (label 0 = fake)
    fake_loss = cross_entropy(tf.zeros_like(disc_generated_output), disc_generated_output)

    # Total loss for the discriminator
    total_disc_loss = real_loss + fake_loss

    return total_disc_loss


def generator_loss(disc_generated_output, gen_output, target):
    """
    This function calculates the total loss for the generator.

    Parameters:
    disc_generated_output (Tensor): The discriminator's prediction on the generated (fake) images
    gen_output (Tensor): The generated (fake) images
    target (Tensor): The real images

    Returns:
    total_gen_loss (Tensor): The total loss for the generator
    gan_loss (Tensor): The GAN loss for the generator
    l2_loss (Tensor): The L2 loss for the generator
    """

    # GAN loss (objective: engañar al discriminador para que piense que la imagen generada es real)
    gan_loss = cross_entropy(tf.ones_like(disc_generated_output), disc_generated_output)

    # L2 Loss (Mean Squared Error entre la imagen generada y la imagen real)
    l2_loss = tf.reduce_mean(tf.square(target - gen_output))

    # Total Generator Loss (GAN loss ponderada + L2 regularización)
    total_gen_loss = gan_loss_weight * gan_loss + l2_lambda * l2_loss

    return total_gen_loss, gan_loss, l2_loss


## Build Generator and Discriminator Models

In [7]:
# Instantiate the generator and discriminator models
generator = build_generator_model()
discriminator = build_discriminator_model()

# Print out the model summaries
generator.summary()
discriminator.summary()

# Define the optimizers for the generator and discriminator
# Adam optimizer is used with a learning rate of 2e-4 and beta_1 (the exponential decay rate for the first moment estimates) of 0.5
# It is recommended to experiment with different values for the learning rate. Beta_1 value of 0.5 is recommended for GANs
# It is possible that the optimal learning rate for the generator and discriminator is different from each other
gen_learning_rate = 2e-4
gene_beta_1 = 0.5

disc_learning_rate = 2e-4
disc_beta_1 = 0.5

generator_optimizer = tf.keras.optimizers.Adam(gen_learning_rate, beta_1=gene_beta_1)
discriminator_optimizer = tf.keras.optimizers.Adam(disc_learning_rate, beta_1=disc_beta_1)

# Define the directory for storing the training checkpoints
checkpoint_dir = os.path.join(WORKDIR, 'training-checkpoints')
# Define the prefix for the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
# Create a checkpoint object which will be used to save and load the models and optimizers
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)
# Create a manager for the checkpoint object, which will be responsible for saving and loading the checkpoints
# max_to_keep parameter is set to 3, meaning that only the 3 most recent checkpoints will be kept
manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


In [8]:
# Define the directory and file for storing the TensorBoard summary logs
summary_log_file = os.path.join(
    WORKDIR, 'tf-summary', datetime.now().strftime("%Y%m%d-%H%M%S"))
# Create a summary writer for writing the summary logs
summary_writer = tf.summary.create_file_writer(summary_log_file)

@tf.function
def train_step(input_image, target, epoch):
    """
    This function performs one training step for the generator and discriminator.

    Parameters:
    input_image (Tensor): The input image
    target (Tensor): The target image
    epoch (int): The current epoch number

    Returns:
    gen_total_loss (Tensor): The total loss for the generator
    disc_loss (Tensor): The loss for the discriminator
    """

    # Open a GradientTape context for automatic differentiation
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_total_loss = 0  # Initialize the total loss for the generator
        gen_gan_loss = 0  # Initialize the GAN loss for the generator
        gen_l2_loss = 0  # Initialize the L2 loss for the generator
        disc_loss = 0  # Initialize the loss for the discriminator

        # Generate an image using the generator
        # (remember to set training=True, since during training time BatchNormalization layers are applied using the batch mean and variance)
        gen_output = generator(input_image, training=True)

        # Get the discriminator's predictions on the real and generated images
        # (remember to set training=True, since during training time BatchNormalization layers are applied using the batch mean and variance)
        disc_real_output = discriminator(target, training=True)
        disc_generated_output = discriminator(gen_output, training=True)

        # Calculate the losses for the generator and discriminator
        gen_total_loss, gen_gan_loss, gen_l2_loss = generator_loss(disc_generated_output, gen_output, target)
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

    # Calculate the gradients of the generator's and discriminator's losses with respect to the trainable variables of each model
    # Use the gradient() method of the GradientTape object of each model to calculate the gradients of the losses with respect to the trainable variables of the corresponding model
    gen_gradients = gen_tape.gradient(gen_total_loss, generator.trainable_variables)
    disc_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    # Apply the gradients to the model's trainable variables
    # Use the apply_gradients() method of the optimizer object to apply the gradients
    # The apply_gradients() method takes a list of (gradient, variable) pairs
    # zip() function can be used to iterate over two lists simultaneously (i.e., the gradients and the model's trainable variables)
    generator_optimizer.apply_gradients(zip(gen_gradients, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(disc_gradients, discriminator.trainable_variables))

    # Write the losses to the summary logs
    with summary_writer.as_default():
        tf.summary.scalar('gen_total_loss', gen_total_loss, step=epoch)
        tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=epoch)
        tf.summary.scalar('gen_l2_loss', gen_l2_loss, step=epoch)
        tf.summary.scalar('disc_loss', disc_loss, step=epoch)

    return gen_total_loss, disc_loss


In [9]:
# The checkpoint manager will load the latest checkpoint if it exists
# This will be used to resume training if the training process was interrupted
# If no checkpoint exists, the models will be initialized from scratch

# Restore the latest checkpoint using the checkpoint manager
checkpoint.restore(manager.latest_checkpoint)

# If a checkpoint was found and restored, print a message indicating the checkpoint file
if manager.latest_checkpoint:
    print('Restored from {}'.format(manager.latest_checkpoint))

# If no checkpoint was found, print a message indicating that the models are being initialized from scratch
else:
    print('Initializing from scratch')

Initializing from scratch


In [None]:
# Loop over the epochs
for e in tqdm(range(EPOCHS)):
    # Record the start time of the epoch
    start_time = time()

    gen_loss_total = 0  # Initialize the total loss for the generator
    disc_loss_total = 0  # Initialize the total loss for the discriminator

    # Loop over the training dataset
    for input_image, target in train_dataset:
        # Perform one training step and get the generator and discriminator losses
        gen_loss, disc_loss = train_step(input_image, target, e)

        # Add the losses to the total losses
        gen_loss_total += gen_loss
        disc_loss_total += disc_loss

    # Calculate the time taken for the epoch
    time_taken = time() - start_time

    # If the epoch number is a multiple of 10, save a checkpoint
    if (e + 1) % 10 == 0:
        checkpoint.save(file_prefix=checkpoint_prefix)

    # Print the epoch number, average generator loss, average discriminator loss, and time taken
    print(f'Epoch {e + 1}: gen loss: {gen_loss_total / len(train_dataset)}, '
          f'disc loss: {disc_loss_total / len(train_dataset)}, time: {time_taken:.2f}s')


  output, from_logits = _get_logits(
  0%|          | 1/200 [00:39<2:09:30, 39.05s/it]

Epoch 1: gen loss: 3.7049975395202637, disc loss: 1.6393542289733887, time: 39.04s


  1%|          | 2/200 [01:09<1:52:01, 33.95s/it]

Epoch 2: gen loss: 2.458028554916382, disc loss: 1.50822114944458, time: 30.35s


  2%|▏         | 3/200 [01:38<1:44:17, 31.76s/it]

Epoch 3: gen loss: 2.3558623790740967, disc loss: 1.4559720754623413, time: 29.16s


  2%|▏         | 4/200 [02:09<1:42:56, 31.51s/it]

Epoch 4: gen loss: 2.2841339111328125, disc loss: 1.4275315999984741, time: 31.12s


  2%|▎         | 5/200 [02:39<1:40:51, 31.03s/it]

Epoch 5: gen loss: 2.2325351238250732, disc loss: 1.4124574661254883, time: 30.18s


  3%|▎         | 6/200 [03:10<1:39:51, 30.89s/it]

Epoch 6: gen loss: 2.2036514282226562, disc loss: 1.3982446193695068, time: 30.60s


  4%|▎         | 7/200 [03:40<1:38:15, 30.55s/it]

Epoch 7: gen loss: 2.181772232055664, disc loss: 1.3944933414459229, time: 29.85s


  4%|▍         | 8/200 [04:10<1:37:49, 30.57s/it]

Epoch 8: gen loss: 2.1679811477661133, disc loss: 1.3899670839309692, time: 30.62s


  4%|▍         | 9/200 [04:40<1:36:44, 30.39s/it]

Epoch 9: gen loss: 2.159566879272461, disc loss: 1.3889145851135254, time: 29.99s


  5%|▌         | 10/200 [05:11<1:36:44, 30.55s/it]

Epoch 10: gen loss: 2.1507558822631836, disc loss: 1.3883802890777588, time: 30.67s


  6%|▌         | 11/200 [05:41<1:35:29, 30.31s/it]

Epoch 11: gen loss: 2.1426913738250732, disc loss: 1.3876473903656006, time: 29.77s


  6%|▌         | 12/200 [06:12<1:35:15, 30.40s/it]

Epoch 12: gen loss: 2.1354644298553467, disc loss: 1.3873909711837769, time: 30.59s


  6%|▋         | 13/200 [06:41<1:33:57, 30.15s/it]

Epoch 13: gen loss: 2.128920316696167, disc loss: 1.387600302696228, time: 29.57s


  7%|▋         | 14/200 [07:13<1:34:50, 30.59s/it]

Epoch 14: gen loss: 2.1214559078216553, disc loss: 1.3876855373382568, time: 31.62s


  8%|▊         | 15/200 [07:43<1:33:56, 30.47s/it]

Epoch 15: gen loss: 2.114452362060547, disc loss: 1.387860655784607, time: 30.18s


  8%|▊         | 16/200 [08:13<1:33:12, 30.40s/it]

Epoch 16: gen loss: 2.110053300857544, disc loss: 1.3876681327819824, time: 30.22s


  8%|▊         | 17/200 [08:43<1:32:06, 30.20s/it]

Epoch 17: gen loss: 2.1050384044647217, disc loss: 1.3872246742248535, time: 29.74s


  9%|▉         | 18/200 [09:14<1:31:53, 30.29s/it]

Epoch 18: gen loss: 2.0969276428222656, disc loss: 1.3876535892486572, time: 30.49s


 10%|▉         | 19/200 [09:55<1:41:02, 33.49s/it]

Epoch 19: gen loss: 2.091463804244995, disc loss: 1.3873876333236694, time: 40.94s


 10%|█         | 20/200 [10:25<1:38:04, 32.69s/it]

Epoch 20: gen loss: 2.08453369140625, disc loss: 1.3875373601913452, time: 30.62s


 10%|█         | 21/200 [10:56<1:35:15, 31.93s/it]

Epoch 21: gen loss: 2.0793697834014893, disc loss: 1.3880473375320435, time: 30.15s


 11%|█         | 22/200 [11:26<1:33:17, 31.45s/it]

Epoch 22: gen loss: 2.076314926147461, disc loss: 1.3875656127929688, time: 30.32s


 12%|█▏        | 23/200 [12:07<1:41:10, 34.30s/it]

Epoch 23: gen loss: 2.068007707595825, disc loss: 1.3875757455825806, time: 40.94s


 12%|█▏        | 24/200 [12:37<1:37:02, 33.09s/it]

Epoch 24: gen loss: 2.065131902694702, disc loss: 1.387483835220337, time: 30.26s


 12%|█▎        | 25/200 [13:07<1:33:36, 32.10s/it]

Epoch 25: gen loss: 2.0585598945617676, disc loss: 1.387560486793518, time: 29.79s


 13%|█▎        | 26/200 [13:38<1:32:29, 31.89s/it]

Epoch 26: gen loss: 2.0513556003570557, disc loss: 1.3874738216400146, time: 31.42s


 14%|█▎        | 27/200 [14:08<1:30:23, 31.35s/it]

Epoch 27: gen loss: 2.044816255569458, disc loss: 1.38800048828125, time: 30.08s


 14%|█▍        | 28/200 [14:38<1:28:44, 30.95s/it]

Epoch 28: gen loss: 2.03851056098938, disc loss: 1.3872010707855225, time: 30.03s


 14%|█▍        | 29/200 [15:08<1:27:15, 30.62s/it]

Epoch 29: gen loss: 2.03244948387146, disc loss: 1.3875691890716553, time: 29.81s


 15%|█▌        | 30/200 [15:39<1:26:32, 30.55s/it]

Epoch 30: gen loss: 2.0250394344329834, disc loss: 1.3873522281646729, time: 30.20s


 16%|█▌        | 31/200 [16:09<1:26:18, 30.64s/it]

Epoch 31: gen loss: 2.0183939933776855, disc loss: 1.3877674341201782, time: 30.86s


 16%|█▌        | 32/200 [16:50<1:34:27, 33.73s/it]

Epoch 32: gen loss: 2.012284517288208, disc loss: 1.3876967430114746, time: 40.94s


 16%|█▋        | 33/200 [17:21<1:31:09, 32.75s/it]

Epoch 33: gen loss: 2.0057613849639893, disc loss: 1.3876954317092896, time: 30.45s


 17%|█▋        | 34/200 [17:52<1:29:27, 32.34s/it]

Epoch 34: gen loss: 1.9959646463394165, disc loss: 1.3874255418777466, time: 31.37s


 18%|█▊        | 35/200 [18:23<1:27:45, 31.91s/it]

Epoch 35: gen loss: 1.9894241094589233, disc loss: 1.3873921632766724, time: 30.93s


 18%|█▊        | 36/200 [18:53<1:25:26, 31.26s/it]

Epoch 36: gen loss: 1.9802865982055664, disc loss: 1.387590765953064, time: 29.73s


 18%|█▊        | 37/200 [19:23<1:24:02, 30.93s/it]

Epoch 37: gen loss: 1.972350001335144, disc loss: 1.3874502182006836, time: 30.17s


 19%|█▉        | 38/200 [19:53<1:22:31, 30.56s/it]

Epoch 38: gen loss: 1.965073585510254, disc loss: 1.3872888088226318, time: 29.69s


 20%|█▉        | 39/200 [20:24<1:22:12, 30.64s/it]

Epoch 39: gen loss: 1.9559813737869263, disc loss: 1.3875027894973755, time: 30.81s


 20%|██        | 40/200 [20:54<1:21:16, 30.48s/it]

Epoch 40: gen loss: 1.9484177827835083, disc loss: 1.387494683265686, time: 29.82s


 20%|██        | 41/200 [21:24<1:20:33, 30.40s/it]

Epoch 41: gen loss: 1.9391851425170898, disc loss: 1.3876245021820068, time: 30.20s


 21%|██        | 42/200 [21:54<1:19:53, 30.34s/it]

Epoch 42: gen loss: 1.9316980838775635, disc loss: 1.3872089385986328, time: 30.19s


 22%|██▏       | 43/200 [22:24<1:19:04, 30.22s/it]

Epoch 43: gen loss: 1.9200798273086548, disc loss: 1.3877453804016113, time: 29.94s


 22%|██▏       | 44/200 [22:54<1:18:43, 30.28s/it]

Epoch 44: gen loss: 1.913784384727478, disc loss: 1.3871605396270752, time: 30.42s


 22%|██▎       | 45/200 [23:26<1:19:13, 30.67s/it]

Epoch 45: gen loss: 1.903648853302002, disc loss: 1.3877638578414917, time: 31.58s


 23%|██▎       | 46/200 [23:57<1:18:42, 30.67s/it]

Epoch 46: gen loss: 1.898617148399353, disc loss: 1.3875526189804077, time: 30.63s


 24%|██▎       | 47/200 [24:26<1:17:30, 30.40s/it]

Epoch 47: gen loss: 1.889061689376831, disc loss: 1.3871514797210693, time: 29.77s


 24%|██▍       | 48/200 [24:57<1:16:59, 30.39s/it]

Epoch 48: gen loss: 1.8825485706329346, disc loss: 1.387601375579834, time: 30.36s


 24%|██▍       | 49/200 [25:27<1:16:05, 30.23s/it]

Epoch 49: gen loss: 1.8733025789260864, disc loss: 1.387821078300476, time: 29.86s


 25%|██▌       | 50/200 [25:58<1:16:05, 30.44s/it]

Epoch 50: gen loss: 1.864298701286316, disc loss: 1.3875151872634888, time: 30.73s


 26%|██▌       | 51/200 [26:27<1:15:07, 30.25s/it]

Epoch 51: gen loss: 1.8553065061569214, disc loss: 1.3873761892318726, time: 29.81s


 26%|██▌       | 52/200 [26:58<1:14:35, 30.24s/it]

Epoch 52: gen loss: 1.8480098247528076, disc loss: 1.3876792192459106, time: 30.22s


 26%|██▋       | 53/200 [27:39<1:21:57, 33.45s/it]

Epoch 53: gen loss: 1.8400933742523193, disc loss: 1.3873746395111084, time: 40.94s


 27%|██▋       | 54/200 [28:09<1:19:14, 32.57s/it]

Epoch 54: gen loss: 1.8328487873077393, disc loss: 1.3872464895248413, time: 30.49s


 28%|██▊       | 55/200 [28:39<1:17:01, 31.87s/it]

Epoch 55: gen loss: 1.8259116411209106, disc loss: 1.387290596961975, time: 30.25s


 28%|██▊       | 56/200 [29:09<1:14:57, 31.24s/it]

Epoch 56: gen loss: 1.8183481693267822, disc loss: 1.3874889612197876, time: 29.74s


 28%|██▊       | 57/200 [29:50<1:21:23, 34.15s/it]

Epoch 57: gen loss: 1.809476375579834, disc loss: 1.3874295949935913, time: 40.94s


 29%|██▉       | 58/200 [30:22<1:19:29, 33.59s/it]

Epoch 58: gen loss: 1.8019359111785889, disc loss: 1.3874928951263428, time: 32.29s


 30%|██▉       | 59/200 [30:53<1:16:42, 32.64s/it]

Epoch 59: gen loss: 1.7946065664291382, disc loss: 1.387264370918274, time: 30.42s


 30%|███       | 60/200 [31:22<1:14:06, 31.76s/it]

Epoch 60: gen loss: 1.7893143892288208, disc loss: 1.387341856956482, time: 29.51s


 30%|███       | 61/200 [31:53<1:12:43, 31.39s/it]

Epoch 61: gen loss: 1.7809807062149048, disc loss: 1.3877849578857422, time: 30.50s


 31%|███       | 62/200 [32:23<1:11:21, 31.02s/it]

Epoch 62: gen loss: 1.7747136354446411, disc loss: 1.387382984161377, time: 30.17s


 32%|███▏      | 63/200 [32:53<1:10:18, 30.79s/it]

Epoch 63: gen loss: 1.7656737565994263, disc loss: 1.3873337507247925, time: 30.25s


 32%|███▏      | 64/200 [33:24<1:09:20, 30.60s/it]

Epoch 64: gen loss: 1.7579622268676758, disc loss: 1.387423038482666, time: 30.13s


 32%|███▎      | 65/200 [33:54<1:08:32, 30.47s/it]

Epoch 65: gen loss: 1.754616141319275, disc loss: 1.3869895935058594, time: 30.16s


 33%|███▎      | 66/200 [34:24<1:07:50, 30.37s/it]

Epoch 66: gen loss: 1.7456787824630737, disc loss: 1.3874919414520264, time: 30.16s


 34%|███▎      | 67/200 [34:54<1:07:17, 30.36s/it]

Epoch 67: gen loss: 1.7410191297531128, disc loss: 1.387588381767273, time: 30.32s


 34%|███▍      | 68/200 [35:35<1:13:46, 33.54s/it]

Epoch 68: gen loss: 1.7322431802749634, disc loss: 1.3874537944793701, time: 40.94s


 34%|███▍      | 69/200 [36:05<1:10:55, 32.48s/it]

Epoch 69: gen loss: 1.7275174856185913, disc loss: 1.3875343799591064, time: 30.02s


 35%|███▌      | 70/200 [36:36<1:09:17, 31.98s/it]

Epoch 70: gen loss: 1.7230799198150635, disc loss: 1.3872079849243164, time: 30.55s


 36%|███▌      | 71/200 [37:06<1:07:26, 31.37s/it]

Epoch 71: gen loss: 1.7164891958236694, disc loss: 1.3872785568237305, time: 29.93s


 36%|███▌      | 72/200 [37:36<1:06:17, 31.07s/it]

Epoch 72: gen loss: 1.712570309638977, disc loss: 1.38730788230896, time: 30.38s


 36%|███▋      | 73/200 [38:06<1:04:47, 30.61s/it]

Epoch 73: gen loss: 1.702734112739563, disc loss: 1.3874722719192505, time: 29.54s


 37%|███▋      | 74/200 [38:39<1:05:43, 31.30s/it]

Epoch 74: gen loss: 1.7007900476455688, disc loss: 1.3873214721679688, time: 32.89s


 38%|███▊      | 75/200 [39:09<1:04:29, 30.95s/it]

Epoch 75: gen loss: 1.693427562713623, disc loss: 1.3873372077941895, time: 30.16s


 38%|███▊      | 76/200 [39:39<1:03:33, 30.75s/it]

Epoch 76: gen loss: 1.6850945949554443, disc loss: 1.3875353336334229, time: 30.27s


 38%|███▊      | 77/200 [40:09<1:02:29, 30.49s/it]

Epoch 77: gen loss: 1.682045340538025, disc loss: 1.3876080513000488, time: 29.85s


 39%|███▉      | 78/200 [40:39<1:01:52, 30.43s/it]

Epoch 78: gen loss: 1.6805349588394165, disc loss: 1.3872801065444946, time: 30.31s


 40%|███▉      | 79/200 [41:20<1:07:44, 33.59s/it]

Epoch 79: gen loss: 1.6747374534606934, disc loss: 1.3873651027679443, time: 40.94s


 40%|████      | 80/200 [41:51<1:05:21, 32.68s/it]

Epoch 80: gen loss: 1.6662064790725708, disc loss: 1.3871358633041382, time: 30.37s


 40%|████      | 81/200 [42:22<1:03:45, 32.14s/it]

Epoch 81: gen loss: 1.6623152494430542, disc loss: 1.3875720500946045, time: 30.89s


 41%|████      | 82/200 [42:52<1:02:03, 31.55s/it]

Epoch 82: gen loss: 1.6569325923919678, disc loss: 1.3874343633651733, time: 30.15s


 42%|████▏     | 83/200 [43:22<1:00:56, 31.25s/it]

Epoch 83: gen loss: 1.6503446102142334, disc loss: 1.3870989084243774, time: 30.54s


 42%|████▏     | 84/200 [43:52<59:39, 30.85s/it]  

Epoch 84: gen loss: 1.6485432386398315, disc loss: 1.38784658908844, time: 29.93s


 42%|████▎     | 85/200 [44:23<59:09, 30.87s/it]

Epoch 85: gen loss: 1.6405937671661377, disc loss: 1.3873679637908936, time: 30.90s


 43%|████▎     | 86/200 [45:04<1:04:23, 33.89s/it]

Epoch 86: gen loss: 1.6350053548812866, disc loss: 1.3874417543411255, time: 40.95s


 44%|████▎     | 87/200 [45:35<1:02:15, 33.06s/it]

Epoch 87: gen loss: 1.6346244812011719, disc loss: 1.387006163597107, time: 31.10s


 44%|████▍     | 88/200 [46:06<1:00:07, 32.21s/it]

Epoch 88: gen loss: 1.6293925046920776, disc loss: 1.3876484632492065, time: 30.23s


 44%|████▍     | 89/200 [46:37<59:01, 31.91s/it]  

Epoch 89: gen loss: 1.6220396757125854, disc loss: 1.3874485492706299, time: 31.18s


 45%|████▌     | 90/200 [47:08<57:57, 31.62s/it]

Epoch 90: gen loss: 1.6204862594604492, disc loss: 1.3874502182006836, time: 30.64s


 46%|████▌     | 91/200 [47:38<56:35, 31.16s/it]

Epoch 91: gen loss: 1.617641568183899, disc loss: 1.3873310089111328, time: 30.08s


 46%|████▌     | 92/200 [48:19<1:01:21, 34.09s/it]

Epoch 92: gen loss: 1.6139642000198364, disc loss: 1.3873534202575684, time: 40.94s


 46%|████▋     | 93/200 [48:50<59:02, 33.11s/it]  

Epoch 93: gen loss: 1.608139157295227, disc loss: 1.3876116275787354, time: 30.81s


 47%|████▋     | 94/200 [49:24<59:04, 33.44s/it]

Epoch 94: gen loss: 1.6051746606826782, disc loss: 1.387412190437317, time: 34.22s


 48%|████▊     | 95/200 [49:54<57:01, 32.58s/it]

Epoch 95: gen loss: 1.600922703742981, disc loss: 1.3874377012252808, time: 30.58s


 48%|████▊     | 96/200 [50:24<55:12, 31.85s/it]

Epoch 96: gen loss: 1.593453288078308, disc loss: 1.3871852159500122, time: 30.13s


 48%|████▊     | 97/200 [50:54<53:41, 31.28s/it]

Epoch 97: gen loss: 1.5954186916351318, disc loss: 1.387518286705017, time: 29.94s


 49%|████▉     | 98/200 [51:25<52:42, 31.01s/it]

Epoch 98: gen loss: 1.5847903490066528, disc loss: 1.3873754739761353, time: 30.38s


 50%|████▉     | 99/200 [51:56<52:03, 30.93s/it]

Epoch 99: gen loss: 1.585731863975525, disc loss: 1.387172818183899, time: 30.73s


 50%|█████     | 100/200 [52:26<51:15, 30.76s/it]

Epoch 100: gen loss: 1.5805468559265137, disc loss: 1.3873276710510254, time: 30.04s


 50%|█████     | 101/200 [52:57<50:47, 30.78s/it]

Epoch 101: gen loss: 1.5788631439208984, disc loss: 1.387142539024353, time: 30.84s


 51%|█████     | 102/200 [53:27<49:47, 30.49s/it]

Epoch 102: gen loss: 1.5721873044967651, disc loss: 1.3872731924057007, time: 29.79s


 52%|█████▏    | 103/200 [53:58<49:32, 30.64s/it]

Epoch 103: gen loss: 1.5721672773361206, disc loss: 1.387092113494873, time: 31.01s


 52%|█████▏    | 104/200 [54:28<48:43, 30.45s/it]

Epoch 104: gen loss: 1.5669864416122437, disc loss: 1.3871904611587524, time: 29.99s


 52%|█████▎    | 105/200 [54:59<48:45, 30.79s/it]

Epoch 105: gen loss: 1.5642163753509521, disc loss: 1.3875176906585693, time: 31.58s


 53%|█████▎    | 106/200 [55:31<48:31, 30.97s/it]

Epoch 106: gen loss: 1.5586557388305664, disc loss: 1.3875035047531128, time: 31.35s


 54%|█████▎    | 107/200 [56:01<47:43, 30.79s/it]

Epoch 107: gen loss: 1.5545865297317505, disc loss: 1.3871517181396484, time: 30.36s


 54%|█████▍    | 108/200 [56:32<47:21, 30.89s/it]

Epoch 108: gen loss: 1.5535862445831299, disc loss: 1.3874001502990723, time: 31.11s


 55%|█████▍    | 109/200 [57:02<46:25, 30.61s/it]

Epoch 109: gen loss: 1.5497699975967407, disc loss: 1.38729989528656, time: 29.97s


 55%|█████▌    | 110/200 [57:33<46:03, 30.70s/it]

Epoch 110: gen loss: 1.544152021408081, disc loss: 1.3874722719192505, time: 30.54s


## Plot Results

In [None]:
# Define the number of samples to generate
n_samples = 150
# Generate colorized versions of the first n_samples grayscale images in the test set
Y_hat = generator(X_test[:n_samples])

# Define the number of rows in the grid, which is equal to the number of samples
num_rows = len(Y_hat)

# Define the number of columns in the grid and the size of each image
num_cols = 3  # Number of columns in the grid
img_size = 1  # Size of each image in the grid

# Create a grid of subplots with num_rows rows and num_cols columns
fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols * img_size, num_rows * img_size))
fig.subplots_adjust(hspace=0.3, wspace=0.1)  # Adjust spacing between subplots

# Loop over the grayscale images, original color images, and colorized images
for row, (x, y, y_hat) in enumerate(zip(X_test[:n_samples], Y_test[:n_samples], Y_hat)):

    # Convert the original color image from Lab to RGB
    orig_lab = np.dstack((x, y * 128))
    orig_rgb = lab2rgb(orig_lab)

    # Convert the grayscale image from Lab to RGB
    grayscale_lab = np.dstack((x, np.zeros((IMAGE_SIZE, IMAGE_SIZE, 2))))
    grayscale_rgb = lab2rgb(grayscale_lab)

    # Convert the colorized image from Lab to RGB
    predicted_lab = np.dstack((x, y_hat * 128))
    predicted_rgb = lab2rgb(predicted_lab)

    # Display the grayscale image in the first column of the current row
    ax = axes[row, 0]  # Get the appropriate axis for the current subplot
    ax.axis('off')  # Turn off axis labels
    ax.imshow(grayscale_rgb)
    ax.set_title('Grayscale')

    # Display the original color image in the second column of the current row
    ax = axes[row, 1]  # Move to the next column for original RGB
    ax.axis('off')
    ax.imshow(orig_rgb)
    ax.set_title('Original')

    # Display the colorized image in the third column of the current row
    ax = axes[row, 2]  # Move to the next column for predicted colorized image
    ax.axis('off')
    ax.imshow(predicted_rgb)
    ax.set_title('Predicted')

    # Print the current row number to the console
    sys.stdout.flush()
    sys.stdout.write('\r{} / {}'.format(row + 1, num_rows))

# Hide empty subplots if any
for ax_row in axes:
    for ax in ax_row:
        if not ax.has_data():
            ax.axis('off')

# Adjust the padding between subplots
plt.tight_layout(pad=0.5)
# Save the figure as an image file
plt.savefig(os.path.join(WORKDIR, 'results', 'image_grid.png'))
# Display the figure
plt.show()

## Save the Generator and Discriminator Trained Models

In [None]:
# Save the checkpoints of the models and optimizers so that they can be loaded later
# Save the generator model in the SavedModel format
tf.saved_model.save(generator, os.path.join(WORKDIR, "generator-saved-model"))

# Save the discriminator model in the SavedModel format
tf.saved_model.save(discriminator, os.path.join(WORKDIR, "disciminator-saved-model"))