
# Image Colorization With GANs


In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import tensorflow as tf
import numpy as np
import os

# Parameters
batch_size = 32
img_size = 128
train_val_split = 0.9  # Proporção de treino e validação
master_dir = '/content/drive/MyDrive/Mestrado/2024_2/Topicos_especiais_IA_ThiagoPX/NormalAndMonet'
test_dir = '/content/drive/MyDrive/Mestrado/2024_2/Topicos_especiais_IA_ThiagoPX/testFolder'

# Image loading and preprocessing
def load_and_preprocess_image(image_path):
    # Load RGB image
    img = tf.io.read_file(image_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, [img_size, img_size])
    img = img / 255.0  # Normalize to [0, 1]

    # Convert to grayscale
    gray_img = tf.image.rgb_to_grayscale(img)
    return gray_img, img

# Create train/validation dataset
image_files = [os.path.join(master_dir, fname) for fname in os.listdir(master_dir)]
train_size = int(train_val_split * len(image_files))
train_files = image_files[:train_size]
val_files = image_files[train_size:]

train_dataset = tf.data.Dataset.from_tensor_slices(train_files)
val_dataset = tf.data.Dataset.from_tensor_slices(val_files)

# Create test dataset
test_files = [os.path.join(test_dir, fname) for fname in os.listdir(test_dir)]
test_dataset = tf.data.Dataset.from_tensor_slices(test_files)

# Map the preprocessing function
train_dataset = train_dataset.map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
val_dataset = val_dataset.map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
test_dataset = test_dataset.map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)

# Batch the datasets
train_dataset = train_dataset.batch(batch_size)
val_dataset = val_dataset.batch(batch_size)
test_dataset = test_dataset.batch(batch_size)

# Prefetch for performance optimization
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE)
test_dataset = test_dataset.prefetch(tf.data.AUTOTUNE)

# Print to check the datasets
print(f"Train Dataset: {len(list(train_dataset))} batches")
print(f"Validation Dataset: {len(list(val_dataset))} batches")
print(f"Test Dataset: {len(list(test_dataset))} batches")


Train Dataset: 29 batches
Validation Dataset: 4 batches
Test Dataset: 1 batches


In [3]:
def get_generator_model():
    inputs = tf.keras.layers.Input(shape=(img_size, img_size, 1))

    # Encoder com Dropout
    conv1 = tf.keras.layers.Conv2D(16, kernel_size=(5, 5), strides=1)(inputs)
    conv1 = tf.keras.layers.LeakyReLU()(conv1)
    conv1 = tf.keras.layers.Conv2D(32, kernel_size=(3, 3), strides=1)(conv1)
    conv1 = tf.keras.layers.LeakyReLU()(conv1)
    conv1 = tf.keras.layers.Dropout(0.2)(conv1)
    conv1 = tf.keras.layers.Conv2D(32, kernel_size=(3, 3), strides=1)(conv1)
    conv1 = tf.keras.layers.LeakyReLU()(conv1)

    conv2 = tf.keras.layers.Conv2D(32, kernel_size=(5, 5), strides=1)(conv1)
    conv2 = tf.keras.layers.LeakyReLU()(conv2)
    conv2 = tf.keras.layers.Conv2D(64, kernel_size=(3, 3), strides=1)(conv2)
    conv2 = tf.keras.layers.LeakyReLU()(conv2)
    conv2 = tf.keras.layers.Dropout(0.2)(conv2)
    conv2 = tf.keras.layers.Conv2D(64, kernel_size=(3, 3), strides=1)(conv2)
    conv2 = tf.keras.layers.LeakyReLU()(conv2)

    conv3 = tf.keras.layers.Conv2D(64, kernel_size=(5, 5), strides=1)(conv2)
    conv3 = tf.keras.layers.LeakyReLU()(conv3)
    conv3 = tf.keras.layers.Conv2D(128, kernel_size=(3, 3), strides=1)(conv3)
    conv3 = tf.keras.layers.LeakyReLU()(conv3)
    conv3 = tf.keras.layers.Dropout(0.2)(conv3)
    conv3 = tf.keras.layers.Conv2D(128, kernel_size=(3, 3), strides=1)(conv3)
    conv3 = tf.keras.layers.LeakyReLU()(conv3)

    # Bottleneck
    bottleneck = tf.keras.layers.Conv2D(256, kernel_size=(5, 5), strides=1, activation='tanh', padding='same')(conv3)

    # Decoder com Dropout
    concat_1 = tf.keras.layers.Concatenate()([bottleneck, conv3])
    conv_up_3 = tf.keras.layers.Conv2DTranspose(128, kernel_size=(3, 3), strides=1, activation='relu')(concat_1)
    conv_up_3 = tf.keras.layers.Conv2DTranspose(128, kernel_size=(3, 3), strides=1, activation='relu')(conv_up_3)
    conv_up_3 = tf.keras.layers.Dropout(0.2)(conv_up_3)
    conv_up_3 = tf.keras.layers.Conv2DTranspose(64, kernel_size=(5, 5), strides=1, activation='relu')(conv_up_3)

    concat_2 = tf.keras.layers.Concatenate()([conv_up_3, conv2])
    conv_up_2 = tf.keras.layers.Conv2DTranspose(64, kernel_size=(3, 3), strides=1, activation='relu')(concat_2)
    conv_up_2 = tf.keras.layers.Conv2DTranspose(64, kernel_size=(3, 3), strides=1, activation='relu')(conv_up_2)
    conv_up_2 = tf.keras.layers.Dropout(0.2)(conv_up_2)
    conv_up_2 = tf.keras.layers.Conv2DTranspose(32, kernel_size=(5, 5), strides=1, activation='relu')(conv_up_2)

    concat_3 = tf.keras.layers.Concatenate()([conv_up_2, conv1])
    conv_up_1 = tf.keras.layers.Conv2DTranspose(32, kernel_size=(3, 3), strides=1, activation='relu')(concat_3)
    conv_up_1 = tf.keras.layers.Conv2DTranspose(32, kernel_size=(3, 3), strides=1, activation='relu')(conv_up_1)
    conv_up_1 = tf.keras.layers.Conv2DTranspose(3, kernel_size=(5, 5), strides=1, activation='relu')(conv_up_1)

    model = tf.keras.models.Model(inputs, conv_up_1)
    return model


