In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
!unzip /content/gdrive/MyDrive/raw-890.zip

In [None]:
import tensorflow.compat.v1 as tf
from tensorflow.python.framework import ops
import numpy as np
import os #, shutil
from tensorflow.keras.layers import *
import glob
import random
import cv2
from google.colab.patches import cv2_imshow
import matplotlib.pyplot as plt
import matplotlib as mpl
import pickle


import fusionmodule as fuseMod
import contrastmodule as contMod
import illuminancemodule as illumMod
from PIL import Image, ImageStat

tf.disable_v2_behavior()

## **Hyperparameters**

In [None]:
n_epochs = 10
batch_size = 8
learning_rate = 1e-4
weight_decay = 1e-4

## K-Estimation Module

In [None]:
def Aod_net(X):

  c1 = Conv2D(3,1,1,padding="SAME",activation="relu",kernel_initializer=tf.initializers.random_normal(stddev=0.02),
                kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(X)

  c2 = Conv2D(3,3,1,padding="SAME",activation="relu",kernel_initializer=tf.initializers.random_normal(stddev=0.02),
                kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(c1)

  c1c2 = tf.concat([c1,c2],axis=-1)

  c3 = Conv2D(3,5,1,padding="SAME",activation="relu",kernel_initializer=tf.initializers.random_normal(stddev=0.02),
                kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(c1c2)

  c2c3 = tf.concat([c2,c3],axis=-1)

  c4 = Conv2D(3,7,1,padding="SAME",activation="relu",kernel_initializer=tf.initializers.random_normal(stddev=0.02),
                kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(c2c3)

  c1c2c3c4 = tf.concat([c1,c2,c3,c4],axis=-1)

  c5 = Conv2D(3,3,1,padding="SAME",activation="relu",kernel_initializer=tf.initializers.random_normal(stddev=0.02),
                kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(c1c2c3c4)

  K = c5

  output = ReLU(max_value=1.0)(tf.math.multiply(K,X) - K + 1.0)

  return output

## **Data Loading & Pre-processing**

In [None]:
def setup_data_paths(orig_images_path,hazy_images_path):

  orig_image_paths = glob.glob(orig_images_path + "/*.png")
  n = len(orig_image_paths)
  random.shuffle(orig_image_paths)

  train_keys = orig_image_paths[:int(0.90*n)]
  val_keys = orig_image_paths[int(0.90*n):]


  split_dict = {}
  for key in train_keys:
    split_dict[key] = 'train'
  for key in val_keys:
    split_dict[key] = 'val'

  train_data = []
  val_data = []

  hazy_image_paths = glob.glob(hazy_images_path + "/*.png")
  for path in hazy_image_paths:
    label = path.split('/')[-1]
    orig_path = orig_images_path + "/" + label
    if(split_dict[orig_path] == 'train'):
      train_data.append([path,orig_path])
    else: val_data.append([path,orig_path])

  return train_data, val_data


In [None]:
def load_image(X):
  raw = tf.io.read_file(X)
  x= tf.image.decode_jpeg(raw,channels=3)
  X = tf.image.resize(x,(480,640))
  X = X / 255.0
  return X

def showImage(x):
  x = np.asarray(x*255,dtype=np.int32)
  plt.figure()
  plt.imshow(x)
  plt.show()

In [None]:
def create_datasets(train_data,val_data,batch_size):

  train_ds_hazy = tf.data.Dataset.from_tensor_slices([data[0] for data in train_data]).map(lambda x: load_image(x))
  train_ds_orig = tf.data.Dataset.from_tensor_slices([data[1] for data in train_data]).map(lambda x: load_image(x))
  train_ds = tf.data.Dataset.zip((train_ds_hazy,train_ds_orig)).shuffle(100).repeat().batch(batch_size)

  val_ds_hazy = tf.data.Dataset.from_tensor_slices([data[0] for data in val_data]).map(lambda x: load_image(x))
  val_ds_orig = tf.data.Dataset.from_tensor_slices([data[1] for data in val_data]).map(lambda x: load_image(x))
  val_ds = tf.data.Dataset.zip((val_ds_hazy,val_ds_orig)).shuffle(100).repeat().batch(batch_size)

  iterator = tf.data.Iterator.from_structure(tf.data.get_output_types(train_ds),tf.data.get_output_shapes(train_ds))

  train_init_op = iterator.make_initializer(train_ds)
  val_init_op = iterator.make_initializer(val_ds)

  return train_init_op, val_init_op, iterator


## **Training**

In [None]:
np.random.seed(9999)
ops.reset_default_graph()
train_data, val_data = setup_data_paths(orig_images_path="/content/reference-890", hazy_images_path="/content/raw-890");
train_init_op, val_init_op, iterator = create_datasets(train_data,val_data,batch_size)
next_element = iterator.get_next()

X = tf.placeholder(shape=(None,480,640,3),dtype=tf.float32)
Y = tf.placeholder(shape=(None,480,640,3),dtype=tf.float32)
dehazed_X = Aod_net(X)

loss = tf.reduce_mean(tf.square(dehazed_X-Y))
optimizer = tf.train.AdamOptimizer(learning_rate)
trainable_variables = tf.trainable_variables()
gradients_and_vars = optimizer.compute_gradients(loss,trainable_variables)
clipped_gradients = [(tf.clip_by_norm(gradient,0.1),var) for gradient,var in gradients_and_vars]
optimizer = optimizer.apply_gradients(gradients_and_vars)


Instructions for updating:
Use `tf.compat.v1.data.get_output_types(iterator)`.
Instructions for updating:
Use `tf.compat.v1.data.get_output_shapes(iterator)`.
Instructions for updating:
Use `tf.compat.v1.data.get_output_classes(iterator)`.


In [None]:
saver = tf.train.Saver()
load_path = None

with tf.device('/gpu:0'):
  with tf.Session() as sess:

    sess.run(tf.global_variables_initializer())
    for epoch in range(n_epochs):

      sess.run(train_init_op)
      batches = len(train_data) // batch_size
      epoch_loss = 0.0
      for batch in range(batches):

        batch_data = sess.run(next_element)
        batch_loss, _ = sess.run([loss,optimizer],feed_dict={X:batch_data[0],Y:batch_data[1]})
        epoch_loss += batch_loss / float(batches)
        if batch % 1000 == 0:
          print("Training loss at batch %d: %f\n"%(batch,batch_loss))

      train_loss = epoch_loss

      sess.run(val_init_op)
      batches= len(val_data) // batch_size
      epoch_loss = 0.0
      for batch in range(batches):
        batch_data = sess.run(next_element)
        batch_loss = sess.run(loss,feed_dict={X:batch_data[0],
                                             Y:batch_data[1]})
        epoch_loss += batch_loss / float(batches)
        if batch % 100 == 0:
          print("Validation loss at batch %d: %f\n"%(batch,batch_loss))
          for j in range(-2 + batch_size//2):
            x = batch_data[0][j].reshape((1,)+batch_data[0][j].shape)
            y = batch_data[1][j].reshape((1,)+batch_data[1][j].shape)
            dehazed_x = sess.run(dehazed_X,feed_dict={X:x,Y:y})
            print("Image Number: %d\n"%(j))
            showImage(x[0])
            showImage(y[0])
            showImage(dehazed_x[0])
      val_loss = epoch_loss

      saver.save(sess,'/content/models/model_checkpoint_' + str(epoch) + '.h5')

      print("Epoch %d\nTraining loss: %f\nValidation loss: %f\n\n"%(epoch,train_loss,val_loss))


In [None]:
next_element = iterator.get_next()

with tf.Session() as sess:
  sess.run(val_init_op)

  for i in range(10):
    batch_data = sess.run(next_element)
    for j in range(4):
      x = batch_data[0][j].reshape((1,)+batch_data[0][j].shape)
      #showImage(x)

## **Evaluation**

In [None]:
tf.reset_default_graph()
train_data, val_data = setup_data_paths(orig_images_path="/content/reference-890",
                                        hazy_images_path = "/content/raw-890");
train_init_op, val_init_op, iterator = create_datasets(train_data,val_data,batch_size)
next_element = iterator.get_next()

X = tf.placeholder(shape=(None,480,640,3),dtype=tf.float32)
Y = tf.placeholder(shape=(None,480,640,3),dtype=tf.float32)
dehazed_X = Aod_net(X)

loss = tf.reduce_mean(tf.square(dehazed_X-Y))
optimizer = tf.train.AdamOptimizer(learning_rate)
trainable_variables = tf.trainable_variables()
gradients_and_vars = optimizer.compute_gradients(loss,trainable_variables)
clipped_gradients = [(tf.clip_by_norm(gradient,0.1),var) for gradient,var in gradients_and_vars]
optimizer = optimizer.apply_gradients(gradients_and_vars)


##Final Results

In [None]:
import pickle
pickle_out = open("img_model.pkl", mode = "wb")
pickle.dump('/content/models/model_checkpoint_9.h5.data-00000-of-00001', pickle_out)
pickle_out.close()

In [None]:
saver = tf.train.Saver()


test_input_folder = "/content/challenging-60"
test_output_folder = "/content/dehazed_test_images"
if not os.path.exists(test_output_folder):
  os.mkdir(test_output_folder)

file_types = ['jpeg','jpg','png']

with tf.Session() as sess:
  saver.restore(sess,'/content/models/model_checkpoint_9.h5')
  test_image_paths = []
  for file_type in file_types:
    test_image_paths.extend(glob.glob(test_input_folder+"/*."+file_type))


  for path in test_image_paths:
    image_label = path.split(test_input_folder)[-1][1:]
    image = Image.open(path)
    image = image.resize((640, 480))
    image = np.asarray(image) / 255.0
    image = image.reshape((1,) + image.shape)
    dehazed_image = sess.run(dehazed_X,feed_dict={X:image,Y:image})


    fig, axes = plt.subplots(nrows=1, ncols=2,figsize=(10,10))
    axes[0].imshow(image[0])
    axes[1].imshow(dehazed_image[0])
    fig.tight_layout()

    dehazed_image = np.asarray(dehazed_image[0] * 255,dtype=np.uint8)
    mpl.image.imsave(test_output_folder + "/" + 'dehazed_' + image_label, dehazed_image)

