This is a modified version of the Pix2Pix GAN model tensorflow notebook from the link below

## https://www.tensorflow.org/tutorials/generative/pix2pix

# 1) Import TensorFlow and other libraries

## 1.1) For file-handling operations (creating folders, finding images etc.)

In [None]:
# For File handling
import os

# To list all files in a folder given a path
from glob import glob

#Other system operations
import sys

## 1.2) Tkinter uses interactive windows to select folder paths

In [None]:
import tkinter
from tkinter import filedialog


#To discard a blank tkinter window that opens when the library is imported
tkinter.Tk().withdraw()

## 1.3) Deep Learning and basic machine learning libraries

In [None]:
# Deep learning library - Tensorflow
import tensorflow as tf

#Numpy for mathematical operations
import numpy as np

#Scikit learn for train test split
from sklearn.model_selection import train_test_split

## 1.4) Libraries for image processing operations

In [None]:
#matplotlib pyplot to plot/show images
from matplotlib import pyplot as plt

# Opencv for all image processing operations
import cv2

# 2) Dataset Path Selection

## 2.1) Dialog box opens asking for the path of the dataset:

Path should contain two folders (Case-sensitive):

### Images (H&E images without CD3+ data - Input Image)


### Masks (H&E images with CD3+ data - Reference/Ground Truth Image)

(These names can be changed in the first cell in step 3.1)

In [None]:
path=filedialog.askdirectory(title='path for whole dataset')

## 2.2) Dataset Identification

The following inputs along with the time of training help differentiate the pre-trained weights

(These names can be changed)

In [None]:
#Date of dataset creation
date="Oct15_Same"

#Variation of dataset
mask="Dataset1"

# 3 Dataset Loading

## 3.1) Loading and splitting dataset into training, test and validation

In [None]:
# A python function to load data

# Uses scikit-learn train-test split function

# E.g: train_test_split(train_x, test_size=test_size, random_state=42), Random state kept constant for the randomisation
# to be consistent whenever the same dataset is used

def load_data(path, split=0.02): # Can change split percentage
    
    images = sorted(glob(os.path.join(path, "Images/*")))  #Imports H&E Images with no CD3+ data (Input image)
    masks = sorted(glob(os.path.join(path, "Masks/*"))) # Imports H&E images with CD3+ data (Reference/Ground Truth Image)

    total_size = len(images) # All images
    valid_size = int(split * total_size) # Validation size
    test_size = int(split * total_size) # Testing size

    train_x, valid_x = train_test_split(images, test_size=valid_size, random_state=42) #  Train-validation split (Input Image)
    train_y, valid_y = train_test_split(masks, test_size=valid_size, random_state=42) #  Train-validation split (Reference
                                                                                                        #/Ground Truth Image)

    train_x, test_x = train_test_split(train_x, test_size=test_size, random_state=42) #  Train-test split (Input Image)
    train_y, test_y = train_test_split(train_y, test_size=test_size, random_state=42) #  Train-test split (Reference 
                                                                                                        #/Ground Truth Image)
    return (train_x, train_y), (valid_x, valid_y), (test_x, test_y) #Function return

## 3.1.1) Function Execution

load_data takes the path obtained from step 2.1

In [None]:
(train_x, train_y), (valid_x, valid_y), (test_x, test_y)=load_data(path) 

## 3.2) Reading images

Images (Input and reference) are read in an unsigned integer 8 format, i.e., each pixel can have a value between 0-255 only (both values inclusive)

If needed, the images are resized to a size of 256x256

The images are converted to a 32-bit floating point tensor and the images are mapped between -1 to 1.

* Images are converted from -1 to 1, as these are the outputs of the tanh activation function used by the generator U-Net Convolutional Neural Network (CNN)


* Training dataset input function is separated from testing and validation to include future data augmentation approaches and to separate batch sizing parameters

https://www.tensorflow.org/api_docs/python/tf/image

https://www.tensorflow.org/api_docs/python/tf/io/decode_png

https://www.tensorflow.org/api_docs/python/tf/cast

### 3.2.1)  Load_train dataset

In [None]:

# All tensorflow image processing operations are used, as the given (input and reference/ground truth) images 
# are converted to tensors


# a- input image

# b- reference/ground truth image


