### Train a Simple Neural Network on MNIST 

This notebook demonstrates how to train a basic fully connected neural network on the MNIST handwritten digits dataset using TensorFlow/Keras.


In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

In [None]:
# Load data
(X_train, y_train), (X_test, y_test) = mnist.load_data()

print(y_train.shape, y_test.shape)

# Flatten 28x28 images to 784-dimensional vectors
X_train = X_train.reshape((-1, 784)).astype("float32") / 255.0
X_test = X_test.reshape((-1, 784)).astype("float32") / 255.0

# Use "one-hot" encoded labels for training
num_classes = 10
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)

In [None]:
# Build the neural network model as instructed
model = ???
model.summary()

# Compile and train the model


In [None]:
# Evaluate the accuracy of model on the test set


In [None]:
# Plot Loss function vs. Epochs
 

### Confusion Matrix

In [None]:
# Convert one-hot encoded labels back to integers
y_test_labels = np.argmax(y_test, axis=1)

# Predict class probabilities on test set
y_pred_probs = model.predict(X_test)
y_pred_labels = np.argmax(y_pred_probs, axis=1)

# Compute confusion matrix
cm = confusion_matrix(y_test_labels, y_pred_labels)

plt.figure(figsize=(8, 6))
plt.imshow(cm, interpolation='nearest', cmap='Blues')
plt.title("Confusion Matrix")
plt.colorbar()

tick_marks = np.arange(10)
plt.xticks(tick_marks, tick_marks)
plt.yticks(tick_marks, tick_marks)

plt.xlabel("Predicted Label")
plt.ylabel("True Label")

# Annotate cells
thresh = cm.max() / 2
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        plt.text(
            j, i, format(cm[i, j], "d"),
            ha="center", va="center",
            color="white" if cm[i, j] > thresh else "black"
        )

plt.tight_layout()
plt.show()

### Visualize some predictions

In [None]:
# Visualize predictions
fig, ax = plt.subplots(3, 6, figsize=(10, 7))
for i, axi in enumerate(ax.flat):
    axi.imshow(X_test[i].reshape(28, 28), cmap='gray')
    axi.set_title(f"True: {y_test_labels[i]}\nPred: {y_pred_labels[i]}")
    axi.axis('off')

plt.suptitle("TensorFlow MLP Predictions on MNIST Digits")
plt.show()