<a href="https://colab.research.google.com/github/anupamav/NVS-GAN/blob/main/NVS_GAN_V1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install tensorflow-addons

In [None]:
import tensorflow as tf
import numpy as np
import os
import pathlib
from PIL import Image
import random
import tensorflow_addons as tfa
import csv

from tensorflow.keras.layers import Input, DepthwiseConv2D, Conv2D,Concatenate, Activation,Conv2DTranspose,Flatten, add,  concatenate
from tensorflow.keras.layers import Reshape, Conv2DTranspose, BatchNormalization, UpSampling2D, Add, Layer, SeparableConv2D
from tensorflow.keras.layers import Dense, Input, ReLU, Lambda, LeakyReLU, ELU
from tensorflow.keras import datasets, layers
from tensorflow.keras.models import Model
from tensorflow.keras.losses import MeanSquaredError

from matplotlib import pyplot as plt
from IPython import display

In [None]:
from google.colab import drive
drive.mount('/content/drive/')

In [None]:
data = np.load('/content/drive/MyDrive/train_car_256.npy')

In [None]:
# Get the shape of the data array
n_models, n_angles, image_size, _, _ = data.shape

# Define training parameters
batch_size = 18  # Number of samples in each training batch
buffer_size = 1000 # Set the buffer size for shuffling
steps_per_epoch = (n_models * n_angles * n_angles) // batch_size  # Number of steps per training epoch
num_iterations = 1000 # Number of training epochs

# - `n_models`: The number of different models in the dataset.
# - `n_angles`: The number of angles for each model.
# - `image_size`: The size of the images (assuming square images).
# - `batch_size`: The number of samples in each training batch. Adjust based on memory constraints.
# - `steps_per_epoch`: The number of steps to complete one training epoch. It's calculated by dividing the total number of possible combinations by the batch size.
# - `epochs`: The number of training epochs, i.e., how many times the entire dataset is used for training.

In [None]:
output_folder = '/content/drive/MyDrive/NVS_GAN/V1'
os.makedirs(output_folder, exist_ok=True)

In [None]:
# Specify the path to the save training logs
log_dir = os.path.join(output_folder, 'logs')
# Specify the path to the saved model
generator_model_path = os.path.join(output_folder, 'generator_model')
# Specify the path to the save images
image_path = os.path.join(output_folder, 'image_test')

In [None]:
# Create one-hot encoded angles
one_hot_angles = tf.one_hot(np.arange(n_angles), depth=n_angles, dtype=tf.float32)

In [None]:
# Define generator function
def data_generator():
    """
    A data generator function that yields training data batches.

    Yields:
        tuple: A tuple containing source image, source angle one-hot encoding, target image, and target angle one-hot encoding.
    """
    while True:
        for model_idx in range(n_models):
            for source_angle_idx in range(n_angles):
                # Extract the source angle one-hot encoding for the current source angle index
                source_angle_one_hot = one_hot_angles[source_angle_idx]

                # Extract the source image for the current chair model and source angle
                source_image = data[model_idx, source_angle_idx]

                for target_angle_idx in range(n_angles):
                    # Extract the target image for the current chair model and target angle
                    target_image = data[model_idx, target_angle_idx]

                    # Extract the target angle one-hot encoding for the current target angle index
                    target_angle_one_hot = one_hot_angles[target_angle_idx]

                    # Extract the transformation angle one-hot encoding
                    transformation_azimuth = np.remainder(target_angle_idx - source_angle_idx + n_angles, n_angles)
                    transformation_one_hot = one_hot_angles[transformation_azimuth]

                    # Yield a tuple containing source image, source angle one-hot encoding,
                    # target image, and target angle one-hot encoding
                    yield source_image, target_image, transformation_one_hot, source_angle_one_hot, target_angle_one_hot

In [None]:
# Define the output types and shapes for the generator function
output_types = (tf.float32, tf.float32, tf.float32, tf.float32, tf.float32)
output_shapes = (
    tf.TensorShape((image_size, image_size, 3)),
    tf.TensorShape((image_size, image_size, 3)),
    tf.TensorShape((n_angles,)),
    tf.TensorShape((n_angles,)),
    tf.TensorShape((n_angles,))
)

# Create the dataset using the defined output types and shapes
dataset = tf.data.Dataset.from_generator(data_generator, output_types=output_types, output_shapes=output_shapes)

