In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import cv2
import os

In [None]:
from keras.layers import Conv2D, Conv2DTranspose, Input,Dropout, ReLU,BatchNormalization,Concatenate,LeakyReLU,Identity
from keras.models import Model
from tensorflow.keras.utils import plot_model
import tensorflow as tf

In [None]:
!wget -nc https://github.com/Sxela/face2comics/releases/download/v2.0.0/face2comics_v2.0.0_by_Sxela_faces.tar
!tar --skip-old-files -xf face2comics_v2.0.0_by_Sxela_faces.tar

In [None]:
!wget -nc https://github.com/Sxela/face2comics/releases/download/v2.0.0/face2comics_v2.0.0_by_Sxela_comics.tar
!tar --skip-old-files -xf face2comics_v2.0.0_by_Sxela_comics.tar

In [None]:
y_folder = "comics"
x_folder = "faces"

In [None]:
x_files = sorted([os.path.join(x_folder, fname) for fname in os.listdir(x_folder) if fname.endswith(".jpg")])
y_files = sorted([os.path.join(y_folder, fname) for fname in os.listdir(y_folder) if fname.endswith(".jpg")])

assert len(x_files) == len(y_files), "Number of files in each folder must be the same"

In [None]:
img_size = 128

In [None]:
# Function to load and preprocess images
def load_image(file_path):
    image = tf.io.read_file(file_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [img_size, img_size])
    image = tf.cast(image, tf.float32)
    image = (image - 127.5) / 127.5
    return image

# Function to load and preprocess paired images
def load_pair(x_path, y_path):
    x_image = load_image(x_path)
    y_image = load_image(y_path)
    return x_image, y_image

In [None]:
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.imshow((load_image(x_files[0]) +1)/2)
plt.title("Real Image")
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(load_image(y_files[0]))
plt.title("Comic Book Image")
plt.axis('off')
plt.show()

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
X_train, X_test, Y_train, Y_test = train_test_split(x_files, y_files, test_size=0.2, random_state=42)
X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=0.125, random_state=42)

len(X_train),len(X_val),len(X_test)

In [None]:
def create_dataset(x_files, y_files,batch_size=16,is_train=False):
    # Create a TensorFlow dataset from the file paths
    dataset = tf.data.Dataset.from_tensor_slices((x_files, y_files))

    # Map the file paths to images
    dataset = dataset.map(lambda x, y: tf.py_function(load_pair, [x, y], [tf.float32, tf.float32]))

    if is_train: dataset = dataset.shuffle(1000)

    dataset = dataset.batch(batch_size).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
    return dataset

train_dataset = create_dataset(X_train, Y_train,is_train=True)
test_dataset = create_dataset(X_test, Y_test)
val_dataset = create_dataset(X_val, Y_val)

In [None]:
def rmse(y_true, y_pred):
    return tf.sqrt(tf.reduce_mean(tf.square(y_pred - y_true)))

def psnr(y_true, y_pred):
    max_pixel = 1.0
    return tf.image.psnr(y_true, y_pred, max_val=max_pixel)

In [None]:
def CK(filters, kernel_size=(4, 4), strides=(2, 2), padding='same', use_batch_norm=True, downsample=True):
    conv = Conv2D               if downsample       else Conv2DTranspose
    norm = BatchNormalization   if use_batch_norm   else Identity
    actf = LeakyReLU(0.2)       if downsample       else ReLU()

    def layer(x):
        x = conv(filters, kernel_size, strides=strides, padding=padding)(x)
        x = norm()(x)
        x = actf(x)
        return x
    return layer

def CDK(filters, kernel_size=(4, 4), strides=(2, 2), padding='same', use_batch_norm=True, downsample=True,dropout_rate=0.5):
    conv = Conv2D               if downsample       else Conv2DTranspose
    norm = BatchNormalization   if use_batch_norm   else Identity
    actf = ReLU()

    def layer(x):
        x = conv(filters, kernel_size, strides=strides, padding=padding)(x)
        x = norm()(x)
        x = Dropout(dropout_rate)(x)
        x = actf(x)
        return x
    return layer