def load_train(a,b):
    
    #Reading input image from path using tensorflow's input-output module
    
    image1 = tf.io.read_file(a)
    
    #Decoding input image from path to unsigned integer 8 array using tensorflow's image module
    
    input_image = tf.image.decode_png(image1)
    
    #Resize input image to 256x256 if needed (Training dataset is expected to already be in 256x256 image patches)
   
    input_image=tf.image.resize(input_image, [256, 256],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    
    #Using tensorflow to cast input image data type from unsigned integer to floating point
    
    #Casting -> Converting a given variable from one data type (how information is represented) to another. No extra function
    
    input_image=tf.cast(input_image,tf.float32)
    
    # Mapping input image from 0 to 255 values to -1 to 1 values 
    
    input_image = (input_image / 127.5) - 1
    
    #Reading reference/ground-truth image from path using tensorflow's input-output module
    
    image2 = tf.io.read_file(b)
    
    #Decoding reference/ground-truth image from path to unsigned integer 8 array using tensorflow's image module
    
    real_image = tf.image.decode_png(image2)
    
     #Resize reference/ground-truth image to 256x256 if needed (Training dataset is expected to already be in 256x256 
                                                                                                            #image patches)

    real_image=tf.image.resize(real_image, [256, 256],
                      method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    
    #Using tensorflow to cast reference/ground-truth image data type from unsigned integer to floating point
    
    #Casting -> Converting a given variable from one data type (how information is represented) to another. No extra function

    real_image=tf.cast(real_image,tf.float32)
    
    # Mapping reference/ground-truth image from 0 to 255 values to -1 to 1 values 

    real_image = (real_image / 127.5) - 1
    
    #Function return
    
    return input_image, real_image


## 3.2.2) Validation/Test Dataset Loading

* For image inference, i.e., when using an external dataset for validation purposes, there is no need to read a reference image

The condtn parameter in load_valid needs to be set to false

In [None]:
def load_valid(a,b=0,condtn=True):
    
    #Reading input image from path using tensorflow's input-output module
    
    image1 = tf.io.read_file(a)
    
    #Decoding input image from path to unsigned integer 8 array using tensorflow's image module
    
    input_image = tf.image.decode_png(image1,channels=3)
    
    #Padding (adding) empty regions of input image less than 256x256 with zeros
    
    if(input_image.shape!=([256,256,3])):
        input_image=tf.image.pad_to_bounding_box(input_image,0,0,256,256)

    #Using tensorflow to cast input image data type from unsigned integer to floating point
    
    #Casting -> Converting a given variable from one data type (how information is represented) to another. No extra function
        
        
    input_image=tf.cast(input_image,tf.float32)
    
    input_image = (input_image / 127.5) - 1
    
    if(condtn==True): # Same Function can be used for H&E images only - for inference purposes (no ground truth data present)
        
        #Reading reference/ground-truth image from path using tensorflow's input-output library
        
        image2 = tf.io.read_file(b)
        
        #Decoding reference/ground-truth image from path to unsigned integer 8 array using tensorflow's image library
        
        real_image = tf.image.decode_png(image2)
        
         #Padding (adding) empty regions of input image less than 256x256 with zeros
    
        if(real_image.shape!=([256,256,3])):
            real_image=tf.image.pad_to_bounding_box(real_image,0,0,256,256)
            
        #Using tensorflow to cast input image data type from unsigned integer to floating point
    
        #Casting -> Converting a given variable from one data type (how information is represented) to another. 
        #No extra function
         

        real_image=tf.cast(real_image,tf.float32)
    
        # Mapping reference/ground-truth image from 0 to 255 values to -1 to 1 values 
        
        real_image = (real_image / 127.5) - 1

        return input_image, real_image # Return type 1-> Input and reference/ground-truth image
    else:
        return input_image # Return type 2-> Input image only

## 3.3) Tensorflow Dataset Creation with tensor slices

Datasets are created using the "tensor slices" method

https://www.tensorflow.org/guide/tensor_slicing

This calls the above functions per batch, leading to lesser GPU (Graphics Processing Unit) usage, compared to all the dataset being in the GPU memory at the same time using the "map" function


https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch


https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map


The dataset is also shuffled using the "shuffle" function with a buffer size

https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shuffle


In [None]:
BUFFER_SIZE = 400 #For shuffling a dataset


BATCH_SIZE = 350 #25 20

In [None]:
# Training Dataset Creation

train_dataset=tf.data.Dataset.from_tensor_slices((train_x,train_y))
train_dataset=train_dataset.shuffle(BUFFER_SIZE) #The Dataset is shuffled with a set buffer size
train_dataset=train_dataset.map(load_train)
train_dataset=train_dataset.batch(BATCH_SIZE) 

# Validation Dataset Creation

valid_dataset=tf.data.Dataset.from_tensor_slices((valid_x,valid_y))
valid_dataset=valid_dataset.map(load_valid)
valid_dataset=valid_dataset.batch(1) # Batch size is 1

# Testing Dataset Creation

test_dataset=tf.data.Dataset.from_tensor_slices((test_x,test_y))
test_dataset=test_dataset.map(load_valid)
test_dataset=test_dataset.batch(1) #Batch Size is 1

# 4) U-Net Generator Creation

  * The architecture of generator is a modified U-Net.
  * Each block in the encoder is (Conv -> Batchnorm -> Leaky ReLU)
  * Each block in the decoder is (Transposed Conv -> Batchnorm -> Dropout(applied to the first 3 blocks) -> ReLU)
  * There are skip connections between the encoder and decoder (as in U-Net)
  
https://www.tensorflow.org/api_docs/python/tf/keras/layers/ReLU


In [None]:
IMAGE_SIZE=256
OUTPUT_CHANNELS = 3

In [None]:
def downsample(filters, size, apply_batchnorm=True):
    
    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())

    result.add(tf.keras.layers.LeakyReLU())

    return result

In [None]:
down_model = downsample(3, 4)
down_result = down_model(tf.expand_dims(inp, 0))

In [None]:
def upsample(filters, size, apply_dropout=False):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                    padding='same',
                                    kernel_initializer=initializer,
                                    use_bias=False))

  result.add(tf.keras.layers.BatchNormalization())

  if apply_dropout:
      result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result