# Batch the dataset
dataset = dataset.batch(batch_size)

# Shuffle the dataset
dataset = dataset.shuffle(buffer_size=buffer_size)

dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

In [None]:
# Function to convert one-hot encoding to angle
def one_hot_to_angle(one_hot_encoding):
    """
    Converts a one-hot encoded array into its corresponding angle value.

    Args:
        one_hot_encoding (numpy.ndarray): One-hot encoded array with 1 at the index corresponding to the angle.

    Returns:
        float or None: The angle value in degrees if found, else None.
    """
    # Find the indices where the one-hot encoding has a value of 1
    angle_indices = np.where(one_hot_encoding == 1)[0]

    # If no angle is found (all 0s in one-hot encoding), return None
    if len(angle_indices) == 0:
        return None

    # Calculate the angle using the first angle index and the total number of angles
    # The angle calculation formula: Angle = Index * (360 degrees / Total number of angles)
    angle = angle_indices[0] * (360.0 / len(one_hot_encoding))
    return angle

In [None]:
def display_images_with_angles(source_image, target_image, source_angle_one_hot, target_angle_one_hot, predicted_image=None, test=False):
    """
    Displays images alongside their associated angles.

    Args:
        source_image (numpy.ndarray): The image of the source angle.
        target_image (numpy.ndarray): The image of the target angle.
        source_angle_one_hot (numpy.ndarray): One-hot encoded source angle.
        target_angle_one_hot (numpy.ndarray): One-hot encoded target angle.
        predicted_image (numpy.ndarray, optional): An image predicted by a model (if available).
        test (bool): Flag indicating whether to display predicted image (True) or not (False).

    Returns:
        None
    """
    # Convert one-hot encoded angles to angle values
    source_angle = one_hot_to_angle(source_angle_one_hot)
    target_angle = one_hot_to_angle(target_angle_one_hot)

    # Display images and angles
    if test:
        fig, axs = plt.subplots(1, 3, figsize=(12, 4))
        axs[0].imshow(source_image)
        axs[0].set_title(f'Source Image (Angle: {source_angle:.2f}°)')
        axs[1].imshow(target_image)
        axs[1].set_title(f'Target Image (Angle: {target_angle:.2f}°)')
        axs[2].imshow(predicted_image)
        axs[2].set_title(f'Predicted Image (Angle: {target_angle:.2f}°)')
    else:
        fig, axs = plt.subplots(1, 2, figsize=(8, 4))
        axs[0].imshow(source_image)
        axs[0].set_title(f'Source Image (Angle: {source_angle:.2f}°)')
        axs[1].imshow(target_image)
        axs[1].set_title(f'Target Image (Angle: {target_angle:.2f}°)')

    plt.tight_layout()
    plt.show()

In [None]:
def save_images(source_image, target_image, predicted_image, source_encoding, target_encoding, save_dir):
    """
    Save source, target, and predicted images with angle information and encodings in PNG format.

    Args:
        source_image (np.ndarray): Source image to be saved.
        target_image (np.ndarray): Target image to be saved.
        predicted_image (np.ndarray): Predicted image to be saved.
        source_encoding (np.ndarray): One-hot encoding for the source angle.
        target_encoding (np.ndarray): One-hot encoding for the target angle.
        save_dir (str): Directory to save the images.
    """
    os.makedirs(save_dir, exist_ok=True)

    # Convert one-hot encoded angles to angle values
    source_angle = one_hot_to_angle(source_encoding)
    target_angle = one_hot_to_angle(target_encoding)

    source_filename = f"source_{source_angle:.2f}.png"
    target_filename = f"target_{target_angle:.2f}.png"
    predicted_filename = f"predicted_{target_angle:.2f}.png"

    Image.fromarray(np.uint8(source_image * 255)).save(os.path.join(save_dir, source_filename))
    Image.fromarray(np.uint8(target_image * 255)).save(os.path.join(save_dir, target_filename))
    Image.fromarray(np.uint8(predicted_image * 255)).save(os.path.join(save_dir, predicted_filename))


