<a href="https://colab.research.google.com/github/PozzaMarco/VCS_Pix2Pix_Implementation/blob/main/VCS_create_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## This notebook provide all the functionalities to create a proper masked dataset for the pix2pix GAN.


# Imports

In [None]:
%%capture
!pip install tensorflow_addons

In [None]:
from tensorflow.keras import layers
import tensorflow_datasets as tfds
import tensorflow_addons as tfa
from tensorflow import keras
import tensorflow as tf

from matplotlib import image as image_loader
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm_notebook

import numpy as np
import random
import cv2
import os

import torch
import torchvision

import shutil
import time

from PIL import Image, ImageOps

# Seed
Setting seeds for reproducibility.

In [None]:
SEED = 42
keras.utils.set_random_seed(SEED)

# Functions
Utility function used in the whole project.

In [None]:
def delete_dataset(folder_num = 0):
  import shutil

  if folder_num == 1:
    shutil.rmtree('/content/dataset')
  elif folder_num == 2:
    shutil.rmtree('/content/joint_dataset')
  else:
    print(f"1: Dataset \n 2: Joint_dataset")

def extract_label(encoded_label):
  label = str(encoded_label.numpy())
  label = label.replace("'","")
  label = label[1:]
  return label

def create_dataset_folders(path):
  if(os.path.isdir(path) == False):
    os.makedirs(path)
    os.mkdir(path+"/train")
    os.mkdir(path+"/val")
  
  else:
    if(os.path.isdir(path+"/train") == False):
      os.mkdir(path+"/train")
    if(os.path.isdir(path+"/val") == False):
      os.mkdir(path+"/val")

