In [None]:
# Step a: Import necessary packages
from tensorflow import keras
from keras.datasets import cifar10
from keras.models import Sequential
from keras.layers import Flatten, Dense
from keras.optimizers import SGD
import matplotlib.pyplot as plt

# Step b: Load the training and testing data (CIFAR10)
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
print(f"Shape of X_train: {x_train.shape}")
print(f"Shape of y_train: {y_train.shape}")
print(f"Shape of X_test: {x_test.shape}")
print(f"Shape of y_test: {y_test.shape}")

# Display an image from the dataset
plt.imshow(x_train[60])
plt.show()
print(f"Label: {y_train[20]}")

# Display multiple images
for i in range(9):
    plt.subplot(330 + 1 + i)
    plt.imshow(x_train[i])
plt.show()

# Step c: Define the network architecture using Keras
model = Sequential([
    Flatten(input_shape=(32, 32, 3)),  # Flatten input from 32x32x3 to 3072
    Dense(512, activation='relu'),
    Dense(256, activation='relu'),
    Dense(128, activation='relu'),
    Dense(10, activation='softmax')  # Output layer for 10 classes
])

# Step d: Compile the model using SGD optimizer
model.compile(optimizer=SGD(learning_rate=0.01, momentum=0.9),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Train the model
history = model.fit(x_train, y_train, epochs=10, batch_size=32, validation_data=(x_test, y_test))

# Step e: Evaluate the network
test_loss, test_accuracy = model.evaluate(x_test, y_test)
print(f"Test Loss: {test_loss}")
print(f"Test Accuracy: {test_accuracy}")

# Step f: Plot the training loss and accuracy
plt.figure(figsize=(12, 4))

# Plot training & validation accuracy
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

# Plot training & validation loss
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.show()
