In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, UpSampling2D, BatchNormalization, MaxPooling2D, concatenate
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from skimage.color import rgb2lab, lab2rgb
from skimage.io import imread
from skimage.transform import resize

2024-08-01 06:21:07.904425: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-08-01 06:21:07.908376: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-08-01 06:21:07.920458: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-01 06:21:07.936298: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-01 06:21:07.940442: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-08-01 06:21:07.951989: I tensorflow/core/platform/cpu_feature_gu

In [2]:
# Directory paths
train_color_dir = os.path.join('data', 'train_color')
train_gray_dir = os.path.join('data', 'train_black')
test_color_dir = os.path.join('data', 'test_color')
test_gray_dir = os.path.join('data', 'test_black')

# Function to load and process images
def load_images(color_dir, gray_dir, target_size=(256, 256)):
    color_images = []
    gray_images = []

    color_files = os.listdir(color_dir)
    gray_files = os.listdir(gray_dir)

    for color_file, gray_file in zip(color_files, gray_files):
        # Load and resize images
        color_image = imread(os.path.join(color_dir, color_file))
        gray_image = imread(os.path.join(gray_dir, gray_file))

        # Resize images
        color_image = resize(color_image, target_size)
        gray_image = resize(gray_image, target_size)

        # Convert to LAB
        color_image_lab = rgb2lab(color_image)

        # Append processed images to lists
        color_images.append(color_image_lab)
        gray_images.append(gray_image)

    color_images = np.array(color_images)
    gray_images = np.array(gray_images).reshape(-1, target_size[0], target_size[1], 1)

    return gray_images, color_images

# Load train and test images
X_train_gray, y_train_color = load_images(train_color_dir, train_gray_dir)
X_test_gray, y_test_color = load_images(test_color_dir, test_gray_dir)

# Extract the AB channels from the LAB images
y_train_ab = y_train_color[:, :, :, 1:] / 128
y_test_ab = y_test_color[:, :, :, 1:] / 128

: 

In [1]:
# U-Net Model Architecture
def unet_model(input_shape=(256, 256, 1)):
    inputs = Input(shape=input_shape)

    # Encoder
    conv1 = Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = BatchNormalization()(conv1)
    conv1 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv1)
    conv1 = BatchNormalization()(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = BatchNormalization()(conv2)
    conv2 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv2)
    conv2 = BatchNormalization()(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = BatchNormalization()(conv3)
    conv3 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv3)
    conv3 = BatchNormalization()(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    # Bottleneck
    conv4 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = BatchNormalization()(conv4)
    conv4 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv4)
    conv4 = BatchNormalization()(conv4)

    # Decoder
    up5 = UpSampling2D(size=(2, 2))(conv4)
    up5 = concatenate([up5, conv3], axis=3)
    conv5 = Conv2D(256, (3, 3), activation='relu', padding='same')(up5)
    conv5 = BatchNormalization()(conv5)
    conv5 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv5)
    conv5 = BatchNormalization()(conv5)

    up6 = UpSampling2D(size=(2, 2))(conv5)
    up6 = concatenate([up6, conv2], axis=3)
    conv6 = Conv2D(128, (3, 3), activation='relu', padding='same')(up6)
    conv6 = BatchNormalization()(conv6)
    conv6 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv6)
    conv6 = BatchNormalization()(conv6)

    up7 = UpSampling2D(size=(2, 2))(conv6)
    up7 = concatenate([up7, conv1], axis=3)
    conv7 = Conv2D(64, (3, 3), activation='relu', padding='same')(up7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv7)
    conv7 = BatchNormalization()(conv7)

    outputs = Conv2D(2, (1, 1), activation='tanh', padding='same')(conv7)

    model = Model(inputs, outputs)
    return model

model = unet_model()

# Compile and Train
model.compile(optimizer='adam', loss='mse', metrics=[psnr])  # Ensure 'psnr' is defined

# Train the model using new datasets
model.fit(
    X_train_gray, y_train_ab,
    epochs=50,
    validation_data=(X_test_gray, y_test_ab),
    callbacks=[EarlyStopping(patience=10, restore_best_weights=True)]
)

# Evaluate the Model
print("Evaluate on test data")
results = model.evaluate(X_test_gray, y_test_ab, batch_size=16)
print("Test loss, Test PSNR:", results)

# Visualize Predictions
predictions = model.predict(X_test_gray)

def display_results(bw_images, color_images, predictions, n=5):
    plt.figure(figsize=(20, 10))
    for i in range(n):
        # Display original black and white image
        ax = plt.subplot(3, n, i + 1)
        plt.imshow(bw_images[i].reshape(256, 256), cmap='gray')
        plt.title("Original B&W")
        plt.axis("off")

        # Display colorized image
        ax = plt.subplot(3, n, i + 1 + n)
        plt.imshow(lab2rgb(np.dstack((bw_images[i].reshape(256, 256, 1), predictions[i] * 128))))
        plt.title("Colorized")
        plt.axis("off")

        # Display ground truth color image
        ax = plt.subplot(3, n, i + 1 + 2 * n)
        plt.imshow(lab2rgb(color_images[i]))
        plt.title("Ground Truth")
        plt.axis("off")

# Get predictions on test set
display_results(X_test_gray, y_test_ab, predictions)

2024-08-01 05:04:41.931059: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-08-01 05:04:43.560102: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-08-01 05:04:44.006109: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-01 05:04:45.083923: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-01 05:04:45.349752: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-08-01 05:04:46.472813: I tensorflow/core/platform/cpu_feature_gu

Found 5000 images belonging to 1 classes.
Found 5000 images belonging to 1 classes.
Found 739 images belonging to 1 classes.
Found 739 images belonging to 1 classes.


: 