# Week 5: GANs


## 1. Description of the Problem/Data

**Problem Statement**: The goal is to transform input photographs into images that resemble the style of Claude Monet's paintings using a generative model, specifically Generative Adversarial Networks (GANs).

**Data Description**: The dataset consists of:

* **Monet Paintings**: 300 images of Monet’s artworks (256x256 pixels in JPEG and TFRecord format).
* **Photos**: 7028 real-world photos (256x256 pixels in JPEG and TFRecord format).

In [None]:
!pip -q install kaggle matplotlib seaborn scikit-image

In [None]:
import os
import shutil
import numpy as np
import matplotlib.pyplot as plt
import sys
import tensorflow as tf
import pandas as pd
from tensorflow import keras
import seaborn as sns
import skimage
from skimage import feature
from skimage.feature import graycomatrix, graycoprops
from skimage.color import rgb2gray
from tensorflow.keras import layers, Model

In [None]:
# Verify GPUs are available
print("TensorFlow version:", tf.__version__)
print("Number of available GPUs:", len(tf.config.list_physical_devices('GPU')))

gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        # Enable memory growth for all GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print("Memory growth enabled for GPUs.")
    except RuntimeError as e:
        print(e)

In [None]:
home_path = os.environ['HOME']
os.makedirs(f"{home_path}/.kaggle", exist_ok=True)
!cp ./kaggle.json {home_path}/.kaggle/
!chmod 600 {home_path}/.kaggle/kaggle.json

In [None]:
!kaggle competitions download -c gan-getting-started


In [None]:
!unzip -q gan-getting-started.zip

## 2. Exploratory Data Analysis (EDA)
* Load the dataset using TensorFlow's tf.data API.
* Display a random sample of Monet paintings and photos to observe differences.
* Analyze color distributions, texture, and other artistic elements in Monet's paintings versus the photos.

In [None]:
def load_images_from_folder(folder):
    images = []
    for filename in os.listdir(folder):
        img = plt.imread(os.path.join(folder, filename))
        if img is not None:
            images.append(img)
    return images

def plot_images(images, title):
    plt.figure(figsize=(10, 10))
    for i, image in enumerate(images[:9]):
        plt.subplot(3, 3, i+1)
        plt.imshow(image)
        plt.title(title)
        plt.axis('off')
    plt.show()
    
monet_files = 'monet_jpg'  
photo_files = 'photo_jpg'

monet_images = load_images_from_folder(monet_files)[:6]  # Load first 9 Monet images
photo_images = load_images_from_folder(photo_files)[:6]  # Load first 9 photo images

plot_images(monet_images, 'Monet Paintings')
plot_images(photo_images, 'Photos')

In [None]:
def plot_color_histograms(images, title):
    """Plot color histograms for a list of images."""
    colors = ('red', 'green', 'blue')
    for color, channel in zip(colors, range(3)):
        hist_data = [np.histogram(image[:, :, channel], bins=256, range=(0, 256))[0] for image in images]
        mean_hist = np.mean(hist_data, axis=0)
        plt.plot(mean_hist, color=color)
    plt.title(f'Color Histograms for {title}')
    plt.xlabel('Intensity Value')
    plt.ylabel('Frequency')
    plt.show()

def analyze_texture(images, title):
    """Analyze texture using GLCM and plot the results."""
    # Convert images to grayscale
    gray_images = [rgb2gray(image) for image in images]
    
    # Calculate GLCM and texture properties
    glcm_props = ['contrast', 'dissimilarity', 'homogeneity', 'ASM', 'energy', 'correlation']
    texture_features = {prop: [] for prop in glcm_props}
    
    for image in gray_images:
        glcm = graycomatrix((image * 255).astype('uint8'), distances=[1], angles=[0], symmetric=True, normed=True)
        for prop in glcm_props:
            texture_features[prop].append(graycoprops(glcm, prop)[0, 0])
    
    # Calculate mean of texture properties
    for prop in glcm_props:
        mean_value = np.mean(texture_features[prop])
        print(f'{title} - Average {prop}: {mean_value}')
        
monet_images_subset = monet_images[:100]  # Take first 100 images for analysis
photo_images_subset = photo_images[:100]  # Take first 100 images for analysis

plot_color_histograms(monet_images_subset, 'Monet Paintings')
plot_color_histograms(photo_images_subset, 'Photos')

analyze_texture(monet_images_subset, 'Monet Paintings')
analyze_texture(photo_images_subset, 'Photos')

## 3. Model Building and Training
**Model Choice:** Use a CycleGAN architecture, which is effective for image-to-image translation tasks without needing paired examples.

Training Steps:

