In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, GaussianNoise
from tensorflow.keras.optimizers import Adam

# Load and preprocess the MNIST dataset
(X_train, _), (X_test, _) = mnist.load_data()
X_train = X_train.reshape(-1, 784) / 255.0
X_test = X_test.reshape(-1, 784) / 255.0

# Add random noise to the images
noise_factor = 0.5
X_train_noisy = X_train + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=X_train.shape) 
X_test_noisy = X_test + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=X_test.shape) 
X_train_noisy = np.clip(X_train_noisy, 0., 1.)
X_test_noisy = np.clip(X_test_noisy, 0., 1.)

# Define the Denoising Autoencoder Model
input_img = Input(shape=(784,))
# Encoder
encoded = GaussianNoise(0.2)(input_img)  # additional noise layer for robustness
encoded = Dense(128, activation='relu')(encoded)
encoded = Dense(64, activation='relu')(encoded)
encoded_output = Dense(32, activation='relu')(encoded)

# Decoder
decoded = Dense(64, activation='relu')(encoded_output)
decoded = Dense(128, activation='relu')(decoded)
decoded_output = Dense(784, activation='sigmoid')(decoded)

# Autoencoder Model
autoencoder = Model(input_img, decoded_output)
autoencoder.compile(optimizer=Adam(), loss='mse')

# Train the Autoencoder
autoencoder.fit(X_train_noisy, X_train, epochs=10, batch_size=256, shuffle=True, validation_data=(X_test_noisy, X_test))

# Denoise the test images
denoised_images = autoencoder.predict(X_test_noisy)

# Visualize original noisy and denoised images
n = 10  # number of images to display
plt.figure(figsize=(20, 6))
for i in range(n):
    # Display noisy image
    ax = plt.subplot(3, n, i + 1)
    plt.imshow(X_test_noisy[i].reshape(28, 28), cmap='gray')
    plt.axis('off')
    
    # Display original image
    ax = plt.subplot(3, n, i + 1 + n)
    plt.imshow(X_test[i].reshape(28, 28), cmap='gray')
    plt.axis('off')
    
    # Display denoised image
    ax = plt.subplot(3, n, i + 1 + 2 * n)
    plt.imshow(denoised_images[i].reshape(28, 28), cmap='gray')
    plt.axis('off')

plt.suptitle("Top: Noisy Images | Middle: Original Images | Bottom: Denoised Images")
plt.show()
