In [None]:
try:
    import tensorflow
except:
    !pip install tensorflow
# Import necessary libraries 
import numpy as np 
import matplotlib.pyplot as plt 
from tensorflow.keras.datasets import mnist 
from tensorflow.keras.models import Model 
from tensorflow.keras.layers import Input, Dense, Flatten, Reshape 
from tensorflow.keras.optimizers import Adam 

# Load the dataset
(x_train, _), (x_test, _) = mnist.load_data()

# Normalize the data
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# Reshape the data to include the channel dimension
x_train = np.reshape(x_train, (len(x_train), 28, 28, 1))
x_test = np.reshape(x_test, (len(x_test), 28, 28, 1))

# Define the input shape for the autoencoder
input_shape = (28, 28, 1)

# Define the encoder part of the autoencoder
input_img = Input(shape=input_shape)  
x = Flatten()(input_img)  
encoded = Dense(64, activation='relu')(x)  

# Define the decoder part of the autoencoder
decoded = Dense(784, activation='sigmoid')(encoded)  
decoded = Reshape((28, 28, 1))(decoded)  

# Define the complete autoencoder model
autoencoder = Model(input_img, decoded)  

autoencoder.compile(optimizer=Adam(), loss='binary_crossentropy')  

# Print the summary of the autoencoder model
autoencoder.summary()

# Train the autoencoder
autoencoder.fit(x_train, x_train,  
   epochs=50,  # Number of epochs to train
   batch_size=256,  # Batch size for training
   shuffle=True,  
   validation_data=(x_test, x_test)
)  

# Predict the reconstructed images from the test set
decoded_imgs = autoencoder.predict(x_test)

# Number of digits to display
n = 10

# Create a figure with a specified size
plt.figure(figsize=(20, 4))

# Loop through the first n test images
for i in range(n):
   # Display the original image
   ax = plt.subplot(2, n, i + 1)  
   plt.imshow(x_test[i].reshape(28, 28), cmap='gray')  
   plt.title("Original")  # Set the title of the plot
   plt.axis('off')  

   # Display the reconstructed image
   ax = plt.subplot(2, n, i + 1 + n)  
   plt.imshow(decoded_imgs[i].reshape(28, 28), cmap='gray')  
   plt.title("Reconstructed")  
   plt.axis('off')  

# Show the figure
plt.show()



Epoch 1/50
