# Lab 6: CNN on MNIST

Simplified and cleaned code from your lab manual.

In [None]:
# Lab 6: CNN for MNIST classification
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt

# Load and preprocess MNIST
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')/255.0; x_test = x_test.astype('float32')/255.0
x_train = x_train.reshape((x_train.shape[0],28,28,1)); x_test = x_test.reshape((x_test.shape[0],28,28,1))
y_train = to_categorical(y_train, 10); y_test = to_categorical(y_test, 10)

model = Sequential([
    Conv2D(32, kernel_size=(3,3), activation='relu', input_shape=(28,28,1)),
    MaxPooling2D(pool_size=(2,2)),
    Dropout(0.25),
    Conv2D(64, kernel_size=(3,3), activation='relu'),
    MaxPooling2D(pool_size=(2,2)),
    Dropout(0.25),
    Flatten(),
    Dense(128, activation='relu'),
    Dropout(0.5),
    Dense(10, activation='softmax')
])

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
history = model.fit(x_train, y_train, validation_split=0.1, epochs=10, batch_size=128, verbose=2)

test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print(f'Test accuracy: {test_acc:.4f}')

# Plot training results
plt.figure(figsize=(12,4))
plt.subplot(1,2,1)
plt.plot(history.history['accuracy']); plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy'); plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.legend(['Train','Val'])
plt.subplot(1,2,2)
plt.plot(history.history['loss']); plt.plot(history.history['val_loss'])
plt.title('Model loss'); plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(['Train','Val'])
plt.show()

# Show a few predictions
preds = model.predict(x_test[:5])
pred_labels = [tf.argmax(p).numpy() for p in preds]
plt.figure(figsize=(8,3))
for i in range(5):
    plt.subplot(1,5,i+1); plt.imshow(x_test[i].reshape(28,28), cmap='gray'); plt.axis('off')
    plt.title(f'Pred:{pred_labels[i]}') 
plt.show()

In [None]:
# End of Lab 6