# CycleGAN for Document Denoising

## About CycleGAN

CycleGAN uses a cycle consistency loss to enable training without the need for paired data. In other words, it can translate from one domain to another without a one-to-one mapping between the source and target domain. 

This opens up the possibility to do a lot of interesting tasks like photo-enhancement, image colorization, style transfer, etc. All you need is the source and the target dataset (which is simply a directory of images).

![CycleGAN Image 1](https://miro.medium.com/max/1400/1*-7JKDTvulO6o4t4RRU5MJQ.png)
Fig.1 Conversion of original dirty input to its translated clean output
![CycleGAN Image 2](https://miro.medium.com/max/1400/1*0C34D2bEHmiyTbNzH8o5nQ.png)
Fig.2 Conversion of original clean input to its translated dirty output

## Set up the input pipeline

Install the [tensorflow_examples](https://github.com/tensorflow/examples) package that enables importing of the generator and the discriminator.

In [None]:
!pip install git+https://github.com/tensorflow/examples.git

## Import libraries and data

In [None]:
import tensorflow as tf
from tensorflow_examples.models.pix2pix import pix2pix

import os
import cv2
import numpy as np
from PIL import Image

AUTOTUNE = tf.data.AUTOTUNE

In [None]:
# special need for Google Colab
from google.colab import drive
drive.mount('/content/drive')
os.chdir("/content/drive/MyDrive/ColabNotebooks/DocDenoise")
!ls

In [None]:
# check GPU details
!nvidia-smi

In [None]:
path = 'data/'
to_process_path = 'to_process/'
processed_path = 'processed/'
to_process_img = sorted(os.listdir(path + to_process_path))

## Data preparation
Next step is to define function to process images and then store this images in list. As there is not as many data, we do not need to work in batches.

In [None]:
IMG_WIDTH = 3072
IMG_HEIGHT = 4096

# prepare function
def process_image(path):
    img = cv2.imread(path)
    img = np.asarray(img, dtype="float32")
    img = cv2.resize(img, (IMG_WIDTH, IMG_HEIGHT))
    img = img/255.0
    img = np.reshape(img, (IMG_HEIGHT, IMG_WIDTH, 3))
    
    return img

In [None]:
# preprocess images
chinese_invoice = []

for f in to_process_img:
    chinese_invoice.append(process_image(path + to_process_path + f))

chinese_invoice = np.asarray(chinese_invoice)

## Import and reuse the Pix2Pix models

In [None]:
OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

Initialize the optimizers for all the generators and the discriminators.

In [None]:
generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

## Checkpoints

In [None]:
checkpoint_path = "./checkpoints/cycleGAN"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=1)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

## Denoising and Save images

In [None]:
def tensor_to_image(tensor):
    tensor = tensor*255
    tensor = np.array(tensor, dtype=np.uint8)
    if np.ndim(tensor)>3:
        assert tensor.shape[0] == 1
        tensor = tensor[0]
    return Image.fromarray(tensor)

In [None]:
i = 0
for image in chinese_invoice:
  prediction = generator_g(image.reshape(1,IMG_HEIGHT,IMG_WIDTH,3))
  im_path = path + processed_path + to_process_img[i]
  im = tensor_to_image(prediction)
  im.save(im_path)
  i += 1

## Next steps
- Training the model on a larger dataset
- Tuning parameters to achieve greater performance
- Fine-tuning the models on a different dataset to implement more functions (e.g., watermark removal and motion deblur)