In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# Import MNIST data
mnist = tf.keras.datasets.mnist

(x_train, _), (x_test, _) = mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.

x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]

In [None]:
noise_factor = 0.2

x_train_noise = x_train + noise_factor * tf.random.normal(shape=x_train.shape) 
x_test_noise = x_test + noise_factor * tf.random.normal(shape=x_test.shape) 


x_train_noise = tf.clip_by_value(x_train_noise, clip_value_min=0., clip_value_max=1.)
x_test_noise = tf.clip_by_value(x_test_noise, clip_value_min=0., clip_value_max=1.)


In [None]:
plt.figure(figsize=(10, 10))
for idx in range(25):
    plt.subplot(5, 5, idx+1)
    fig = plt.imshow(tf.squeeze(x_train_noise[idx]), cmap='gray')
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)

In [None]:
class Denoise(tf.keras.Model):
  def __init__(self):
    super(Denoise, self).__init__()
    self.encoder = tf.keras.Sequential([
      tf.keras.layers.Input(shape=(28, 28, 1)),
      tf.keras.layers.Conv2D(24, (3, 3), activation='relu', padding='same'),
      tf.keras.layers.MaxPooling2D((2, 2), padding="same"),
      tf.keras.layers.Conv2D(48, (3, 3), activation='relu', padding='same'),
      tf.keras.layers.MaxPooling2D((2, 2), padding="same")])

    self.decoder = tf.keras.Sequential([
      tf.keras.layers.Conv2DTranspose(48, kernel_size=(3, 3), strides=2, activation='relu', padding='same'),
      tf.keras.layers.Conv2DTranspose(24, kernel_size=(3, 3), strides=2, activation='relu', padding='same'),
      tf.keras.layers.Conv2D(1, kernel_size=(3, 3), activation='sigmoid', padding='same')])

  def call(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    return decoded

autoencoder = Denoise()


In [None]:
autoencoder.compile(optimizer='adam',
              loss=tf.keras.losses.BinaryCrossentropy()
              )

In [None]:
autoencoder.fit(x_train_noise, x_train,
                epochs=1,
                shuffle=True,
                batch_size=512,
                validation_data=(x_test_noise, x_test))


In [None]:
encoded_imgs = autoencoder.encoder(x_test).numpy()
decoded_imgs = autoencoder.decoder(encoded_imgs).numpy()


In [None]:
n = 10
plt.figure(figsize=(30, 6))
for i in range(n):

    # display original + noise
    ax = plt.subplot(3, n, i + 1)
    plt.title("original + noise")
    plt.imshow(tf.squeeze(x_test_noise[i]))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display reconstruction
    bx = plt.subplot(3, n, i + n + 1)
    plt.title("reconstructed")
    plt.imshow(tf.squeeze(decoded_imgs[i]))
    plt.gray()
    bx.get_xaxis().set_visible(False)
    bx.get_yaxis().set_visible(False)
    
    # display original
    ax = plt.subplot(3, n, i + 2*n + 1)
    plt.title("original")
    plt.imshow(tf.squeeze(x_test[i]))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()