In [None]:
def generator(input_nc, output_nc, ngf, num_blocks=1, num_downsample=3):
    inputs = Input(shape=(img_size, img_size, input_nc))
    x = inputs

    # Initial convolutional layers
    x = CK(ngf,use_batch_norm=False)(x)

    # Contracting path
    skips = []
    for i in range(num_downsample + num_blocks - 1):
        expo = min(i+1,num_downsample)
        x = CK(ngf*(2**(expo)))(x)
        skips.append(x)

    skips = list(reversed(skips))
    for skip in skips[:num_blocks-1]:
        x = Concatenate()([x, skip])
        x = CDK(ngf*(2**(num_downsample)),downsample=False)(x)

    for skip in skips[num_blocks-1:]:
        x = Concatenate()([x, skip])
        x = CK(ngf * (2 ** i),downsample=False)(x)

    output = Conv2DTranspose(output_nc, (4, 4), activation='tanh',padding="same",strides=(2, 2))(x)
    return Model(inputs=inputs, outputs=output)

In [None]:
def discriminator(input_nc,output_nc, ngf, num_blocks=1, num_downsample=3):

    inp = Input(shape=[img_size, img_size, input_nc], name='input_image')
    tar = Input(shape=[img_size, img_size, output_nc], name='target_image')

    x = Concatenate()([inp, tar])

    # Initial convolutional layers
    x = CK(ngf,use_batch_norm=False)(x)

    # Contracting path
    for i in range(num_downsample):
        x = CK(ngf*(2**(i+1)))(x)

    output = Conv2D(1, (4, 4), activation='sigmoid',padding="same",strides=(2, 2))(x)
    return Model(inputs=[inp, tar], outputs=output)

In [None]:
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.optimizers import Adam

In [None]:
def discriminator_loss(real_output, fake_output):
    cross_entropy = BinaryCrossentropy(from_logits=False)
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    cross_entropy = BinaryCrossentropy(from_logits=False)
    return cross_entropy(tf.ones_like(fake_output), fake_output)

def l1_loss(real_image, generated_image):
    return tf.reduce_mean(tf.abs(real_image - generated_image))

In [None]:
generator_optimizer = Adam(2e-4, beta_1=0.5,beta_2=0.999)
discriminator_optimizer = Adam(2e-4, beta_1=0.5,beta_2=0.999)
LAMBDA = 100


@tf.function
def train_step(input_image, target_image, generator, discriminator):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_image = generator(input_image, training=True)

        real_output = discriminator([input_image, target_image], training=True)
        fake_output = discriminator([input_image, generated_image], training=True)

        gen_loss = generator_loss(fake_output) + LAMBDA*l1_loss(target_image, generated_image)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    return gen_loss, disc_loss

In [None]:
def evaluate_model(validation_dataset, generator_model, discriminator_model):
    # Initialize accumulators for metrics
    val_gen_loss_total = 0
    val_rmse_total = 0
    num_batches = 0

    # Iterate over the validation dataset
    for batch in validation_dataset:
        input_image, target_image = batch

        # Generate images using the generator model
        generated_image = generator_model(input_image, training=False)

        # Compute losses
        val_gen_loss = generator_loss(discriminator_model([input_image, target_image], training=False)) + l1_loss(target_image, generated_image)
        val_rmse = rmse(target_image, generated_image)

        # Accumulate metrics
        val_gen_loss_total += val_gen_loss
        val_rmse_total += val_rmse
        num_batches += 1

    # Compute average metrics
    avg_val_gen_loss = val_gen_loss_total / num_batches
    avg_val_rmse = val_rmse_total / num_batches

    return avg_val_gen_loss, avg_val_rmse

In [None]:
# Example usage
input_nc = 3  # Number of input channels (e.g., RGB)
output_nc = 3  # Number of output channels (e.g., RGB)
ngf = 64  # Number of generator filters in first conv layer

generator_model     = generator(input_nc, output_nc, ngf, num_blocks=3, num_downsample=4)
discriminator_model = discriminator(input_nc,output_nc, ngf, num_blocks=3, num_downsample=4)

