<a href="https://colab.research.google.com/github/filsto/GAN/blob/main/AutoEncoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

print("size of:")
print("- training set:\t\t{}".format(len(mnist.train.labels)))
print("- test set:\t\t{}".format(len(mnist.test.labels)))
print("- validation set:\t\t{}".format(len(mnist.validation.labels)))

Hyperparameters

In [None]:
#hyper parameters
logs_path = "./logs/noiseRemoval"

learning_rate = 0.001
epochs = 10
batch_size = 100
display_freq = 100

#network parameters
img_h = img_w = 28
img_size_flat = img_h * img_w
h1 = 100
noise_level = 0.6

Graph

In [None]:
#weight and bias wrappers
def weight_variable(name, shape):
  """
  create a weight variable with appropriate initialization
  :param name: weight name
  :para shape: weight shape
  :return: intialized weight variable
  """

  initer = tf.truncated_normal_initializer(stddev=0.01)
  return tf.get_variable('W_' +name, dtype=tf.float32, shape=shape, initializer=initer)

def bias_variable(name, shape):
  """
  create a bias variable with appropriate initialization
  :param name: bias variable name
  :param shape: bias variable shape
  :return: intialized bias variable
  """

  initial = tf.constant(0., shape=shape, dtype=tf.float32)
  return tf.get_variable('b' +name, dtype=tf.float32, initializer=initial)

def fc_layer(x, num_units, name, use_relu=True):
  """
  create a fully connected layer
  :param x: input from previous layer
  :param num_units: nb of hidden units in the fully-connected layer
  :param name: layer name
  :param use_relu: boolean to add ReLU non-linearity (or not)
  :return: the output array
  """

  with tf.variable_scope(name):
    in_dim = x.get_shape()[1]
    W = weight_variable(name, shape=[in_dim, num_units])
    tf.summary.histogram('W', W)
    b = bias_variable(name, [num_units])
    tf.summary.histogram('b',b)
    layer = tf.matmul(x, W)
    layer +=b
    if use_relu:
      layer = tf.nn.relu(layer)
      return layer


In [None]:
#create graphs
#placeholders for inputs(x), outputs(y)
with tf.variable.scope('Input'):
  x_original = tf.placeholder(tf.float32, shape=[None, img_size_flat], name='X_original')
  x_noisy = tf.placeholder(tf.float32, shape=[None, img_size_flat], name='X_noisy')

fc1 = fc_layer(x_noisy, h1, 'Hidden layer', use_relu=True)
out = fc_layer(fc1, img_size_flat, 'Output_layer', use_relu=False)

#define the loss function, optimizer, and accuracy
with tf.variable_scope('Train'):
  with tf.variable_scope('Loss'):
    loss = tf.reduce_mean(tf.losses.mean_squared_error(x_original, out), name='loss')
    tf.summary.scalar('loss', loss)
  with tf.variable_scope('Optimizer'):
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, name='Adam-op').minimze(loss)

#initializing the variables
init = tf.global_variables_initializer()

#add 5 images from original, noisy and reconstructed samples to summaries
tf.summary.image('original', tf.reshape(x_original,(-1, img_w, img_h, 1)), max_outputs=5)
tf.summary.image('noisy', tf.reshape(x_noisy, (-1, img_w, img_h, 1)), max_outputs=5)
tf.summary.image('reconstructed', tf.reshape(out, (-1, img_w, img_h, 1)), max_outputs=5)

#merge all the summaries
merged = tf.summary.merge_all()

Train

In [None]:
#launch the graph (session)
sess = tf.InteractiveSession()
sess.run(init)

train_writer = tf.summary.FileWriter(logs_path, sess.graph)
num_tr_iter = int(mnist.train.num_examples / batch_size)
global_step = 0

for epoch in range(epochs):
  print('Training epoch: {}'.format(epoch +1))
  for iteration in range(num_tr_iter):
    batch_x, _ = mnist.train.next_batch(batch_size)
    batch_x_noisy = batch_x + noise_level * np.random.normal(loc=0.0, scale=1.0, size=batch_x.shape)
    global_step +=1
    
    #run optimizatyion op (backprop)
    feed_dict_batch = {x_original: batch_x, x_noisy :batch_x_noisy}
    _, summary_tr = sess.run([optimizer, merged], feed_dict=feed_dict_batch)
    train_writer.add_summary(summary_tr, global_step)

    if iteration % display_freq == 0:
      #calculate and display the batch loss and accuracy
      loss_batch = sess.run(loss, feed_dict = feed_dict_batch)
      print('iter {0:3d}\t Reconstruction loss={1:3f}'.format(iteration, loss_batch))

  #run validation after every epoch
  x_valid_original = mnist.validation.images
  x_valid_noisy = x_valid_original + noise_level * np.random.normal(loc=0.0, scale=1.0, size=x_valid_original.shape)

  feed_dict_valid = {x_original: x_valid_original, x_noisy: x_valid_noisy}
  loss_valid = sess.run(loss, feed_dict = feed_dict_valid)
  print('______________________________________')
  print('Epoch: {0}, validation loss: {1:.3f}'.format(epoch +1, loss_valid))
  print('______________________________________')

Test

In [None]:
def plot_images(original_images, noisy_images, reconstructed_images):
  """
  create figure of original and reconstructed image.
  :param original_image: original images to be plotted, (?, img_h*img_w)
  :param noisy_image: noisy images to be plotted, (?, img_h*img_w)
  :param reconstructed_image: reconstructed images to be plotted, (?, img_h*img_w)
  """

  num_images = original_images.shape[0]
  fig.axes = plt.subplots(num_images, 3, figsize=(9,9))
  fig.subplots_adjust(hspace=.1, wspace=0)

  img_h = img_w = np.sqrt(original_images.shape[-1]).astype(int)
  for i, ax in enumerates(axes):
    #plot image
    ax[0].imshow(original_images[i].reshape((img_h, img_w)),cmap='gray')
    ax[1].imshow(noisy_images[i].reshape((img_h, img_w)),cmap='gray')
    ax[2].imshow(reconstructed_images[i].reshape((img_h, img_w)),cmap='gray')

    #remove ticks from the plot
    for sub_ax in ax:
      sub_ax.set_xticks([])
      sub_ax.set_yticks([])

  for ax, col in zip(axes[0], ['Original Image','Noisy Image','reconstructed Image']):
    ax.set_title(col)
    fig.tight_layout()
    plt.show()

In [None]:
#test the network after training

#make a noisy image
test_samples = 5
x_test = mnist.test.images[:test_samples]
x_test_noisy = x_test + noise_level *np.random.normal(loc=0.0, scale=1.0, size=x_test.shape)

#reconstruct a clean image from noisy image
x_reconstruct = sess.run(out, feed_dict={x_noisy:x_test_noisy})

#calculate the loss between reconstructed image and original image
loss_test = sess.run(loss, feed_dict={x_original:x_test, x_noisy:x_test_noisy})
print('------------------------------------------------')
print('test loss of original image compared to reconstructed image :{0:.3f}'.format(loss_test))
print('------------------------------------------------')

#plot original image, noisy image, and reconstructed image
plot_images(x_test, x_test_noisy, x_reconstruct)

In [None]:
# close the session after being done with testing
sess.close()