In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D

In [None]:
!pip install mediapy

In [None]:
#Load the MNIST dataset and normalize the pixel values:

(x_train, _), (x_test, _) = mnist.load_data()
x_train = x_train / 255
x_test = x_test / 255

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [None]:
# Add noise to the images:

noise_factor = 0.4
x_train_noisy = x_train + noise_factor * np.random.normal(size=x_train.shape)
x_test_noisy = x_test + noise_factor * np.random.normal(size=x_test.shape)
x_train_noisy = np.clip(x_train_noisy, 0.0, 1.0)
x_test_noisy = np.clip(x_test_noisy, 0.0, 1.0)


In [None]:
# Define the autoencoder architecture:

input_img = Input(shape=(28, 28, 1))

# Encoder
x = Conv2D(32, (3, 3), activation='relu', padding='same')(input_img)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
encoded = MaxPooling2D((2, 2), padding='same')(x)

# Decoder
x = Conv2D(64, (3, 3), activation='relu', padding='same')(encoded)
x = UpSampling2D((2, 2))(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)

autoencoder = Model(input_img, decoded)
autoencoder.compile(optimizer='adam', loss='binary_crossentropy', metrics=["accuracy"])


In [None]:
#Train the autoencoder:

autoencoder.fit(x_train_noisy, x_train,
                epochs=20,
                batch_size=64,
                shuffle=True,
                validation_data=(x_test_noisy, x_test))


In [None]:
#Predict
import random
import mediapy
import PIL
import matplotlib.pyplot as plt


prediction = autoencoder.predict(x_test)

r = random.randint(0,len(x_test))

#Show
fig = plt.figure(figsize=(10, 7))
fig.add_subplot(1, 2, 1)
plt.imshow(x_test_noisy[r], cmap='Greys')

fig.add_subplot(1, 2, 2)
plt.imshow(prediction[r], cmap='Greys')