def get_patches(image, image_dim, patch_size, mask_proprotions):
  num_patches = (image_dim // patch_size) ** 2

  resized_image = cv2.resize(image, dsize=(image_dim, image_dim), interpolation=cv2.INTER_CUBIC)
  batch_resized_image = np.expand_dims(resized_image, axis=0)

  patch_layer = Patches(patch_size=patch_size)
  patches = patch_layer(images=batch_resized_image)

  return patches, resized_image, patch_layer

def crop_resize_image(image, size):
  size = (size, size)
  new_image = ImageOps.fit(image, size, Image.ANTIALIAS)
  return new_image

def create_masked_image(image, patch_size, mask_extension):
  image_dim = 250
  projection_dim = 128

  patches, resized_image, patch_layer = get_patches(image,image_dim, patch_size, mask_extension)
  patch_encoder = PatchEncoder(patch_size, projection_dim, mask_extension)
  (
      unmasked_embeddings,
      masked_embeddings,
      unmasked_positions,
      mask_indices,
      unmask_indices,
  ) = patch_encoder(patches=patches)

  new_patch, random_index = patch_encoder.generate_masked_image(patches, unmask_indices)
  masked_img = patch_layer.reconstruct_from_patch(new_patch).numpy()

  if(masked_img.shape != (250, 250, 3)):
    masked_img = pad_img(masked_img)

  masked_image = cv2.cvtColor(masked_img, cv2.COLOR_BGR2RGB)
  return masked_image, resized_image

def pad_img(input_image):
  old_image_height, old_image_width, channels = input_image.shape

  # create new image of desired size and color (black) for padding
  new_image_width = 250
  new_image_height = 250
  color = (255,255,255)
  result = np.full((new_image_height,new_image_width, channels), color, dtype=np.uint8)

  # compute center offset
  x_center = (new_image_width - old_image_width) // 2
  y_center = (new_image_height - old_image_height) // 2

  # copy img image into center of result image
  result[y_center:y_center+old_image_height, 
        x_center:x_center+old_image_width] = input_image
  
  return result

def save_images(image, masked_image, filename):
  cv2.imwrite(filename, image)
  cv2.imwrite(filename+"_masked.jpg", cv2.cvtColor(masked_image, cv2.COLOR_RGB2BGR))
  pass

def create_masked_dataset(path, save_path, patch_size, mask_extension, test = False):
  new_size = 200
  list_files = sorted(os.listdir(path))
  num_images = len(list_files)
  create_dataset_folders(save_path)
  extension = path.split("/")[8]
  print(f"Processing {extension} images")

  if (test == True):
    num_images = 5
    
  for idx in tqdm_notebook(range(num_images)):
    img_name = list_files[idx]
    img_path = os.path.join(path, img_name)
    image = Image.open(img_path)

    preprocessed_image = crop_resize_image(image, new_size)
    image = cv2.cvtColor(np.array(preprocessed_image), cv2.COLOR_BGR2RGB)
    masked_image, resized_image = create_masked_image(image, patch_size, mask_extension)

    label =  img_name
    
    filename = save_path+"/"+extension+"/"+label    
    save_images(resized_image, masked_image, filename)
    
def join_images(path_to_dir):
  list_files = sorted(os.listdir(path_to_dir))

  extension = path_to_dir.split("/")[4]
  save_path = "/content/joint_dataset/data/"
  create_dataset_folders(save_path)

  save_path += extension+"/"
  print(save_path)
  
  for idx in tqdm_notebook(range(len(list_files))):
    if(idx % 2 == 0):
      img_file = list_files[idx]
      img_path = os.path.join(path_to_dir, img_file)
      image = np.array(Image.open(img_path))

      masked_img_file = list_files[idx + 1]
      masked_img_path = os.path.join(path_to_dir, masked_img_file)
      masked_image = np.array(Image.open(masked_img_path))

      full =  np.concatenate((image, masked_image), axis = 1)

      cv2.imwrite(save_path+img_file, cv2.cvtColor(full, cv2.COLOR_RGB2BGR))

# Classes
Implementation of the classes for creating the PatchEncoder that allow to create the masked patches with different sizes and extensions.

In [None]:
class Patches(layers.Layer):
    def __init__(self, patch_size, **kwargs):
        super().__init__(**kwargs)
        self.patch_size = patch_size

        # Assuming the image has three channels each patch would be
        # of size (patch_size, patch_size, 3).
        self.resize = layers.Reshape((-1, patch_size * patch_size * 3))

    def call(self, images):
        # Create patches from the input images
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )

        # Reshape the patches to (batch, num_patches, patch_area) and return it.
        patches = self.resize(patches)
        return patches

    def show_patched_image(self, images, patches):
        # This is a utility function which accepts a batch of images and its
        # corresponding patches and help visualize one image and its patches
        # side by side.
        idx = np.random.choice(patches.shape[0])
        print(f"Index selected: {idx}.")

        plt.figure(figsize=(4, 4))
        plt.imshow(keras.utils.array_to_img(images[idx]))
        plt.axis("off")
        plt.show()
        n = int(np.sqrt(patches.shape[1]))

        plt.figure(figsize=(4, 4))
        for i, patch in enumerate(patches[idx]):
            ax = plt.subplot(n, n, i + 1)
            patch_img = tf.reshape(patch, (self.patch_size, self.patch_size, 3))
            plt.imshow((keras.utils.img_to_array(patch_img)).astype(np.uint8))
            plt.axis("off")
        plt.show()

        # Return the index chosen to validate it outside the method.
        return idx

    # taken from https://stackoverflow.com/a/58082878/10319735
    def reconstruct_from_patch(self, patch):
        # This utility function takes patches from a *single* image and
        # reconstructs it back into the image. This is useful for the train
        # monitor callback.
        num_patches = patch.shape[0]
        n = int(np.sqrt(num_patches))
        patch = tf.reshape(patch, (num_patches, self.patch_size, self.patch_size, 3))
        rows = tf.split(patch, n, axis=0)
        rows = [tf.concat(tf.unstack(x), axis=1) for x in rows]
        reconstructed = tf.concat(rows, axis=0)
        return reconstructed

