In [None]:
from __future__ import annotations

In [None]:
from IPython.core.display import display, HTML
display(HTML('<style>.container { width:100% !important; }</style>'))

In [None]:
import os
import datetime
import numpy as np
import sklearn.metrics
import tensorflow as tf
import sklearn.datasets
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

import seaborn as sns
sns.set_style('dark')
sns.set(rc={'figure.figsize': (10, 10), 'axes.grid': False})

In [None]:
rng = np.random.default_rng()

In [None]:
class Classifier:
    
    _estimator_type = 'classifier'
    
    def __init__(self, model):
        self.model = model
        
    def predict(self, X):
        preds = self.model.predict(X)
        return np.argmax(preds, axis=-1)

In [None]:
def plot_grid(X, ncols=2, nrows=2, figsize=(10, 10), hide_axes=True):
    fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=figsize)
    imgs = X[rng.choice(np.arange(X.shape[0]), size=ax.shape)].reshape(*ax.shape, 28, 28)
    for i in range(ax.shape[0]):
        for j in range(ax.shape[1]):
            ax[i, j].imshow(imgs[i, j], cmap='gray')
            if hide_axes:
                ax[i, j].get_xaxis().set_visible(False)
                ax[i, j].get_yaxis().set_visible(False)

In [None]:
def plot_grid_and_label(X, y, model, ncols=2, nrows=2, figsize=(10, 10), hide_axes=True, flatten_image=True, label_description=None):
    fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=figsize)
    idx = rng.choice(np.arange(X.shape[0]), size=ax.shape)
    imgs, labels = X[idx], y[idx]
    inputs = imgs.reshape(-1, *imgs.shape[2:])
    
    if flatten_image:
        inputs = inputs.reshape(nrows * ncols, -1)
        
    map_label = (lambda x: x) if label_description is None else (lambda x: label_description[x])
        
    preds = np.argmax(tf.nn.softmax(model.predict(inputs)), axis=-1).reshape(nrows, ncols)
    for i in range(ax.shape[0]):
        for j in range(ax.shape[1]):
            ax[i, j].imshow(imgs[i, j], cmap='gray')
            ax[i, j].set_title(f'label: {map_label(labels[i, j])}, prediction: {map_label(preds[i, j])}')
            if hide_axes:
                ax[i, j].get_xaxis().set_visible(False)
                ax[i, j].get_yaxis().set_visible(False)

In [None]:
X, y = sklearn.datasets.fetch_openml('mnist_784', version=1, return_X_y=True, as_frame=False)
y = np.array(list(map(int, y)))
X_train, X_test, y_train, y_test = train_test_split(X, y)

In [None]:
plot_grid(X_train)

In [None]:
train_data = tf.data.Dataset.from_tensor_slices(X_train).map(lambda x: tf.cast(x / 255, tf.float32))
train_labels = tf.data.Dataset.from_tensor_slices(y_train)
train_dataset = tf.data.Dataset.zip((train_data, train_labels)).batch(32)

val_data = tf.data.Dataset.from_tensor_slices(X_test).map(lambda x: tf.cast(x / 255, tf.float32))
val_labels = tf.data.Dataset.from_tensor_slices(y_test)
val_dataset = tf.data.Dataset.zip((val_data, val_labels)).batch(32)

In [None]:
inputs = tf.keras.layers.Input((X.shape[1], ), name='input')
x = ...
x = tf.keras.layers.Dense(10, activation='softmax', name='output')(x)
model = tf.keras.Model(inputs=inputs, outputs=x, name='mnist')

In [None]:
model.summary()

In [None]:
tf.keras.utils.plot_model(model)

In [None]:
logdir = os.path.join('logs/mnist/', datetime.datetime.now().strftime('%Y%m%d-%H%M%S'))

In [None]:
callbacks = [
    tf.keras.callbacks.ModelCheckpoint(logdir, save_best_only=True, monitor='val_loss'),
    tf.keras.callbacks.TensorBoard(logdir, update_freq=10)
]

In [None]:
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics='accuracy')

In [None]:
model.fit(train_dataset, validation_data=val_dataset, epochs=5, callbacks=callbacks)

In [None]:
model.evaluate(X_test / 255, y_test)

In [None]:
images = X_test.reshape(-1, 28, 28)

In [None]:
plot_grid_and_label(images, y_test, model)

In [None]:
sklearn.metrics.plot_confusion_matrix(Classifier(model), X_test / 255, y_test)