1. Build the Generator and Discriminator Models:
* The generator should transform a photo to a Monet-style image.
* The discriminator distinguishes between generated images and real Monet paintings.
2. Set Up Loss Functions:
* Adversarial loss (to train generators and discriminators).
* Cycle consistency loss (to ensure that the original photo can be recovered from the generated image).
3. Compile the Model:
* Use TensorFlow and Keras to set up the training loops.
4 Train the Model:
* Use the tf.data API for efficient data handling.


In [None]:
class InstanceNormalization(layers.Layer):
    def __init__(self, epsilon=1e-5):
        super(InstanceNormalization, self).__init__()
        self.epsilon = epsilon

    def build(self, input_shape):
        self.scale = self.add_weight(
            name='scale', 
            shape=input_shape[-1:], 
            initializer=tf.random_normal_initializer(1., 0.02), 
            trainable=True)
        
        self.offset = self.add_weight(
            name='offset', 
            shape=input_shape[-1:], 
            initializer='zeros', 
            trainable=True)

    def call(self, x):
        mean, variance = tf.nn.moments(x, axes=[1, 2], keepdims=True)
        inv = tf.math.rsqrt(variance + self.epsilon)
        normalized = (x - mean) * inv
        return self.scale * normalized + self.offset

In [None]:
def build_generator():
    inputs = layers.Input(shape=(256, 256, 3))
    # Reflection padding
    x = layers.ZeroPadding2D(padding=3)(inputs)
    x = layers.Conv2D(64, 7, use_bias=False)(x)
    x = InstanceNormalization()(x)
    x = layers.ReLU()(x)
    # Downsampling
    x = layers.Conv2D(128, 3, strides=2, padding='same', use_bias=False)(x)
    x = InstanceNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv2D(256, 3, strides=2, padding='same', use_bias=False)(x)
    x = InstanceNormalization()(x)
    x = layers.ReLU()(x)
    # Residual blocks
    for _ in range(9):
        y = layers.ZeroPadding2D(padding=1)(x)
        y = layers.Conv2D(256, 3, use_bias=False)(y)
        y = InstanceNormalization()(y)
        y = layers.ReLU()(y)
        y = layers.ZeroPadding2D(padding=1)(y)
        y = layers.Conv2D(256, 3, use_bias=False)(y)
        y = InstanceNormalization()(y)
        x = layers.add([x, y]) # Skip connection
    # Upsampling
    x = layers.Conv2DTranspose(128, 3, strides=2, padding='same', use_bias=False)(x)
    x = InstanceNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv2DTranspose(64, 3, strides=2, padding='same', use_bias=False)(x)
    x = InstanceNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.ZeroPadding2D(padding=3)(x)
    outputs = layers.Conv2D(3, 7, use_bias=False, activation='tanh')(x) # Valid padding due to reflection padding
    return Model(inputs, outputs)

In [None]:
def build_discriminator():
    inputs = layers.Input(shape=(256, 256, 3))
    
    x = layers.Conv2D(64, 4, strides=2, padding='same')(inputs)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(128, 4, strides=2, padding='same', use_bias=False)(x)
    x = InstanceNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(256, 4, strides=2, padding='same', use_bias=False)(x)
    x = InstanceNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(512, 4, padding='same', use_bias=False)(x)
    x = InstanceNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    outputs = layers.Conv2D(1, 4, padding='same')(x)  # No sigmoid
    
    return Model(inputs, outputs)


