In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras import layers, models
from keras.utils import load_img, img_to_array
import os
import matplotlib.pyplot as plt
from skimage import color

In [None]:
# Function to load and preprocess a single image
def load_and_preprocess_image(image_path, target_size=(128, 128)):
    img = load_img(image_path, target_size=target_size, color_mode='rgb')
    img_array = img_to_array(img)
    return img_array

# Function to convert RGB to Lab color space
def rgb_to_lab(rgb_image):
    return color.rgb2lab(rgb_image / 255.0)

# Function to convert Lab to RGB color space
def lab_to_rgb(lab_image):
    return color.lab2rgb(lab_image) * 255.0

In [None]:
# Custom data generator
def custom_image_generator(bw_dir, color_dir, batch_size):
    try:
        bw_files = [os.path.join(bw_dir, f) for f in os.listdir(bw_dir) if f.lower().endswith(('.jpg', '.png'))]
        color_files = [os.path.join(color_dir, f) for f in os.listdir(color_dir) if f.lower().endswith(('.jpg', '.png'))]
    except UnicodeEncodeError:
        print("Error: File paths contain unsupported characters. Please ensure all file names use ASCII characters. Just make sure")
        return
    
    while True:
        batch_paths = list(zip(bw_files, color_files))
        np.random.shuffle(batch_paths)
        for i in range(0, len(batch_paths), batch_size):
            batch_bw_paths, batch_color_paths = zip(*batch_paths[i:i+batch_size])
            
            batch_bw = np.array([load_and_preprocess_image(f) for f in batch_bw_paths])
            batch_color = np.array([load_and_preprocess_image(f) for f in batch_color_paths])
            
            # Convert to Lab color space
            batch_bw_lab = rgb_to_lab(batch_bw)[:, :, :, 0]
            batch_color_lab = rgb_to_lab(batch_color)[:, :, :, 1:]
            
            # Normalize
            batch_bw_lab = batch_bw_lab / 50.0 - 1.0
            batch_color_lab = batch_color_lab / 128.0
            
            # Reshape grayscale images
            batch_bw_lab = np.expand_dims(batch_bw_lab, axis=-1)
            
            yield batch_bw_lab, batch_color_lab


In [None]:
import tensorflow as tf
tf.keras.backend.clear_session()

In [None]:
# Function to build the generator (U-Net architecture)
def build_generator():
    def conv2d_block(input_tensor, n_filters, kernel_size=3, batchnorm=True):
        x = layers.Conv2D(n_filters, kernel_size, padding='same')(input_tensor)
        if batchnorm:
            x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)
        x = layers.Conv2D(n_filters, kernel_size, padding='same')(x)
        if batchnorm:
            x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)
        return x

    # Encoder
    inputs = layers.Input(shape=(128, 128, 1))
    conv1 = conv2d_block(inputs, 64)
    pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = conv2d_block(pool1, 128)
    pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2)
    conv3 = conv2d_block(pool2, 256)
    pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv3)
    conv4 = conv2d_block(pool3, 512)
    pool4 = layers.MaxPooling2D(pool_size=(2, 2))(conv4)
    conv5 = conv2d_block(pool4, 1024)

    # Decoder
    up6 = layers.Conv2DTranspose(512, 2, strides=(2, 2), padding='same')(conv5)
    up6 = layers.concatenate([up6, conv4])
    conv6 = conv2d_block(up6, 512)
    up7 = layers.Conv2DTranspose(256, 2, strides=(2, 2), padding='same')(conv6)
    up7 = layers.concatenate([up7, conv3])
    conv7 = conv2d_block(up7, 256)
    up8 = layers.Conv2DTranspose(128, 2, strides=(2, 2), padding='same')(conv7)
    up8 = layers.concatenate([up8, conv2])
    conv8 = conv2d_block(up8, 128)
    up9 = layers.Conv2DTranspose(64, 2, strides=(2, 2), padding='same')(conv8)
    up9 = layers.concatenate([up9, conv1])
    conv9 = conv2d_block(up9, 64)

    outputs = layers.Conv2D(2, 1, activation='tanh')(conv9)
    model = models.Model(inputs=inputs, outputs=outputs)
    return model

In [None]:
def build_discriminator():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, kernel_size=3, strides=2, padding='same', input_shape=(128, 128, 3)))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Conv2D(128, kernel_size=3, strides=2, padding='same'))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Conv2D(256, kernel_size=3, strides=2, padding='same'))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Conv2D(512, kernel_size=3, strides=2, padding='same'))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Flatten())
    model.add(layers.Dense(1, activation='sigmoid'))
    return model

In [None]:
# Function to build the GAN model
def build_gan(generator, discriminator):
    discriminator.trainable = False
    gan_input = layers.Input(shape=(128, 128, 1))
    generated_image = generator(gan_input)
    # Concatenate the input grayscale image with the generated ab channels
    concatenated = layers.Concatenate()([gan_input, generated_image])
    gan_output = discriminator(concatenated)
    gan = models.Model(gan_input, gan_output)
    return gan

In [None]:
import time

