In [None]:
from tensorflow.python.keras.callbacks import EarlyStopping
import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical


(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

# Data preprocessing
x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.0
x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.0

# Convert class vectors to binary class matrices
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

# Create the model
model = Sequential([
     Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
     MaxPooling2D((2, 2)),
     Dropout(0.25),

     Conv2D(64, (3, 3), activation='relu'),
     MaxPooling2D((2, 2)),
     Dropout(0.25),

     Conv2D(128, (3, 3), activation='relu'),
     MaxPooling2D((2, 2)),
     Dropout(0.25),

     Flatten(),
     Dense(128, activation='relu'),
     Dropout(0.5),
     Dense(10, activation='softmax')

])

model.compile(optimizer=Adam(), loss='categorical_crossentropy', metrics=['accuracy'])
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
history = model.fit(x_train, y_train, epochs=50, batch_size=128, validation_split=0.2, callbacks=[early_stopping])

# Evaluate the model on test data
test_loss, test_acc = model.evaluate(x_test, y_test)
print('Test accuracy:', test_acc)


In [None]:
import pickle
with open('history_cnn.pkl', 'wb') as f:
    pickle.dump(history.history, f)

In [None]:
# Save the model using the native Keras format
# model.save('model_cnn.h5')
model.save('model_cnn.keras')  # for the convolutional neural network

In [None]:
import matplotlib.pyplot as plt

plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0, 1])
plt.legend(loc='lower right')
plt.show()