In [None]:
up_model = upsample(3, 4)
up_result = up_model(down_result)
print (up_result.shape)

In [None]:
def Generator():
  size=4
  inputs = tf.keras.layers.Input(shape=[256,256,3])

  down_stack = [
    downsample(64, size, apply_batchnorm=False), # (bs, 128, 128, 64)
    downsample(128,size), # (bs, 64, 64, 128)
    downsample(256, size), # (bs, 32, 32, 256)
    downsample(512, size), # (bs, 16, 16, 512)
    downsample(512, size), # (bs, 8, 8, 512)
    downsample(512, size), # (bs, 4, 4, 512)
    downsample(512, size), # (bs, 2, 2, 512)
    downsample(512, size), # (bs, 1, 1, 512)
  ]

  up_stack = [
    upsample(512, size, apply_dropout=True), # (bs, 2, 2, 1024)
    upsample(512, size, apply_dropout=True), # (bs, 4, 4, 1024)
    upsample(512, size, apply_dropout=True), # (bs, 8, 8, 1024)
    upsample(512, size), # (bs, 16, 16, 1024)
    upsample(256, size), # (bs, 32, 32, 512)
    upsample(128, size), # (bs, 64, 64, 256)
    upsample(64, size), # (bs, 128, 128, 128)
  ]

  initializer = tf.random_normal_initializer(0., 0.02)
  last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, size,
                                         strides=2,
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='tanh') # (bs, 256, 256, 3)

  x = inputs

  # Downsampling through the model
  skips = []
  for down in down_stack:
    x = down(x)
    skips.append(x)

  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    x = tf.keras.layers.Concatenate()([x, skip])

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
generator = Generator()