def train_gan(generator, discriminator, gan, epochs, steps_per_epoch, train_generator):
    start_time = time.time()
    
    for epoch in range(epochs):
        epoch_start_time = time.time()
        print(f"Epoch {epoch+1}/{epochs}")
        
        d_losses = []
        d_accuracies = []
        g_losses = []
        
        for step in range(steps_per_epoch):
            # Get a batch of images
            bw_batch, color_batch = next(train_generator)

            # Generate colorized images
            generated_images = generator.predict(bw_batch)

            # Concatenate real and fake inputs for discriminator
            real_input = np.concatenate([bw_batch, color_batch], axis=-1)
            fake_input = np.concatenate([bw_batch, generated_images], axis=-1)

            # Train the discriminator
            d_loss_real = discriminator.train_on_batch(real_input, np.ones((bw_batch.shape[0], 1)))
            d_loss_fake = discriminator.train_on_batch(fake_input, np.zeros((bw_batch.shape[0], 1)))
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # Train the generator
            g_loss = gan.train_on_batch(bw_batch, np.ones((bw_batch.shape[0], 1)))

            d_losses.append(d_loss[0])
            d_accuracies.append(d_loss[1])
            g_losses.append(g_loss)

            if step % 10 == 0:
                print(f"  Step {step+1}/{steps_per_epoch} [D loss: {d_loss[0]:.4f} | D accuracy: {100 * d_loss[1]:.2f}%] [G loss: {g_loss:.4f}]")

        # Calculate average losses and accuracy for the epoch
        avg_d_loss = np.mean(d_losses)
        avg_d_accuracy = np.mean(d_accuracies)
        avg_g_loss = np.mean(g_losses)

        epoch_time = time.time() - epoch_start_time
        total_time = time.time() - start_time

        print(f"Epoch {epoch+1}/{epochs} completed in {epoch_time:.2f} seconds")
        print(f"Average D loss: {avg_d_loss:.4f} | Average D accuracy: {100 * avg_d_accuracy:.2f}% | Average G loss: {avg_g_loss:.4f}")
        print(f"Total training time: {total_time:.2f} seconds")

        # Save generated images at the end of each epoch
        save_generated_images(generator, bw_batch, color_batch, epoch + 1)

    print("Training completed.")
    print(f"Total training time: {time.time() - start_time:.2f} seconds")


In [None]:
def save_generated_images(generator, bw_images, color_images, epoch, examples=5):
    # Ensure we don't try to display more images than we have in the batch
    examples = min(examples, bw_images.shape[0])
    
    generated_images = generator.predict(bw_images[:examples])
    
    fig, axs = plt.subplots(examples, 3, figsize=(15, 5*examples))
    for i in range(examples):
        if examples == 1:
            current_ax = axs
        else:
            current_ax = axs[i]
        
        # Display black and white image
        current_ax[0].imshow(bw_images[i, :, :, 0], cmap='gray')
        current_ax[0].axis('off')
        current_ax[0].set_title('Input (Grayscale)')
        
        # Display generated color image
        gen_lab = np.concatenate([bw_images[i], generated_images[i]], axis=-1)
        gen_rgb = lab_to_rgb(gen_lab)
        current_ax[1].imshow(gen_rgb)
        current_ax[1].axis('off')
        current_ax[1].set_title('Generated')
        
        # Display original color image
        orig_lab = np.concatenate([bw_images[i], color_images[i]], axis=-1)
        orig_rgb = lab_to_rgb(orig_lab)
        current_ax[2].imshow(orig_rgb)
        current_ax[2].axis('off')
        current_ax[2].set_title('Ground Truth')

    plt.tight_layout()
    plt.savefig(f"generated_images_epoch_{epoch}.png")
    plt.close()

In [None]:
# Set up paths
train_bw_path = r"C:\Users\USER\Downloads\Capstone Redesigned\data2\train_black"
train_color_path = r"C:\Users\USER\Downloads\Capstone Redesigned\data2\train_color"
test_bw_path = r"C:\Users\USER\Downloads\Capstone Redesigned\data2\test_black"
test_color_path = r"C:\Users\USER\Downloads\Capstone Redesigned\data2\test_color"


In [None]:
# Set training parameters
batch_size = 8
epochs = 5
steps_per_epoch = 25  # Adjust based on your dataset size

In [None]:
# Create data generators
train_generator = custom_image_generator(train_bw_path, train_color_path, batch_size)
test_generator = custom_image_generator(test_bw_path, test_color_path, batch_size)


In [None]:
# Build and compile models
generator = build_generator()
discriminator = build_discriminator()
gan = build_gan(generator, discriminator)

discriminator.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5),
                      loss='binary_crossentropy',
                      metrics=['accuracy'])

gan.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001, beta_1=0.5),
            loss='binary_crossentropy')

# Train the GAN
train_gan(generator, discriminator, gan, epochs, steps_per_epoch, train_generator)

# Save the generator model
generator.save("colorization_generator.h5")

In [None]:
# Generate and save some test images
test_bw_batch, test_color_batch = next(test_generator)
save_generated_images(generator, test_bw_batch, test_color_batch, epoch="final", examples=10)

In [None]:
#test cases