In [8]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# Load the CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# Convert the images to grayscale
x_train_bw = tf.image.rgb_to_grayscale(x_train)
x_test_bw = tf.image.rgb_to_grayscale(x_test)

# Resize the color images to match the grayscale images
x_train_resized = tf.image.resize(x_train, size=(x_train_bw.shape[1], x_train_bw.shape[2]))
x_test_resized = tf.image.resize(x_test, size=(x_test_bw.shape[1], x_test_bw.shape[2]))

# Define the model architecture
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation='sigmoid', padding='same', input_shape=(32, 32, 1)),
    tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation='sigmoid', padding='same'),
    tf.keras.layers.Conv2DTranspose(64, kernel_size=(3, 3), activation='sigmoid', padding='same'),
    tf.keras.layers.Conv2DTranspose(32, kernel_size=(3, 3), activation='sigmoid', padding='same'),
    tf.keras.layers.Conv2DTranspose(3, kernel_size=(3, 3), activation='sigmoid', padding='same'),
])

# Compile the model
model.compile(optimizer='adam', loss='mean_squared_error')

# Train the model
history = model.fit(x_train_bw, x_train_resized, epochs=50, batch_size=64, validation_data=(x_test_bw, x_test_resized))

# Plot the training and validation loss
train_loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(train_loss) + 1)

plt.plot(epochs, train_loss, label='Training Loss')
plt.plot(epochs, val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

# Use the trained model to generate predictions on the test set
decoded_imgs = model.predict(x_test_bw)

# Display a few test images and their reconstructions
n = 5
plt.figure(figsize=(10, 4))

for i in range(n):
    # Display original image
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(x_test_bw[i].numpy().reshape(32, 32), cmap='gray')
    plt.title('Original')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # Display reconstructed image
    ax = plt.subplot(2, n, i + n + 1)
    plt.imshow(decoded_imgs[i].reshape(32, 32, 3))
    plt.title('Reconstructed')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

plt.show()


Epoch 1/50
Epoch 2/50
180/782 [=====>........................] - ETA: 8:20 - loss: 18425.9766

KeyboardInterrupt: ignored