generator_model.compile(optimizer=generator_optimizer, loss=generator_loss)
discriminator_model.compile(optimizer=discriminator_optimizer, loss=discriminator_loss)

In [None]:
generator_model.summary()

In [None]:
plot_model(generator_model)

In [None]:
discriminator_model.summary()

In [None]:
plot_model(discriminator_model)

In [None]:
# Lists to store metrics
epochs_list = []
gen_losses = []
disc_losses = []
val_gen_losses = []
val_rmses = []
val_psnrs = []

In [None]:
# Training loop
epochs = 200
for epoch in range(epochs):
    gen_loss_total = 0
    disc_loss_total = 0
    num_batches = 0

    for batch in train_dataset:
        input_image, target_image = batch
        gen_loss, disc_loss = train_step(input_image, target_image, generator_model, discriminator_model)

        # Accumulate training losses
        gen_loss_total += gen_loss.numpy().item()
        disc_loss_total += disc_loss.numpy().item()
        num_batches += 1

    val_gen_loss, val_rmse = evaluate_model(val_dataset, generator_model, discriminator_model)

    # Compute average training losses
    gen_loss = gen_loss_total / num_batches
    disc_loss = disc_loss_total / num_batches

    # Collect training losses
    epochs_list.append(epoch + 1)
    gen_losses.append(gen_loss)
    disc_losses.append(disc_loss)

    val_gen_losses.append(val_gen_loss.numpy().item())
    val_rmses.append(val_rmse.numpy().item())

    print(f"Epoch {epoch+1}/{epochs} - Gen Loss: {gen_loss:.4f}, D Loss: {disc_loss:.4f}")
    # Validation step
    print(f"Val Gen Loss: {val_gen_loss.numpy().item():.4f}, RMSE: {val_rmse.numpy().item():.4f}")
    print()

In [None]:
# Plot metrics
plt.figure(figsize=(12, 6))

# Plot Generator and Discriminator Loss
plt.subplot(1, 2, 1)
plt.plot(epochs_list, gen_losses, label='Generator Loss')
plt.plot(epochs_list, disc_losses, label='Discriminator Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Losses')
plt.legend()

# Plot Validation Metrics
plt.subplot(1, 2, 2)
plt.plot(epochs_list, val_gen_losses, label='Validation Gen Loss')
plt.plot(epochs_list, val_rmses, label='Validation RMSE')
plt.plot(epochs_list, val_psnrs, label='Validation PSNR')
plt.xlabel('Epochs')
plt.ylabel('Metric Value')
plt.title('Validation Metrics')
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
def display_random_images(val_dataset, generator_model, num_samples=3):
    # Get a random batch from the validation dataset
    val_dataset = val_dataset.shuffle(buffer_size=len(val_dataset))
    iterator = iter(val_dataset)

    # Collect a few random samples
    input_images = []
    target_images = []
    predicted_images = []

    for _ in range(num_samples):
        batch = next(iterator)
        input_image, target_image = batch
        input_images.append(input_image.numpy()[0])
        target_images.append(target_image.numpy()[0])
        predicted_image = generator_model(input_image, training=False)
        predicted_images.append(predicted_image.numpy()[0])

    # Plot the images
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, num_samples * 5))

    for i in range(num_samples):
        # Input image
        axes[i, 0].imshow((input_images[i] + 1) / 2)  # Normalize for display
        axes[i, 0].set_title("Input")
        axes[i, 0].axis('off')

        # Target image
        axes[i, 1].imshow((target_images[i] + 1) / 2)  # Normalize for display
        axes[i, 1].set_title("Target")
        axes[i, 1].axis('off')

        # Predicted image
        axes[i, 2].imshow((predicted_images[i] + 1) / 2)  # Normalize for display
        axes[i, 2].set_title("Prediction")
        axes[i, 2].axis('off')

    plt.tight_layout()
    plt.show()

# Example usage
display_random_images(val_dataset, generator_model, num_samples=3)