In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, Input, Concatenate
from tensorflow.keras import Model
import matplotlib.pyplot as plt

In [None]:
def load_and_split_image(image_file):
    # Load and decode the image file
    image = tf.io.read_file(image_file)
    image = tf.image.decode_jpeg(image)

    # Split the image into two halves: input and target
    width = tf.shape(image)[1] // 2
    input_image = image[:, :width, :]  # Left half
    target_image = image[:, width:, :]  # Right half

    # Normalize images to [-1, 1]
    input_image = (tf.cast(input_image, tf.float32) / 127.5) - 1
    target_image = (tf.cast(target_image, tf.float32) / 127.5) - 1

    return input_image, target_image

In [None]:
def create_generator():
    # Define the input
    input_layer = Input(shape=(256, 256, 3))

    # Encoder - Downsampling with Conv2D layers
    down1 = Conv2D(64, kernel_size=4, strides=2, padding='same', activation='relu')(input_layer)
    down2 = Conv2D(128, kernel_size=4, strides=2, padding='same', activation='relu')(down1)
    down3 = Conv2D(256, kernel_size=4, strides=2, padding='same', activation='relu')(down2)

    # Bottleneck
    bottleneck = Conv2D(512, kernel_size=4, strides=2, padding='same', activation='relu')(down3)

    # Decoder - Upsampling with Conv2DTranspose layers and skip connections
    up1 = Conv2DTranspose(256, kernel_size=4, strides=2, padding='same', activation='relu')(bottleneck)
    up1 = Concatenate()([up1, down3])  # Skip connection

    up2 = Conv2DTranspose(128, kernel_size=4, strides=2, padding='same', activation='relu')(up1)
    up2 = Concatenate()([up2, down2])

    up3 = Conv2DTranspose(64, kernel_size=4, strides=2, padding='same', activation='relu')(up2)
    up3 = Concatenate()([up3, down1])

    # Output layer
    output_layer = Conv2DTranspose(3, kernel_size=4, strides=2, padding='same', activation='tanh')(up3)

    return Model(inputs=input_layer, outputs=output_layer)

In [None]:
def create_discriminator():
    # Inputs: both the original and generated image
    input_img = Input(shape=(256, 256, 3))
    target_img = Input(shape=(256, 256, 3))

    # Concatenate input and target images
    combined_input = Concatenate()([input_img, target_img])

    # Discriminator layers
    disc1 = Conv2D(64, kernel_size=4, strides=2, padding='same', activation='relu')(combined_input)
    disc2 = Conv2D(128, kernel_size=4, strides=2, padding='same', activation='relu')(disc1)
    disc3 = Conv2D(256, kernel_size=4, strides=2, padding='same', activation='relu')(disc2)

    # Output patch
    output_patch = Conv2D(1, kernel_size=4, padding='same')(disc3)

    return Model(inputs=[input_img, target_img], outputs=output_patch)

In [None]:
# Generator loss - combines adversarial and L1 loss
def generator_loss(disc_output, gen_output, target):
    # Adversarial loss - encourage realistic images
    adv_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(tf.ones_like(disc_output), disc_output))

    # L1 loss - encourage similarity with the target
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

    # Total generator loss
    return adv_loss + (100 * l1_loss)

# Discriminator loss - distinguishes real vs. fake images
def discriminator_loss(disc_real, disc_fake):
    real_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(tf.ones_like(disc_real), disc_real))
    fake_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(tf.zeros_like(disc_fake), disc_fake))
    return real_loss + fake_loss

In [None]:
# Optimizers
gen_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
disc_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

@tf.function
def train_step(input_image, target):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        # Generate the fake image
        gen_output = generator(input_image, training=True)

        # Get discriminator outputs for real and fake images
        disc_real_output = discriminator([input_image, target], training=True)
        disc_fake_output = discriminator([input_image, gen_output], training=True)

        # Calculate losses
        gen_loss = generator_loss(disc_fake_output, gen_output, target)
        disc_loss = discriminator_loss(disc_real_output, disc_fake_output)

    # Apply gradients to update generator and discriminator
    gen_grads = gen_tape.gradient(gen_loss, generator.trainable_variables)
    disc_grads = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    gen_optimizer.apply_gradients(zip(gen_grads, generator.trainable_variables))
    disc_optimizer.apply_gradients(zip(disc_grads, discriminator.trainable_variables))

    return gen_loss, disc_loss

