In [None]:
import os
import numpy as np
from PIL import Image
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
import matplotlib.pyplot as plt

def load_image_pairs(dataset_dir):
    bw_images = []
    color_images = []
    for filename in os.listdir(dataset_dir):
        if filename.startswith("bw_image_"):
            bw_path = os.path.join(dataset_dir, filename)
            color_filename = filename.replace("bw_image_", "color_image_")
            color_path = os.path.join(dataset_dir, color_filename)
            if os.path.exists(color_path):
                bw_images.append(bw_path)
                color_images.append(color_path)
            else:
                print(f"Colored image not found for: {filename}")
        elif filename.startswith("color_image_"):
            continue
        else:
            print(f"Skipping unrelated file: {filename}")
    return bw_images, color_images

def preprocess_image(image_path, size=(256,256)):
    try:
        image = Image.open(image_path).convert('RGB')
        image = image.resize(size, Image.BILINEAR)
        image = np.array(image).astype('float32')
        image = image / 127.5 - 1.0
        return image
    except Exception as e:
        print(f"Error processing {image_path}: {e}")
        return None

def preprocess_dataset(bw_images, color_images, size=(256,256)):
    inputs = []
    targets = []
    for bw_path, color_path in zip(bw_images, color_images):
        bw_im = preprocess_image(bw_path, size)
        color_im = preprocess_image(color_path, size)
        if bw_im is not None and color_im is not None:
            inputs.append(bw_im)
            targets.append(color_im)
    return np.array(inputs), np.array(targets)