In [4]:
def get_discriminator_model():
    layers = [
        tf.keras.layers.Conv2D(32, kernel_size=(7, 7), strides=1, activation='relu', input_shape=(img_size, img_size, 3)),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Conv2D(32, kernel_size=(7, 7), strides=1, activation='relu'),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Conv2D(64, kernel_size=(5, 5), strides=1, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Conv2D(64, kernel_size=(5, 5), strides=1, activation='relu'),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Conv2D(128, kernel_size=(3, 3), strides=1, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Conv2D(128, kernel_size=(3, 3), strides=1, activation='relu'),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Conv2D(256, kernel_size=(3, 3), strides=1, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Conv2D(256, kernel_size=(3, 3), strides=1, activation='relu'),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(512, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(16, activation='relu'),
        tf.keras.layers.Dense(1, activation='sigmoid')
    ]
    model = tf.keras.models.Sequential(layers)
    return model


In [5]:
import tensorflow as tf

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=False)
mse = tf.keras.losses.MeanSquaredError()

def discriminator_loss(real_output, fake_output):

    real_labels = tf.ones_like(real_output) - tf.random.uniform(shape=real_output.shape, maxval=0.1)
    fake_labels = tf.zeros_like(fake_output) + tf.random.uniform(shape=fake_output.shape, maxval=0.1)

    real_loss = cross_entropy(real_labels, real_output)
    fake_loss = cross_entropy(fake_labels, fake_output)

    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output, real_y):
    """
    Calcula a perda do gerador usando o erro médio quadrático (MSE).
    """
    real_y = tf.cast(real_y, tf.float32)
    return mse(fake_output, real_y)

learning_rate = 0.001
generator_optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.5, beta_2=0.999)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.5, beta_2=0.999)

generator = get_generator_model()
discriminator = get_discriminator_model()


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


In [6]:
import tensorflow as tf
from tqdm import tqdm

gen_loss_metric = tf.keras.metrics.Mean(name="gen_loss")
disc_loss_metric = tf.keras.metrics.Mean(name="disc_loss")

early_stopping_patience = 10
best_gen_loss = float('inf')
patience_counter = 0