In [None]:
class CycleGAN(Model):
    def __init__(self, monet_generator, photo_generator, monet_discriminator, photo_discriminator):
        super(CycleGAN, self).__init__()
        self.m_gen = monet_generator
        self.p_gen = photo_generator
        self.m_disc = monet_discriminator
        self.p_disc = photo_discriminator

    def compile(self, m_gen_optimizer, p_gen_optimizer, m_disc_optimizer, p_disc_optimizer, gen_loss_fn, disc_loss_fn, cycle_loss_fn, identity_loss_fn):
        super(CycleGAN, self).compile()
        self.m_gen_optimizer = m_gen_optimizer
        self.p_gen_optimizer = p_gen_optimizer
        self.m_disc_optimizer = m_disc_optimizer
        self.p_disc_optimizer = p_disc_optimizer
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn

    def train_step(self, data):
        real_monet, real_photo = data

        with tf.GradientTape(persistent=True) as tape:
            # Generate fake images
            fake_photo = self.m_gen(real_monet, training=True)
            fake_monet = self.p_gen(real_photo, training=True)
        
            # Cycle back to original images
            cycled_monet = self.m_gen(fake_photo, training=True)
            cycled_photo = self.p_gen(fake_monet, training=True)
        
            # Identity mapping of images
            same_monet = self.m_gen(real_monet, training=True)
            same_photo = self.p_gen(real_photo, training=True)
        
            # Discriminator output
            disc_real_monet = self.m_disc(real_monet, training=True)
            disc_real_photo = self.p_disc(real_photo, training=True)
            disc_fake_monet = self.m_disc(fake_monet, training=True)
            disc_fake_photo = self.p_disc(fake_photo, training=True)
        
            # Generator loss
            gen_monet_loss = self.gen_loss_fn(tf.ones_like(disc_fake_monet), disc_fake_monet)
            gen_photo_loss = self.gen_loss_fn(tf.ones_like(disc_fake_photo), disc_fake_photo)
        
            # Total cycle consistency loss
            total_cycle_loss = self.cycle_loss_fn(real_monet, cycled_monet) + self.cycle_loss_fn(real_photo, cycled_photo)
        
            # Total identity loss
            total_identity_loss = self.identity_loss_fn(real_monet, same_monet) + self.identity_loss_fn(real_photo, same_photo)
        
            # Total generator loss
            total_monet_gen_loss = gen_monet_loss + total_cycle_loss + total_identity_loss
            total_photo_gen_loss = gen_photo_loss + total_cycle_loss + total_identity_loss
        
            # Discriminator loss
            monet_disc_loss = self.disc_loss_fn(disc_real_monet, disc_fake_monet)
            photo_disc_loss = self.disc_loss_fn(disc_real_photo, disc_fake_photo)
        
        # Calculate the gradients for generators and discriminators
        monet_gen_gradients = tape.gradient(total_monet_gen_loss, self.m_gen.trainable_variables)
        photo_gen_gradients = tape.gradient(total_photo_gen_loss, self.p_gen.trainable_variables)
        monet_disc_gradients = tape.gradient(monet_disc_loss, self.m_disc.trainable_variables)
        photo_disc_gradients = tape.gradient(photo_disc_loss, self.p_disc.trainable_variables)
    
        # Apply the gradients to the optimizer
        self.m_gen_optimizer.apply_gradients(zip(monet_gen_gradients, self.m_gen.trainable_variables))
        self.p_gen_optimizer.apply_gradients(zip(photo_gen_gradients, self.p_gen.trainable_variables))
        self.m_disc_optimizer.apply_gradients(zip(monet_disc_gradients, self.m_disc.trainable_variables))
        self.p_disc_optimizer.apply_gradients(zip(photo_disc_gradients, self.p_disc.trainable_variables))
    
        return {
            "monet_gen_loss": total_monet_gen_loss,
            "photo_gen_loss": total_photo_gen_loss,
            "monet_disc_loss": monet_disc_loss,
            "photo_disc_loss": photo_disc_loss
        }

In [None]:
# Instantiate and compile the CycleGAN model
monet_generator = build_generator()
photo_generator = build_generator()
monet_discriminator = build_discriminator()
photo_discriminator = build_discriminator()

# Define the learning rate scheduler
initial_learning_rate = 1e-4
decay_steps = 1000
decay_rate = 0.90
learning_rate_fn = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate, decay_steps=decay_steps, decay_rate=decay_rate
)

cyclegan = CycleGAN(monet_generator, photo_generator, monet_discriminator, photo_discriminator)

cyclegan.compile(
    m_gen_optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate_fn),
    p_gen_optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate_fn),
    m_disc_optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate_fn),
    p_disc_optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate_fn),
    gen_loss_fn=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    disc_loss_fn=tf.keras.losses.MeanSquaredError(),
    cycle_loss_fn=lambda real, cycled: tf.reduce_mean(tf.abs(real - cycled)),
    identity_loss_fn=lambda real, same: tf.reduce_mean(tf.abs(real - same))
)


In [None]:
# Define a function to parse TFRecord examples
import glob
def _parse_image_function(proto):
    # Define your parse dictionary
    features = {'image': tf.io.FixedLenFeature([], tf.string)}
    # Parse the input `tf.train.Example` proto using the dictionary above
    parsed_features = tf.io.parse_single_example(proto, features)
    # Decode the JPEG image
    image = tf.image.decode_jpeg(parsed_features['image'])
    image = tf.image.resize(image, [256, 256])
    image = tf.reshape(image, [256, 256, 3])
    # Normalize the image to [-1, 1]
    image = (image / 127.5) - 1
    return image