* **Generator loss**
  * It is a sigmoid cross entropy loss of the generated images and an **array of ones**.
  * The [paper](https://arxiv.org/abs/1611.07004) also includes L1 loss which is MAE (mean absolute error) between the generated image and the target image.
  * This allows the generated image to become structurally similar to the target image.
  * The formula to calculate the total generator loss = gan_loss + LAMBDA * l1_loss, where LAMBDA = 100. This value was decided by the authors of the [paper](https://arxiv.org/abs/1611.07004).

The training procedure for the generator is shown below:

In [None]:
LAMBDA = 100

In [None]:
def generator_loss(disc_generated_output, gen_output, target):
  gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)

  # mean absolute error
  l1_loss = tf.reduce_mean(tf.abs(target - gen_output))


  total_gen_loss = gan_loss + (LAMBDA * l1_loss)

  return total_gen_loss, gan_loss, l1_loss

![Generator Update Image](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/gen.png?raw=1)


# 5) Discriminator Design
  * The Discriminator is a PatchGAN.
  * Each block in the discriminator is (Conv -> BatchNorm -> Leaky ReLU)
  * The shape of the output after the last layer is (batch_size, 30, 30, 1)
  * Each 30x30 patch of the output classifies a 70x70 portion of the input image (such an architecture is called a PatchGAN).
  * Discriminator receives 2 inputs.
    * Input image and the target image, which it should classify as real.
    * Input image and the generated image (output of generator), which it should classify as fake.
    * We concatenate these 2 inputs together in the code (`tf.concat([inp, tar], axis=-1)`)

In [None]:
def Discriminator():
  size=4
  initializer = tf.random_normal_initializer(0., 0.02)

  inp = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')
  tar = tf.keras.layers.Input(shape=[256, 256, 3], name='target_image')

  x = tf.keras.layers.concatenate([inp, tar]) # (bs, 256, 256, channels*2)

  down1 = downsample(64, size, False)(x) # (bs, 128, 128, 64)
  down2 = downsample(128, size)(down1) # (bs, 64, 64, 128)
  down3 = downsample(256, size)(down2) # (bs, 32, 32, 256)

  zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
  conv = tf.keras.layers.Conv2D(512, size, strides=1,
                                kernel_initializer=initializer,
                                use_bias=False)(zero_pad1) # (bs, 31, 31, 512)

  batchnorm1 = tf.keras.layers.BatchNormalization()(conv)

  leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)

  zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

  last = tf.keras.layers.Conv2D(1, size, strides=1,
                                kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)

  return tf.keras.Model(inputs=[inp, tar], outputs=last)

In [None]:
discriminator = Discriminator()
tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=64)

In [None]:
disc_out = discriminator([inp[tf.newaxis,...], gen_output], training=False)
plt.imshow(disc_out[0,...,-1], vmin=-20, vmax=20, cmap='RdBu_r')
plt.colorbar()

**Discriminator loss**
  * The discriminator loss function takes 2 inputs; **real images, generated images**
  * real_loss is a sigmoid cross entropy loss of the **real images** and an **array of ones(since these are the real images)**
  * generated_loss is a sigmoid cross entropy loss of the **generated images** and an **array of zeros(since these are the fake images)**
  * Then the total_loss is the sum of real_loss and the generated_loss


In [None]:
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [None]:
def discriminator_loss(disc_real_output, disc_generated_output):
    real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)

    generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)

    total_disc_loss = real_loss + generated_loss
    #print(total_disc_loss)
    return total_disc_loss

The training procedure for the discriminator is shown below.

To learn more about the architecture and the hyperparameters you can refer the [paper](https://arxiv.org/abs/1611.07004).

![Discriminator Update Image](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/dis.png?raw=1)


## 4.2) Define the Optimizers and Checkpoint-saver


In [None]:
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
checkpoint_dir = './training_checkpoints/'+date+"/"+mask
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)
manager = tf.train.CheckpointManager(
    checkpoint, directory=checkpoint_prefix, max_to_keep=5)

## 5) Generate Images

Write a function to plot some images during training.

* We pass images from the test dataset to the generator.
* The generator will then translate the input image into the output.