In [None]:
'''
save_dir = os.path.join(output_folder, 'sample')

i = random.randint(0, (batch_size))

# Fetch and display a batch of data from the dataset
source_images, target_images, transformation_one_hot, source_one_hot_angles, target_one_hot_angles = next(iter(dataset))

display_images_with_angles(source_images[i], target_images[i], source_one_hot_angles[i], target_one_hot_angles[i])
#save_images(source_images[i], target_images[i], target_images[i], source_one_hot_angles[i], target_one_hot_angles[i], save_dir)
'''

In [None]:
# Define a custom layer for bilinear sampling
class BilinearSamplingLayer(Layer):
    def __init__(self, image_size, **kwargs):
        # Initialize the layer with the specified image_size
        self.image_size = image_size
        super().__init__(**kwargs)

    def call(self, tensors):
        # Unpack the input tensors: original image and predicted flow
        original_image, predicted_flow = tensors

        # Apply dense image warp with predicted flow scaled by image size
        warped_image = tfa.image.dense_image_warp(original_image, predicted_flow * self.image_size)
        return warped_image

    def compute_output_shape(self, tensor):
        # Calculate the output shape based on the input tensor shape
        input_shape = tensor[0]
        return None, input_shape[1], input_shape[2], input_shape[3]

In [None]:
def get_modified_decoder_layer(x_d0, x_e, current_attention_strategy, current_image_size, pred_flow=None):
    # Skip connection Strategies
    # (1) U-Net
    if current_attention_strategy == 'u_net':
        x_d = Concatenate()([x_e, x_d0])
        x_e_rearranged = x_e
    # (0) Vanilla
    else:
        x_d = x_d0
        x_e_rearranged = None

    return x_e_rearranged, x_d

In [None]:
pixel_normalizer = lambda x: (x - 0.5) * 2
pixel_normalizer_reverse = lambda x: x / 2 + 0.5
decoder_original_features = {}
encoder_original_features = {}
decoder_rearranged_features = {}

