In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cv2
from tensorflow import keras
from tensorflow.keras.layers import Conv2D, UpSampling2D, InputLayer, Conv2DTranspose
from tensorflow.keras.models import Sequential
from skimage.color import rgb2lab, lab2rgb
from skimage import io
from tensorflow.keras.preprocessing.image import img_to_array, load_img, array_to_img
import os


In [None]:
# Cell 2: Function to Load Images
def load_images(image_folder, size=(256, 256)):
    images = []
    for filename in os.listdir(image_folder):
        img = load_img(os.path.join(image_folder, filename), target_size=size)
        img = img_to_array(img) / 255.0  # Normalize the image
        images.append(img)
    return np.array(images)

In [None]:
# Cell 3: Function to Preprocess Images
def preprocess_images(images):
    lab_images = rgb2lab(images)  # Convert RGB to LAB color space
    X = lab_images[:, :, :, 0]  # Extract L channel
    Y = lab_images[:, :, :, 1:]  # Extract a and b channels
    X = X.reshape(X.shape + (1,))  # Reshape for input into CNN
    return X, Y

In [None]:
# Cell 4: Function to Build the Model
def build_model():
    model = Sequential([
        InputLayer(input_shape=(256, 256, 1)),
        Conv2D(64, (3, 3), activation='relu', padding='same'),
        Conv2D(64, (3, 3), activation='relu', padding='same', strides=2),
        Conv2D(128, (3, 3), activation='relu', padding='same'),
        Conv2D(128, (3, 3), activation='relu', padding='same', strides=2),
        Conv2D(256, (3, 3), activation='relu', padding='same'),
        Conv2D(256, (3, 3), activation='relu', padding='same', strides=2),
        
        # Upsampling layers
        UpSampling2D((2, 2)),
        Conv2D(128, (3, 3), activation='relu', padding='same'),
        UpSampling2D((2, 2)),
        Conv2D(64, (3, 3), activation='relu', padding='same'),
        UpSampling2D((2, 2)),
        Conv2D(32, (3, 3), activation='relu', padding='same'),

        # Final output layer
        Conv2D(2, (3, 3), activation='tanh', padding='same')  # 2 channels for 'a' and 'b'
    ])
    model.compile(optimizer='adam', loss='mean_squared_error')
    model.summary()
    return model

In [None]:
# Cell 5: Function to Train the Model
def train_model(model, X, Y, epochs=50, batch_size=16):
    model.fit(X, Y, epochs=epochs, batch_size=batch_size)
    model.save("image_colorization_model.keras")

In [None]:
# Cell 6: Model Training and Saving
# Path to your image folder
image_folder = 'images/'  # Replace with your dataset path
images = load_images(image_folder)
X, Y = preprocess_images(images)

# Build and train the model
model = build_model()
train_model(model, X, Y, epochs=100, batch_size=16)

In [None]:
# Cell 7: Image Colorization Loop
# Create directory for output images if it doesn't exist
output_folder = 'colorized_images/'
os.makedirs(output_folder, exist_ok=True)

# Loop through all images in the folder for colorization
for filename in os.listdir(image_folder):
    # Ensure the file is an image
    if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
        test_image = load_img(os.path.join(image_folder, filename), target_size=(256, 256), color_mode='grayscale')
        test_image = img_to_array(test_image) / 255.0
        test_image = test_image.reshape((1, 256, 256, 1))

        # Predict the a and b channels using the trained model
        predicted_ab = model.predict(test_image)

        # Combine with the L channel to form a color image
        test_image_lab = np.zeros((256, 256, 3))
        test_image_lab[:, :, 0] = test_image[0, :, :, 0] * 100  # Rescale L channel to original scale
        test_image_lab[:, :, 1:] = predicted_ab[0] * 128  # Rescale ab channels

        # Convert LAB image to RGB
        colorized_image = lab2rgb(test_image_lab)

        # Ensure the image data is in the range [0, 255] and convert to uint8
        colorized_image = (colorized_image * 255).astype(np.uint8)

        # Save the colorized image
        output_filename = os.path.join(output_folder, f'colorized_{filename}')
        io.imsave(output_filename, colorized_image)

        # Optionally display the images
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.title("Original Image")
        plt.imshow(array_to_img(test_image[0]), cmap='gray')

        plt.subplot(1, 2, 2)
        plt.title("Colorized Image")
        plt.imshow(colorized_image)
        plt.show()
