In [None]:
import numpy as np
import pandas as pd
from keras.src.callbacks import ModelCheckpoint
from keras.src.datasets import mnist
import plotly.express as px
from keras.src.layers import Dense, Flatten
from keras.src.models.cloning import Sequential
from keras.src.utils import to_categorical
import plotly.graph_objects as go
from sklearn.metrics import confusion_matrix, accuracy_score

In [None]:
(data, labels), (test_data, test_labels) = mnist.load_data()

print(f"Data shape: {data.shape}")
print(f"Test data shape: {test_data.shape}")

In [None]:
idx = 5553

print(f"Number: {labels[idx]}")
px.imshow(data[idx], color_continuous_scale='gray_r')

In [None]:
model = Sequential()
model.add(Flatten(input_shape=(28, 28)))
model.add(Dense(units=32))
model.add(Dense(units=16))
model.add(Dense(units=10, activation="softmax"))

model.summary()

In [None]:
mc = ModelCheckpoint("best_model_mnist.keras", monitor="val_accuracy", save_best_only=True, mode="max", verbose=1)

model.compile(optimizer="adam",
              loss="categorical_crossentropy",
              metrics=["accuracy"])

history = model.fit(data, to_categorical(labels), epochs=10, batch_size=32, validation_split=0.05, callbacks=[mc])

model.save("model_after_training.keras")

In [None]:
def draw_history(history):
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=history.epoch, y=history.history["accuracy"], name="accuracy"))
    fig.add_trace(go.Scatter(x=history.epoch, y=history.history["val_accuracy"], name="val_accuracy"))
    fig.show()

    fig = go.Figure()
    fig.add_trace(go.Scatter(x=history.epoch, y=history.history["loss"], name="loss"))
    fig.add_trace(go.Scatter(x=history.epoch, y=history.history["val_loss"], name="val_loss"))
    fig.show()


draw_history(history)

In [None]:
# model.load_weights("best_model_mnist.keras")

preds = model.predict(test_data)
preds = np.argmax(preds, axis=1)

df = pd.DataFrame({
    "true": test_labels,
    "preds": preds
})

df

In [None]:
print(f"Accuracy: {accuracy_score(df['true'], df['preds']) * 100}%")

cm = confusion_matrix(df["true"], df["preds"])
px.imshow(cm, text_auto=True)