def load_dataset(monet_tfrecords_path, photo_tfrecords_path):
    # Create a TensorFlow dataset from the TFRecord files
    monet_dataset = tf.data.TFRecordDataset(monet_tfrecords_path)
    photo_dataset = tf.data.TFRecordDataset(photo_tfrecords_path)
    
    # Map the parse function to the datasets
    monet_dataset = monet_dataset.map(_parse_image_function)
    photo_dataset = photo_dataset.map(_parse_image_function)
    
    # Print the shapes of the datasets
    print("Monet dataset shape:", monet_dataset.element_spec.shape)
    print("Photo dataset shape:", photo_dataset.element_spec.shape)
    # Print the shape of a single image
    for image in monet_dataset.take(1):
        print("Monet image shape:", image.shape)
    
    for image in photo_dataset.take(1):
        print("Photo image shape:", image.shape)
    
    # Zip the datasets together
    return tf.data.Dataset.zip((monet_dataset, photo_dataset))

# Specify the correct paths to your TFRecord files
monet_tfrecord_dir = './monet_tfrec/'
photo_tfrecord_dir = './photo_tfrec/'   

monet_tfrecord_files = glob.glob(monet_tfrecord_dir + '*.tfrec')
photo_tfrecord_files = glob.glob(photo_tfrecord_dir + '*.tfrec')

# print("Monet TFRecord files:", monet_tfrecord_files)
# print("Photo TFRecord files:", photo_tfrecord_files)

batch_size = 1 

train_dataset = load_dataset(monet_tfrecord_files, photo_tfrecord_files)

for monet, photo in train_dataset.take(1):
    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    
    # Display the photo image
    axs[0].imshow((photo.numpy() + 1) / 2)  # Rescale from [-1, 1] to [0, 1]
    axs[0].set_title("Photo")
    axs[0].axis("off")
    
    # Display the Monet image
    axs[1].imshow((monet.numpy() + 1) / 2)  # Rescale from [-1, 1] to [0, 1]
    axs[1].set_title("Monet")
    axs[1].axis("off")
    
    plt.tight_layout()
    plt.show()


In [None]:
! rm -rf *.h5

In [None]:
import time

start_time = time.time()

epochs = 250
batch_size = 1
train_dataset = train_dataset.batch(batch_size)
steps_per_epoch = len(list(train_dataset.as_numpy_iterator()))

patience = 3  
best_loss = float('inf')  
wait = 0  

display_epoch_interval = 5 

for epoch in range(epochs):
    start_epoch_time = time.time()
    epoch_losses = []
    
    step = 0
    for image_monet, image_photo in train_dataset:
        losses = cyclegan.train_step((image_monet, image_photo))
        epoch_losses.append(losses)
        
        step += 1
#         if step % 100 == 0:
#             elapsed_time = time.time() - start_epoch_time
#             remaining_steps = steps_per_epoch - step
#             estimated_time_remaining = (elapsed_time / step) * remaining_steps
            
