In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Input, Dense, Flatten, Reshape
from tensorflow.keras.models import Model
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import cifar10

In [None]:
# Load CIFAR-10 dataset
(x_train, _), (x_test, _) = cifar10.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

In [None]:
# Flatten the 32x32x3 images into 3072-dim vectors
x_train = x_train.reshape((len(x_train), -1))
x_test = x_test.reshape((len(x_test), -1))

In [None]:
# Define encoding dimension
encoding_dim = 128

In [None]:
# Build encoder
input_img = Input(shape=(3072,))
encoded = Dense(encoding_dim, activation='relu')(input_img)

In [None]:
# Build decoder
decoded = Dense(3072, activation='sigmoid')(encoded)

In [None]:
# Autoencoder model
autoencoder = Model(input_img, decoded)

In [None]:
# Compile the model
autoencoder.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

In [None]:
# Train the model
history = autoencoder.fit(x_train, x_train,
                          epochs=20,
                          batch_size=256,
                          shuffle=True,
                          validation_data=(x_test, x_test))

In [None]:
# Extract the encoder model
encoder = Model(input_img, encoded)

# Encode test images
encoded_imgs = encoder.predict(x_test)

# Decode test images
decoded_imgs = autoencoder.predict(x_test)

In [None]:
n = 10  # Number of images to display
plt.figure(figsize=(20, 4))
for i in range(n):
    # Original
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(x_test[i].reshape(32, 32, 3))
    plt.axis('off')

    # Reconstructed
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(decoded_imgs[i].reshape(32, 32, 3))
    plt.axis('off')
plt.show()

In [None]:
plt.figure(figsize=(8, 5))
plt.plot(history.history['loss'], label='Train Loss', color='blue')
plt.plot(history.history['val_loss'], label='Validation Loss', color='red')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curve')
plt.show()

In [None]:
# Handle potential metric name issues
train_acc_key = 'accuracy' if 'accuracy' in history.history else 'binary_accuracy'
val_acc_key = 'val_accuracy' if 'val_accuracy' in history.history else 'val_binary_accuracy'

plt.figure(figsize=(8, 5))
plt.plot(history.history[train_acc_key], label='Train Accuracy', color='blue')
plt.plot(history.history[val_acc_key], label='Validation Accuracy', color='red')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy Curve')
plt.show()