@tf.function
def train_step(input_x, real_y):

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(input_x, training=True)

        real_output = discriminator(real_y, training=True)
        generated_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(generated_images, real_y)
        disc_loss = discriminator_loss(real_output, generated_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    gen_loss_metric.update_state(gen_loss)
    disc_loss_metric.update_state(disc_loss)

def train(dataset, epochs):
    """
    Loop principal de treinamento.
    """
    global best_gen_loss, patience_counter

    for epoch in range(epochs):
        gen_loss_metric.reset_state()
        disc_loss_metric.reset_state()

        print(f"\nEpoch {epoch+1}/{epochs}")
        progress_bar = tqdm(dataset, desc="Training", leave=True)

        for input_x, real_y in progress_bar:
            train_step(input_x, real_y)
            progress_bar.set_postfix({
                "Generator Loss": f"{gen_loss_metric.result():.4f}",
                "Discriminator Loss": f"{disc_loss_metric.result():.4f}"
            })

        print(f"Generator Loss: {gen_loss_metric.result():.4f}, Discriminator Loss: {disc_loss_metric.result():.4f}")

        current_gen_loss = gen_loss_metric.result()
        if current_gen_loss < best_gen_loss:
            best_gen_loss = current_gen_loss
            patience_counter = 0
            print("Validation improvement. Saving best model...\n")
            generator.save_weights("/content/drive/MyDrive/Mestrado/2024_2/Topicos_especiais_IA_ThiagoPX/best_generator.weights.h5")
            discriminator.save_weights("/content/drive/MyDrive/Mestrado/2024_2/Topicos_especiais_IA_ThiagoPX/best_discriminator.weights.h5")
        else:
            patience_counter += 1
            print(f"No improvement. Early stopping patience: {patience_counter}/{early_stopping_patience}")


        if patience_counter >= early_stopping_patience:
            print("Early stopping triggered. Stopping training.")
            break



In [7]:
train(dataset=train_dataset, epochs=100)


Epoch 1/100


Training: 100%|██████████| 29/29 [01:11<00:00,  2.48s/it, Generator Loss=0.1043, Discriminator Loss=1.4481]


Generator Loss: 0.1043, Discriminator Loss: 1.4481
Validation improvement. Saving best model...


Epoch 2/100


Training: 100%|██████████| 29/29 [00:40<00:00,  1.41s/it, Generator Loss=0.0230, Discriminator Loss=4.0285]


Generator Loss: 0.0230, Discriminator Loss: 4.0285
Validation improvement. Saving best model...


Epoch 3/100


Training: 100%|██████████| 29/29 [00:40<00:00,  1.41s/it, Generator Loss=0.0176, Discriminator Loss=1.3863]


Generator Loss: 0.0176, Discriminator Loss: 1.3863
Validation improvement. Saving best model...


Epoch 4/100


Training: 100%|██████████| 29/29 [00:40<00:00,  1.41s/it, Generator Loss=0.0149, Discriminator Loss=1.3864]


Generator Loss: 0.0149, Discriminator Loss: 1.3864
Validation improvement. Saving best model...


Epoch 5/100


Training: 100%|██████████| 29/29 [00:40<00:00,  1.41s/it, Generator Loss=0.0145, Discriminator Loss=1.3863]


Generator Loss: 0.0145, Discriminator Loss: 1.3863
Validation improvement. Saving best model...


Epoch 6/100


Training: 100%|██████████| 29/29 [00:40<00:00,  1.41s/it, Generator Loss=0.0139, Discriminator Loss=1.3856]


Generator Loss: 0.0139, Discriminator Loss: 1.3856
Validation improvement. Saving best model...


Epoch 7/100


Training: 100%|██████████| 29/29 [00:27<00:00,  1.05it/s, Generator Loss=0.0164, Discriminator Loss=1.4311]


Generator Loss: 0.0164, Discriminator Loss: 1.4311
No improvement. Early stopping patience: 1/10

Epoch 8/100


Training:  17%|█▋        | 5/29 [00:05<00:28,  1.18s/it, Generator Loss=0.1468, Discriminator Loss=1.3861]


KeyboardInterrupt: 


## **4. Results**

We plotted the input, output and the original images respectively, from a part of the dataset to find out the results.


In [8]:
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

def display_results(generator, test_dataset, img_size=256, output_size=(1024, 1024), weights_path="/content/drive/MyDrive/Mestrado/2024_2/Topicos_especiais_IA_ThiagoPX/best_generator.weights.h5"):
    try:
        generator.load_weights(weights_path)
        print(f"Pesos carregados com sucesso de: {weights_path}")
    except Exception as e:
        print(f"Erro ao carregar pesos: {e}")
        return

    # Iterate through the batches in the test dataset
    for batch_x, batch_y in test_dataset:
        y_pred = generator(batch_x).numpy() # Pass a batch of images to the generator

        # Iterate through images in the current batch
        for i, (input_img, target_img, output_img) in enumerate(zip(batch_x, batch_y, y_pred)):
            fig, axes = plt.subplots(1, 3, figsize=(15, 5))

            # Convert input_img to NumPy array before reshaping
            input_img_np = input_img.numpy()
            input_resized = Image.fromarray(input_img_np.reshape((img_size, img_size)) * 255).resize(output_size)
            axes[0].imshow(input_resized, cmap='gray')
            axes[0].set_title("Imagem Cinza", fontsize=12)
            axes[0].axis("off")

            # Convert target_img to NumPy array before using astype
            target_img_np = target_img.numpy()
            target_resized = Image.fromarray((target_img_np * 255).astype('uint8')).resize(output_size)
            axes[1].imshow(target_resized)
            axes[1].set_title("Saída Objetivo", fontsize=12)
            axes[1].axis("off")

            output_resized = Image.fromarray((output_img * 255).astype('uint8')).resize(output_size)
            axes[2].imshow(output_resized)
            axes[2].set_title("Imagem Colorida NN", fontsize=12)
            axes[2].axis("off")

            plt.tight_layout()
            plt.show()

display_results(generator, test_dataset, img_size=img_size, output_size=(1024, 1024), weights_path="/content/drive/MyDrive/Mestrado/2024_2/Topicos_especiais_IA_ThiagoPX/best_generator.weights.h5")

Output hidden; open in https://colab.research.google.com to view.