def build_generator():
    inp = layers.Input(shape=(256,256,3))
    x = layers.Conv2D(64, 4, strides=2, padding='same')(inp)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.Conv2D(128, 4, strides=2, padding='same')(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.Conv2D(256, 4, strides=2, padding='same')(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.Conv2DTranspose(128, 4, strides=2, padding='same')(x)
    x = layers.ReLU()(x)
    x = layers.Conv2DTranspose(64, 4, strides=2, padding='same')(x)
    x = layers.ReLU()(x)
    out = layers.Conv2DTranspose(3, 4, strides=2, padding='same', activation='tanh')(x)
    return models.Model(inp, out, name="pix2pix_generator")

def build_discriminator():
    inp = layers.Input(shape=(256,256,3))
    tar = layers.Input(shape=(256,256,3))
    merged = layers.Concatenate()([inp, tar])
    x = layers.Conv2D(64, 4, strides=2, padding='same')(merged)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.Conv2D(128, 4, strides=2, padding='same')(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.Conv2D(256, 4, strides=2, padding='same')(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    out = layers.Conv2D(1, 4, strides=1, padding='same', activation='sigmoid')(x)
    return models.Model([inp, tar], out, name="pix2pix_discriminator")

def build_patch_based_model():
    inp = layers.Input(shape=(64,64,3))
    x = layers.Conv2D(64, 4, strides=2, padding='same')(inp)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.Conv2D(128, 4, strides=2, padding='same')(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.Conv2DTranspose(128, 4, strides=2, padding='same')(x)
    x = layers.ReLU()(x)
    x = layers.Conv2DTranspose(64, 4, strides=2, padding='same')(x)
    x = layers.ReLU()(x)
    out = layers.Conv2D(3, 3, strides=1, padding='same', activation='tanh')(x)
    return models.Model(inp, out, name="patch_based_model")

def build_custom_model():
    inp = layers.Input(shape=(256,256,3))
    a = layers.Conv2D(64, 4, strides=2, padding='same')(inp)
    a = layers.LeakyReLU(alpha=0.2)(a)
    a = layers.Conv2D(128, 4, strides=2, padding='same')(a)
    a = layers.LeakyReLU(alpha=0.2)(a)
    a = layers.Conv2D(256, 4, strides=2, padding='same')(a)
    a = layers.LeakyReLU(alpha=0.2)(a)
    a_up = layers.Conv2DTranspose(256, 4, strides=2, padding='same')(a)
    a_up = layers.ReLU()(a_up)
    b = layers.Conv2D(64, 3, strides=1, padding='same')(inp)
    b = layers.ReLU()(b)
    b = layers.Conv2D(128, 3, strides=2, padding='same')(b)
    b = layers.ReLU()(b)
    b = layers.Conv2D(256, 3, strides=2, padding='same')(b)
    b = layers.ReLU()(b)
    combined = layers.Concatenate()([a_up, b])
    x = layers.Conv2DTranspose(128, 4, strides=2, padding='same')(combined)
    x = layers.ReLU()(x)
    x = layers.Conv2DTranspose(64, 4, strides=2, padding='same')(x)
    x = layers.ReLU()(x)
    out = layers.Conv2D(3, 3, strides=1, padding='same', activation='tanh')(x)
    return models.Model(inp, out, name="custom_fused_model")

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=False)
l1_loss = tf.keras.losses.MeanAbsoluteError()

def generator_loss(disc_generated_output, gen_output, target):
    gan_loss = cross_entropy(tf.ones_like(disc_generated_output), disc_generated_output)
    l1 = l1_loss(target, gen_output)
    return gan_loss + 100 * l1

def discriminator_loss(disc_real_output, disc_generated_output):
    real_loss = cross_entropy(tf.ones_like(disc_real_output), disc_real_output)
    fake_loss = cross_entropy(tf.zeros_like(disc_generated_output), disc_generated_output)
    return real_loss + fake_loss

generator_optimizer = optimizers.Adam(5e-4, beta_1=0.5)
discriminator_optimizer = optimizers.Adam(2e-4, beta_1=0.5)

def extract_patches(images, patch_size=64):
    patches = []
    for img in images:
        h, w = img.shape[0:2]
        for i in range(0, h, patch_size):
            for j in range(0, w, patch_size):
                patch = img[i:i+patch_size, j:j+patch_size]
                if patch.shape[0] == patch_size and patch.shape[1] == patch_size:
                    patches.append(patch)
    return np.array(patches)

def stitch_patches(patches, original_shape):
    patch_size = patches.shape[1]
    h, w, _ = original_shape
    stitched = np.zeros(original_shape)
    patch_index = 0
    for i in range(0, h, patch_size):
        for j in range(0, w, patch_size):
            stitched[i:i+patch_size, j:j+patch_size] = patches[patch_index]
            patch_index += 1
    return stitched

@tf.function
def train_pix2pix_step(input_image, target, generator, discriminator):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(input_image, training=True)
        disc_real_output = discriminator([input_image, target], training=True)
        disc_gen_output = discriminator([input_image, gen_output], training=True)
        gen_loss = generator_loss(disc_gen_output, gen_output, target)
        disc_loss = discriminator_loss(disc_real_output, disc_gen_output)
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    return gen_loss, disc_loss

def train_pix2pix(dataset, generator, discriminator, epochs=100):
    for epoch in range(epochs):
        print(f"Starting epoch {epoch+1}/{epochs} ...")
        for input_image, target in dataset:
            g_loss, d_loss = train_pix2pix_step(input_image, target, generator, discriminator)
        print(f"Epoch {epoch+1} completed. Generator Loss: {g_loss:.4f} Disc Loss: {d_loss:.4f}")

def train_patch_based_model(input_imgs, target_imgs, model, epochs=100):
    in_patches = extract_patches(input_imgs, patch_size=64)
    tar_patches = extract_patches(target_imgs, patch_size=64)
    print("Training patch–based model on patches with shapes:", in_patches.shape, tar_patches.shape)
    model.compile(optimizer=optimizers.Adam(2e-4, beta_1=0.5), loss='mae')
    model.fit(in_patches, tar_patches, batch_size=32, epochs=epochs)

def train_custom_model(input_imgs, target_imgs, model, epochs=100):
    print("Training custom fused model on full images:", input_imgs.shape, target_imgs.shape)
    model.compile(optimizer=optimizers.Adam(5e-4), loss='mae')
    model.fit(input_imgs, target_imgs, batch_size=1, epochs=epochs)

def compute_ssim(pred, target):
    pred_denorm = (pred + 1) / 2.0
    target_denorm = (target + 1) / 2.0
    pred_denorm = tf.convert_to_tensor(pred_denorm, dtype=tf.float32)
    target_denorm = tf.convert_to_tensor(target_denorm, dtype=tf.float32)
    ssim_val = tf.reduce_mean(tf.image.ssim(pred_denorm, target_denorm, max_val=1.0))
    return ssim_val * 100.0

def test_models(test_bw_path, test_color_path, generator, patch_model, custom_model):
    test_input = preprocess_image(test_bw_path, size=(256,256))
    test_target = preprocess_image(test_color_path, size=(256,256))
    if test_input is None or test_target is None:
        print("Error: Unable to load test images.")
        return
    test_input_batch = np.expand_dims(test_input, axis=0)
    pix2pix_output = generator.predict(test_input_batch)[0]
    patches = extract_patches(np.expand_dims(test_input, axis=0), patch_size=64)
    pred_patches = patch_model.predict(patches)
    patch_based_output = stitch_patches(pred_patches, test_input.shape)
    custom_output = custom_model.predict(test_input_batch)[0]
    ssim_pix = compute_ssim(pix2pix_output, test_target)
    ssim_patch = compute_ssim(patch_based_output, test_target)
    ssim_custom = compute_ssim(custom_output, test_target)
    print(f"Pix2Pix model SSIM accuracy: {ssim_pix.numpy():.2f}%")
    print(f"Patch–based model SSIM accuracy: {ssim_patch.numpy():.2f}%")
    print(f"Custom fused model SSIM accuracy: {ssim_custom.numpy():.2f}%")
    plt.figure(figsize=(15,5))
    plt.subplot(1,4,1)
    plt.title("Input Image")
    plt.imshow(((test_input+1)/2))
    plt.axis('off')
    plt.subplot(1,4,2)
    plt.title("Pix2Pix Output")
    plt.imshow(((pix2pix_output+1)/2))
    plt.axis('off')
    plt.subplot(1,4,3)
    plt.title("Patch-based Output")
    plt.imshow(((patch_based_output+1)/2))
    plt.axis('off')
    plt.subplot(1,4,4)
    plt.title("Custom Fused Output")
    plt.imshow(((custom_output+1)/2))
    plt.axis('off')
    plt.show()

if __name__ == '__main__':
    dataset_dir = "smalldatasetmanga"
    print("Dataset directory exists:", os.path.exists(dataset_dir))
    print("Files in dataset directory:", os.listdir(dataset_dir))
    bw_files, color_files = load_image_pairs(dataset_dir)
    print("Number of BW images:", len(bw_files), "and colored images:", len(color_files))
    inputs, targets = preprocess_dataset(bw_files, color_files, size=(256,256))
    print("Inputs shape:", inputs.shape, "Targets shape:", targets.shape)
    pix2pix_generator = build_generator()
    pix2pix_discriminator = build_discriminator()
    patch_based_model = build_patch_based_model()
    custom_fused_model = build_custom_model()
    pix2pix_dataset = tf.data.Dataset.from_tensor_slices((inputs, targets)).shuffle(10).batch(1)
    print("Training Pix2Pix model ...")
    train_pix2pix(pix2pix_dataset, pix2pix_generator, pix2pix_discriminator, epochs=100)
    print("Training patch–based model ...")
    train_patch_based_model(inputs, targets, patch_based_model, epochs=100)
    print("Training custom fused model ...")
    train_custom_model(inputs, targets, custom_fused_model, epochs=100)
    test_bw_path = os.path.join(dataset_dir, "bw_image_6.png")
    test_color_path = os.path.join(dataset_dir, "color_image_6.png")
    test_models(test_bw_path, test_color_path, pix2pix_generator, patch_based_model, custom_fused_model)
    pix2pix_generator.save("pix2pix_generator_256.h5")
    pix2pix_discriminator.save("pix2pix_discriminator_256.h5")
    patch_based_model.save("patch_based_model_256.h5")
    custom_fused_model.save("custom_fused_model_256.h5")
    print("Models saved.")