class PatchEncoder(layers.Layer):
    def __init__(
        self,
        patch_size,
        projection_dim,
        mask_proportion,
        downstream=False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.patch_size = patch_size
        self.projection_dim = projection_dim
        self.mask_proportion = mask_proportion
        self.downstream = downstream

        # This is a trainable mask token initialized randomly from a normal
        # distribution.
        self.mask_token = tf.Variable(
            tf.random.normal([1, patch_size * patch_size * 3]), trainable=True
        )

    def build(self, input_shape):
        (_, self.num_patches, self.patch_area) = input_shape

        # Create the projection layer for the patches.
        self.projection = layers.Dense(units=self.projection_dim)

        # Create the positional embedding layer.
        self.position_embedding = layers.Embedding(
            input_dim=self.num_patches, output_dim=self.projection_dim
        )

        # Number of patches that will be masked.
        self.num_mask = int(self.mask_proportion * self.num_patches)

    def call(self, patches):
        # Get the positional embeddings.
        batch_size = tf.shape(patches)[0]
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        pos_embeddings = self.position_embedding(positions[tf.newaxis, ...])
        pos_embeddings = tf.tile(
            pos_embeddings, [batch_size, 1, 1]
        )  # (B, num_patches, projection_dim)

        # Embed the patches.
        patch_embeddings = (
            self.projection(patches) + pos_embeddings
        )  # (B, num_patches, projection_dim)

        if self.downstream:
            return patch_embeddings
        else:
            mask_indices, unmask_indices = self.get_random_indices(batch_size)
            # The encoder input is the unmasked patch embeddings. Here we gather
            # all the patches that should be unmasked.
            unmasked_embeddings = tf.gather(
                patch_embeddings, unmask_indices, axis=1, batch_dims=1
            )  # (B, unmask_numbers, projection_dim)

            # Get the unmasked and masked position embeddings. We will need them
            # for the decoder.
            unmasked_positions = tf.gather(
                pos_embeddings, unmask_indices, axis=1, batch_dims=1
            )  # (B, unmask_numbers, projection_dim)
            masked_positions = tf.gather(
                pos_embeddings, mask_indices, axis=1, batch_dims=1
            )  # (B, mask_numbers, projection_dim)

            # Repeat the mask token number of mask times.
            # Mask tokens replace the masks of the image.
            mask_tokens = tf.repeat(self.mask_token, repeats=self.num_mask, axis=0)
            mask_tokens = tf.repeat(
                mask_tokens[tf.newaxis, ...], repeats=batch_size, axis=0
            )

            # Get the masked embeddings for the tokens.
            masked_embeddings = self.projection(mask_tokens) + masked_positions
            return (
                unmasked_embeddings,  # Input to the encoder.
                masked_embeddings,  # First part of input to the decoder.
                unmasked_positions,  # Added to the encoder outputs.
                mask_indices,  # The indices that were masked.
                unmask_indices,  # The indices that were unmaksed.
            )

    def get_random_indices(self, batch_size):
        # Create random indices from a uniform distribution and then split
        # it into mask and unmask indices.
        rand_indices = tf.argsort(
            tf.random.uniform(shape=(batch_size, self.num_patches)), axis=-1
        )
        mask_indices = rand_indices[:, : self.num_mask]
        unmask_indices = rand_indices[:, self.num_mask :]
        return mask_indices, unmask_indices

    def generate_masked_image(self, patches, unmask_indices):
        # Choose a random patch and it corresponding unmask index.
        idx = np.random.choice(patches.shape[0])
        patch = patches[idx]
        unmask_index = unmask_indices[idx]

        # Build a numpy array of same shape as patch.
        new_patch = np.zeros_like(patch)

        # Iterate of the new_patch and plug the unmasked patches.
        count = 0
        for i in range(unmask_index.shape[0]):
            new_patch[unmask_index[i]] = patch[unmask_index[i]]
        return new_patch, idx


#Create masked images
Procedure that loads the training/validation sets and uses two hyperparameters to set the patch size and the extension.
The variable "test" is used to create 5 images to see if the patch sizes and the extensions is as wanted.

In [None]:
Tpath_train = "/content/drive/MyDrive/VCS_datasets/extended_dataset/cub200/data/train"
path_val = "/content/drive/MyDrive/VCS_datasets/extended_dataset/cub200/data/val"
save_path = "/content/dataset/data"

patch_size = 35
mask_extension = 0.15
test = False

create_masked_dataset(path_train, save_path, patch_size, mask_extension, test)
create_masked_dataset(path_val, save_path, patch_size, mask_extension, test)

# Compact images
For each pair of images (patched - original) I create a new image that is doubled in length in order to have the patched image and the original image one after the other.
So from two images I create just one with the two images adjacent.

In [None]:
path_masked_train = "/content/dataset/data/train"
path_masked_val = "/content/dataset/data/val"

join_images(path_masked_train)
join_images(path_masked_val)

# Make Zip for save
Make a zip file of the newly created images for a easy download.

In [None]:
#Make dataset to zip to download it
!zip -r /content/caltech_birds_masked_35_15_joint.zip /content/joint_dataset/

# Delete datasets
After having downloaded the images,  reset the workspace by deleting the folder containing the datasets.
This is done for a fresh restart of the whole process of masked image generation.

##Delete dataset

In [None]:
delete_dataset(1)

## Delete joint dataset

In [None]:
delete_dataset(2)