In [None]:
from google.colab import files
import zipfile
import io

# Upload the zip file file A
uploaded = files.upload()

# Get the name of the uploaded zip file (assuming only one file was uploaded)
zip_filename = list(uploaded.keys())[0]

# Extract the contents of the zip file
with zipfile.ZipFile(io.BytesIO(uploaded[zip_filename]), 'r') as zip_ref:
    zip_ref.extractall('/content/') # Extract to '/content/' directory

print(f"Folder '{zip_filename[:-4]}' uploaded and extracted to '/content/'")

Saving trainA.zip to trainA.zip
Folder 'trainA' uploaded and extracted to '/content/'


In [None]:
# Upload the zip file file B
uploaded = files.upload()

# Get the name of the uploaded zip file (assuming only one file was uploaded)
zip_filename = list(uploaded.keys())[0]

# Extract the contents of the zip file
with zipfile.ZipFile(io.BytesIO(uploaded[zip_filename]), 'r') as zip_ref:
    zip_ref.extractall('/content/') # Extract to '/content/' directory

print(f"Folder '{zip_filename[:-4]}' uploaded and extracted to '/content/'")

Saving trainB.zip to trainB.zip
Folder 'trainB' uploaded and extracted to '/content/'


In [None]:
import os

# Update image_path and target_path to point to the extracted folders
image_path_raw = os.path.join('/content/', 'trainA/') # Use os.path.join for platform independence
target_path_raw = os.path.join('/content/', 'trainB/')

In [None]:
# Load images from image_path using image_dataset_from_directory
image_path = tf.keras.utils.image_dataset_from_directory(
    image_path_raw, # Use the raw image path here
    label_mode=None,
    image_size=(256, 256),
    batch_size=32,
    shuffle=True
)

# Load target images similarly and combine with train_dataset
# (You might need to adjust this depending on how your target data is organized)
target_path = tf.keras.utils.image_dataset_from_directory(
    target_path_raw, # Use the raw target path here
    label_mode=None,
    image_size=(256, 256),
    batch_size=32,
    shuffle=True
)

Found 10000 files.
Found 10000 files.


In [None]:
# Zip the datasets together
train_dataset = tf.data.Dataset.zip(image_path, target_path)

# Preprocess the images (e.g., normalization)
def preprocess(input_image, target_image):
    input_image = input_image / 255.0  # Normalize to [0, 1]
    target_image = target_image / 255.0
    return input_image, target_image

train_dataset = train_dataset.map(preprocess)

In [None]:
# Define the number of training epochs
epochs = 1  # You can adjust this value as needed

# Create instances of the generator and discriminator
generator = create_generator()
discriminator = create_discriminator()

In [None]:
train_dataset = train_dataset.map(preprocess).cache().prefetch(tf.data.AUTOTUNE)

In [None]:
# Enable mixed precision training
from tensorflow.keras import mixed_precision
policy = mixed_precision.Policy('mixed_float16')

In [22]:
import tensorflow as tf

# Set up the logging interval and checkpoints to reduce runtime interruptions
log_interval = 10  # Log every 10 epochs
checkpoint_interval = 50  # Save checkpoint every 50 epochs

# Set up the summary writer
writer = tf.summary.create_file_writer('logs')

# Training loop
for epoch in range(epoch):
    with writer.as_default():
        # Perform the training step
        gen_loss, disc_loss = train_step(input_image, target)

        # Log losses only at specified intervals to reduce overhead
        if epoch % log_interval == 0:
            tf.summary.scalar('gen_loss', gen_loss, step=epoch)
            tf.summary.scalar('disc_loss', disc_loss, step=epoch)
            print(f"Epoch {epoch}: Gen Loss = {gen_loss:.4f}, Disc Loss = {disc_loss:.4f}")

    # Save model checkpoint periodically
    if epoch % checkpoint_interval == 0:
        checkpoint.save(file_prefix=checkpoint_prefix)

In [23]:
def display_sample(generator, test_input, test_target):
    generated_image = generator(test_input, training=False)
    plt.figure(figsize=(12, 12))

    images = [test_input[0], test_target[0], generated_image[0]]
    titles = ['Input', 'Target', 'Generated']

    for i, image in enumerate(images):
        plt.subplot(1, 3, i + 1)
        plt.title(titles[i])
        plt.imshow(image * 0.5 + 0.5)
        plt.axis('off')

    plt.show()

In [24]:
display_sample