Note: The `training=True` is intentional here since
we want the batch statistics while running the model
on the test dataset. If we use training=False, we will get
the accumulated statistics learned from the training dataset
(which we don't want)

In [None]:
def generate_images(model, test_input, tar):
    prediction = model(test_input, training=True)
    plt.figure(figsize=(15,15))

    display_list = [test_input[0], tar[0], prediction[0]]
    title = ['Input Image', 'Ref for style_transfer', 'Predicted Image']

    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        # getting the pixel values between [0, 1] to plot it.
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()

## 6) Training

* For each example input generate an output.
* The discriminator receives the input_image and the generated image as the first input. The second input is the input_image and the target_image.
* Next, we calculate the generator and the discriminator loss.
* Then, we calculate the gradients of loss with respect to both the generator and the discriminator variables(inputs) and apply those to the optimizer.
* Then log the losses to TensorBoard.

In [None]:
EPOCHS = 150

In [None]:
# Saving Directories with different dates
import datetime
log_dir="logs/"

summary_writer = tf.summary.create_file_writer(
  log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

In [None]:
@tf.function
def train_step(input_image, target, epoch):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    gen_output = generator(input_image, training=True)
    
    ps=tf.reduce_mean(tf.image.psnr(gen_output,target, max_val=1.0))

    ss=tf.reduce_mean(tf.image.ssim(gen_output,target, 1.0))

    disc_real_output = discriminator([input_image, target], training=True)
    disc_generated_output = discriminator([input_image, gen_output], training=True)

    gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
    disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

  generator_gradients = gen_tape.gradient(gen_total_loss,
                                          generator.trainable_variables)
  discriminator_gradients = disc_tape.gradient(disc_loss,
                                               discriminator.trainable_variables)

  generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
  discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))
  
  with summary_writer.as_default():
    tf.summary.scalar('gen_total_loss', gen_total_loss, step=epoch)
    
    tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=epoch)
    
    tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=epoch)
    tf.summary.scalar('disc_loss', disc_loss, step=epoch)
    #Additional Metrics maybe used as an alternative to L1-Loss
    tf.summary.scalar('ssim', ss, step=epoch) 
    tf.summary.scalar('psnr', ps, step=epoch) 
 

The actual training loop:

* Iterates over the number of epochs.
* On each epoch it clears the display, and runs `generate_images` to show it's progress.
* On each epoch it iterates over the training dataset, printing a '.' for each example.
* It saves a checkpoint every 20 epochs.

In [None]:
def fit(train_ds, epochs, test_ds):
    for epoch in range(epochs):
        start = time.time()

        display.clear_output(wait=True)

        for example_input, example_target in test_ds.take(3):
            generate_images(generator, example_input, example_target)
        
        print("Epoch: ", epoch)

    # Train
        for n, (input_image, target) in train_ds.enumerate():
            print('.', end='')
   
            train_step(input_image, target, epoch)

    # saving (checkpoint) the model every 5 epochs
                
        manager.save()

        print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                        time.time()-start))
    #checkpoint.save(file_prefix = checkpoint_prefix)

In [None]:
# restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_prefix))

This training loop saves logs you can easily view in TensorBoard to monitor the training progress. Working locally you would launch a separate tensorboard process. In a notebook, if you want to monitor with TensorBoard it's easiest to launch the viewer before starting the training.

To launch the viewer paste the following into a code-cell:

In [None]:
#docs_infra: no_execute
#!kill 1787
%load_ext tensorboard
%tensorboard --logdir {log_dir}  
#%reload_ext tensorboard

Now run the training loop:

In [None]:
fit(train_dataset, EPOCHS, valid_dataset)

## Restore the latest checkpoint and test

In [None]:
!ls {checkpoint_dir}

In [None]:
# restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_prefix))

## Testing trained model on datasets.

Tkinter is used to make this step interactive. It asks for the path of the input and output images

In [None]:
p_train=filedialog.askdirectory(title='path for train_results')
P_train_gt=os.path.join(p_train,"gt")
P_train_pred=os.path.join(p_train,"pred")
P_train_img=os.path.join(p_train,"img")
os.makedirs(P_train_img,exist_ok=True)
os.makedirs(P_train_gt,exist_ok=True)
os.makedirs(P_train_pred,exist_ok=True)
i=0
for x,y in train_dataset:
    for a,b in zip(x,y):
        v=generator(np.expand_dims(a,axis=0))
        img1=v[0]*0.5+0.5
        img2=b*0.5+0.5
        img3=a*0.5+0.5
        plt.imsave(P_train_pred+"/"+f"{i:06d}"+".png",img1.numpy())
        plt.imsave(P_train_gt+"/"+f"{i:06d}"+".png",img2.numpy())
        plt.imsave(P_train_img+"/"+f"{i:06d}"+".png",img3.numpy())
        print(i, end= " ")
        i=i+1