#             print(f"Epoch {epoch + 1}/{epochs}, Step {step}/{steps_per_epoch}, "
#                   f"Monet Generator Loss: {losses['monet_gen_loss']:.4f}, "
#                    f"Photo Generator Loss: {losses['photo_gen_loss']:.4f}, "
#                   f"Monet Discriminator Loss: {losses['monet_disc_loss']:.4f}, "
#                    f"Photo Discriminator Loss: {losses['photo_disc_loss']:.4f}, "
#                    f"Elapsed Time: {elapsed_time:.2f}s, "
#                   f"Estimated Time Remaining: {estimated_time_remaining:.2f}s")
    
    epoch_losses = {k: sum(l[k] for l in epoch_losses) / len(epoch_losses) for k in epoch_losses[0]}
    epoch_elapsed_time = time.time() - start_epoch_time
    total_elapsed_time = time.time() - start_time
    
    # Calculate the total loss for the current epoch
    total_loss = epoch_losses['monet_gen_loss'] + epoch_losses['photo_gen_loss'] + epoch_losses['monet_disc_loss'] + epoch_losses['photo_disc_loss']
    
    # Check if the current loss is better than the best loss
    if total_loss < best_loss:
        best_loss = total_loss
        wait = 0
        print(f"\nNew best loss found at epoch {epoch + 1}: {best_loss:.4f}")
        
        # Save the best model checkpoint
        monet_generator.save('best_monet_generator.h5')
        photo_generator.save('best_photo_generator.h5')
        monet_discriminator.save('best_monet_discriminator.h5')
        photo_discriminator.save('best_photo_discriminator.h5')
    else:
        wait += 1
        if wait >= patience:
            print(f"\nEarly stopping at epoch {epoch + 1}")
            break
    
    print(f"\nEpoch {epoch + 1}/{epochs} completed, "
          f"Monet Generator Loss: {epoch_losses['monet_gen_loss']:.4f}, "
          f"Photo Generator Loss: {epoch_losses['photo_gen_loss']:.4f}, "
          f"Monet Discriminator Loss: {epoch_losses['monet_disc_loss']:.4f}, "
          f"Photo Discriminator Loss: {epoch_losses['photo_disc_loss']:.4f}, "
          f"Epoch Time: {epoch_elapsed_time:.2f}s, "
          f"Total Elapsed Time: {total_elapsed_time:.2f}s\n")
    
    # Display sample generated images every few epochs
    if (epoch + 1) % display_epoch_interval == 0:
        monet_batch, photo_batch = next(iter(train_dataset.take(1)))
    
        # Generate sample Monet-style images from photo images
        fake_monet_batch = monet_generator(photo_batch, training=False)
    
        # Generate sample photo-style images from Monet images
        fake_photo_batch = photo_generator(monet_batch, training=False)
        
        photo_batch = (photo_batch + 1) / 2
        fake_monet_batch = (fake_monet_batch + 1) / 2
        monet_batch = (monet_batch + 1) / 2
        fake_photo_batch = (fake_photo_batch + 1) / 2
        
        # Display the samples
        fig, axs = plt.subplots(2, 2, figsize=(10, 10))
        axs[0, 0].imshow(photo_batch[0])
        axs[0, 0].set_title("Real Photo")
        axs[0, 1].imshow(fake_monet_batch[0])
        axs[0, 1].set_title("Generated Monet")
        axs[1, 0].imshow(monet_batch[0])
        axs[1, 0].set_title("Real Monet")
        axs[1, 1].imshow(fake_photo_batch[0])
        axs[1, 1].set_title("Generated Photo")
        plt.tight_layout()
        plt.show()

# Load the best model checkpoint
monet_generator.load_weights('best_monet_generator.h5')
photo_generator.load_weights('best_photo_generator.h5')
monet_discriminator.load_weights('best_monet_discriminator.h5')
photo_discriminator.load_weights('best_photo_discriminator.h5')

## Generate Images

In [None]:
monet_generator.load_weights('best_monet_generator.h5')
photo_generator.load_weights('best_photo_generator.h5')
monet_discriminator.load_weights('best_monet_discriminator.h5')
photo_discriminator.load_weights('best_photo_discriminator.h5')

In [None]:
from PIL import Image
import zipfile
import os

# Load the pre-trained CycleGAN model weights
monet_generator.load_weights('best_monet_generator.h5')
photo_generator.load_weights('best_photo_generator.h5')
monet_discriminator.load_weights('best_monet_discriminator.h5')
photo_discriminator.load_weights('best_photo_discriminator.h5')

# Set the input directory for photo images
photo_dir = './photo_jpg'

# Set the output directory for generated images
output_dir = 'generated_monet_images'
os.makedirs(output_dir, exist_ok=True)

# Set the batch size for generating images
batch_size = 1

# Get the list of photo image files
photo_files = [f for f in os.listdir(photo_dir) if f.endswith('.jpg')]

# Create a dataset from the photo image files
photo_dataset = tf.data.Dataset.from_tensor_slices(photo_files)
photo_dataset = photo_dataset.map(lambda f: tf.image.decode_jpeg(tf.io.read_file(tf.strings.join([photo_dir, f], separator=os.path.sep))))
photo_dataset = photo_dataset.map(lambda x: (tf.cast(x, tf.float32) / 127.5) - 1)
photo_dataset = photo_dataset.batch(batch_size)

# Generate Monet-style images
for i, photo_batch in enumerate(photo_dataset):
    generated_images = monet_generator(photo_batch, training=False)
    
    for j in range(generated_images.shape[0]):
        image_index = i * batch_size + j
        
        # Rescale the pixel values to [0, 255]
        image = ((generated_images[j] + 1) * 127.5).numpy().astype(np.uint8)
        
        # Convert the array to a PIL image
        image = Image.fromarray(image)
        
        # Save the image as a PNG file
        image_path = os.path.join(output_dir, f'monet_image_{image_index}.png')
        image.save(image_path)
    
#     print(f'Generated {(i + 1) * batch_size} images')

# Create a zip file containing the generated images
zip_filename = 'generated_monet_images.zip'
with zipfile.ZipFile(zip_filename, 'w') as zip_file:
    for root, _, files in os.walk(output_dir):
        for file in files:
            file_path = os.path.join(root, file)
            zip_file.write(file_path, file)

print(f'Generated images are saved in {output_dir} and zipped in {zip_filename}')