In [None]:
def movnetv1():
        # Build Keras model. Tried to follow the original paper as much as possible.
        activation = 'relu'
        current_image_size = image_size
        image_input = Input(shape=(current_image_size, current_image_size, 3), name='image_input')
        image_input_normalized = Lambda(pixel_normalizer)(image_input)

        i = 0
        x = image_input_normalized #image_input_normalized
        x = Conv2D(8, kernel_size=(3, 3), strides=(2, 2), padding='same')(x)
        x = ReLU()(x)
        x = SeparableConv2D(16, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
        x = BatchNormalization()(x)
        x = ReLU()(x)
        i = 1
        current_image_size = int(image_size / 2)
        encoder_original_features[current_image_size] = x

        while current_image_size > 2 :
          x = SeparableConv2D(16 * (2 ** i), kernel_size=(3, 3), strides=(2, 2), padding='same')(x)
          x = BatchNormalization()(x)
          x = ReLU()(x)
          x = SeparableConv2D(16 * (2 ** i), kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
          x = BatchNormalization()(x)
          x = ReLU()(x)
          i = i+1
          current_image_size = int(current_image_size / 2)
          if(current_image_size == 8):
            for repeat in range (4):
              x = SeparableConv2D(256, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
              x = BatchNormalization()(x)
              x = ReLU()(x)
              repeat = repeat+1
          encoder_original_features[current_image_size] = x

        x = Flatten()(x)
        hidden_layer_size = int(4096 / 256 * image_size)
        x = Dense(hidden_layer_size, activation=activation)(x)
        #x = Dense(hidden_layer_size, activation=activation)(x)

        viewpoint_input = Input(shape=(n_angles, ), name='viewpoint_input')

        v = Dense(128, activation=activation)(viewpoint_input)
        v = Dense(256, activation=activation)(v)

        concatenated = concatenate([x, v])
        concatenated = Dense(hidden_layer_size, activation=activation)(concatenated)
        #concatenated = Dense(hidden_layer_size, activation=activation)(concatenated)

        d = Reshape((2, 2, 1024))(concatenated)
        #d = SeparableConv2D(1024, kernel_size=(3, 3), strides=(1, 1), padding='same')(d)
        d = ReLU()(d)
        while current_image_size < image_size / 2 :
          current_image_size = current_image_size * 2
          # attention strategy at this layer.
          current_attention_strategy = 'unet'
          d = Conv2DTranspose(4 * (2 ** i), kernel_size=(3, 3), strides=(2, 2), padding='same')(d)
          d = ReLU()(d)
          #d = SeparableConv2D(4 * (2 ** i), kernel_size=(3, 3), strides=(1, 1), padding='same')(d)
          #d = ReLU()(d)
          i = i-1

          x_d0 = d
          x_e = encoder_original_features[current_image_size]
          x_e_rearranged, x_d = get_modified_decoder_layer(x_d0, x_e, current_attention_strategy, current_image_size)
          decoder_original_features[current_image_size] = x_d0
          decoder_rearranged_features[current_image_size] = x_e_rearranged
          d = x_d

        #d = SeparableConv2D(8, kernel_size=(3, 3), strides=(1, 1), padding='same')(d)
        #d = ReLU()(d)
        '''
        d = Conv2DTranspose(16, kernel_size=(3, 3), strides=(1, 1), padding='same')(d)
        d = ReLU()(d)
        d = Conv2DTranspose(8, kernel_size=(3, 3), strides=(1, 1), padding='same')(d)
        d = ReLU()(d)
        '''
        # final flow
        pred_flow = Conv2DTranspose(2, kernel_size=(3, 3), strides=(2, 2), padding='same')(d)

        # fetch pixels from original image
        pred_image = BilinearSamplingLayer(image_size)([image_input, pred_flow])
        return Model(inputs=[image_input, viewpoint_input], outputs=[pred_image])

In [None]:
generator = movnetv1()
#generator.summary()
tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)

In [None]:
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [None]:
# Load Pre-trained VGG16 model for feature extraction
vgg16 = tf.keras.applications.VGG16(weights='imagenet', include_top=False, input_shape=(image_size, image_size, 3))
vgg16.trainable = False # Freeze the VGG16 model

# Define the layers for feature extraction
#selected_layers = [vgg16.get_layer('block3_conv3').output, vgg16.get_layer('block4_conv3').output]
selected_layers = [layer.output for layer in vgg16.layers if 'conv' in layer.name]

# Create a custom model for feature extraction
feature_extractor = Model(inputs=vgg16.input, outputs=selected_layers)

# Define a loss function (e.g., Mean Squared Error)
mse_loss = MeanSquaredError()

def generator_loss(disc_generated_output, gen_output, target):
    gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
    # Define the weights for each loss component
    gan_weight = 0.01
    mae_weight = 1
    ssim_weight = 1
    perceptual_weight = 1

    # MAE Loss
    mae = tf.reduce_mean(tf.abs(target - gen_output))

    # SSIM Loss
    ssim = 1 - tf.image.ssim(target, gen_output, max_val=1.0)

    # Perceptual Loss
    real_features = feature_extractor(target)
    generated_features = feature_extractor(gen_output)
    perceptual_loss = 0.0
    for real_feat, gen_feat in zip(real_features, generated_features):
        perceptual_loss += mse_loss(real_feat, gen_feat)

    # Combine the losses using the defined weights
    total_loss = gan_weight * gan_loss + mae_weight * mae + ssim_weight * ssim + perceptual_weight * perceptual_loss

    # Custom metrics
    custom_metrics = {
        'gan_loss': gan_loss,
        'mae': mae,
        'ssim': ssim,
        'perceptual_loss': perceptual_loss
    }

    return total_loss, gan_loss, mae, ssim, perceptual_loss


In [None]:
def downsample(filters, size, apply_batchnorm=True):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

  if apply_batchnorm:
    result.add(tf.keras.layers.BatchNormalization())

  result.add(tf.keras.layers.LeakyReLU())

  return result

In [None]:
def Discriminator():
  initializer = tf.random_normal_initializer(0., 0.02)

  inp = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')
  tar = tf.keras.layers.Input(shape=[256, 256, 3], name='target_image')

  x = tf.keras.layers.concatenate([inp, tar])  # (batch_size, 256, 256, channels*2)

  down1 = downsample(16, 4, False)(x)  # (batch_size, 128, 128, 16)
  down2 = downsample(32, 4)(down1)  # (batch_size, 64, 64, 32)
  down3 = downsample(64, 4)(down2)  # (batch_size, 32, 32, 64)
  down4 = downsample(128, 4)(down3)  # (batch_size, 16, 16, 128)
  down5 = downsample(256, 4)(down4)  # (batch_size, 8, 8, 256)
  down6 = downsample(512, 4)(down5)  # (batch_size, 4, 4, 512)
  flat = Flatten()(down6)
  last = Dense(1, activation='relu')(flat)
  return tf.keras.Model(inputs=[inp, tar], outputs=last)

In [None]:
discriminator = Discriminator()
tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=64)

In [None]:
def discriminator_loss(disc_real_output, disc_generated_output):
  real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)

  generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(0.0002, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(0.0002, beta_1=0.5)

In [None]:
# Create a SummaryWriter for TensorBoard
summary_writer = tf.summary.create_file_writer(log_dir)

In [None]:
# Training loop
for iteration in range(num_iterations):
    source_images_batch, target_images_batch, transformation_batch, source_angle_batch, target_angle_batch = next(iter(dataset))

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator([source_images_batch, transformation_batch], training=True)
        real_output = discriminator([source_images_batch, target_images_batch], training=True)
        fake_output = discriminator([source_images_batch, gen_output], training=True)

        gen_total_loss, gen_gan_loss, gen_mae_loss, gen_ssim_loss, gen_perceptual_loss = generator_loss(fake_output, gen_output, target_images_batch)
        disc_loss = discriminator_loss(real_output, fake_output)

    # Compute the average loss per batch
    avg_gen_total_loss = tf.reduce_mean(gen_total_loss)
    avg_gen_gan_loss = tf.reduce_mean(gen_gan_loss)
    avg_gen_mae_loss = tf.reduce_mean(gen_mae_loss)
    avg_gen_ssim_loss = tf.reduce_mean(gen_ssim_loss)
    avg_gen_perceptual_loss = tf.reduce_mean(gen_perceptual_loss)
    avg_disc_loss = tf.reduce_mean(disc_loss)

    gradients_of_generator = gen_tape.gradient(gen_total_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    # Log the losses using summary_writer
    with summary_writer.as_default():
        tf.summary.scalar("Generator Total Loss", avg_gen_total_loss, step=iteration+1)
        tf.summary.scalar("Generator GAN Loss", avg_gen_gan_loss, step=iteration+1)
        tf.summary.scalar("Generator MAE Loss", avg_gen_mae_loss, step=iteration+1)
        tf.summary.scalar("Generator SSIM Loss", avg_gen_ssim_loss, step=iteration+1)
        tf.summary.scalar("Generator Perceptual  Loss", avg_gen_perceptual_loss, step=iteration+1)
        tf.summary.scalar("Discriminator Loss", avg_disc_loss, step=iteration+1)

    # Print the losses
    print('============================================================')
    print(f'Iteration {iteration + 1}')
    print(f'Total Loss: {float(avg_gen_total_loss.numpy()):.4f}')
    print(f'Generator Loss: {float(avg_gen_gan_loss.numpy()):.4f}')
    print(f'MAE: {float(avg_gen_mae_loss.numpy()):.4f}')
    print(f'SSIM: {float(avg_gen_ssim_loss.numpy()):.4f}')
    print(f'Perceptual: {float(avg_gen_perceptual_loss.numpy()):.4f}')
    print(f'Discriminator Loss: {float(avg_disc_loss.numpy()):.4f}')

In [None]:
%load_ext tensorboard
%tensorboard --logdir /content/drive/MyDrive/NVS_GAN/V1/logs

In [None]:
# Save the generator model in TensorFlow SavedModel format
generator_model_path = os.path.join(output_folder, 'generator_model')
tf.saved_model.save(generator, generator_model_path)

# Close the summary_writer
summary_writer.close()

In [None]:
# Fetch and display a batch of data from the dataset
source_images, target_images, transformation_one_hot, source_one_hot_angles, target_one_hot_angles = next(iter(dataset))

In [None]:
# Load the generator model
loaded_generator = tf.saved_model.load(generator_model_path)

In [None]:
pred_images = loaded_generator([source_images, transformation_one_hot])

In [None]:
i = random.randint(0, (batch_size))
print(i)
display_images_with_angles(source_images[i], target_images[i], source_one_hot_angles[i], target_one_hot_angles[i], pred_images[i], test=True)
save_images(source_images[i], target_images[i], pred_images[i], source_one_hot_angles[i], target_one_hot_angles[i], image_path)