In [1]:
%load_ext tensorboard

In [35]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import io, itertools
from datetime import datetime
from tensorflow import keras,summary
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras import layers,models,callbacks
from sklearn.metrics import confusion_matrix

In [7]:
(train_img,train_labels),(test_img,test_labels) = fashion_mnist.load_data()

In [9]:
train_img,test_img = 255.0 / train_img, 255.0 / test_img

  train_img,test_img = 255.0 / train_img, 255.0 / test_img


In [21]:
class_name = ['T-shirt/top','Trousers','Pullover','Dress','Coat','Sandels','Shirt','Sneakers','Bag','Ankle boot']

In [23]:
print('shape :',train_img[0].shape)
print('label :',train_labels[0], '->' , class_name[train_labels[0]])

shape : (28, 28)
label : 9 -> Ankle boot


In [13]:
model = models.Sequential([
    layers.Flatten(input_shape = (28,28)),
    layers.Dense(32, activation = 'relu'),
    layers.Dense(10, activation = 'softmax')
])

In [15]:
model.compile(
    optimizer = 'adam',
    loss = 'sparse_categorical_crossentropy',
    metrics = ['accuracy']
)

In [103]:
def plot_to_img(figure):
    buf = io.BytesIO()
    plt.savefig(buf, format = 'png')
    plt.close(figure)
    buf.seek(0)
    image = tf.image.decode_png(buf.getvalue(), channels = 4)
    image = tf.expand_dims(image, 0)
    return image

In [111]:
def plot_confusion_matrix(cm, class_name):
    figure = plt.figure(figsize = (8,8))
    plt.imshow(cm, interpolation = 'nearest', cmap = plt.cm.Blues)

    plt.title('confusion matrix')
    plt.colorbar()

    tick_marks = np.arange(len(class_name))
    plt.xticks(tick_marks, class_name, rotation = 45)
    plt.yticks(tick_marks, class_name)

    labels = np.around(cm.astype('float') / cm.sum(axis = 1)[:, np.newaxis],decimals = 2)

    threshold = cm.max() / 2
    for i, j in itertools.product(range(cm.shape[0]),range(cm.shape[1])):
        color = 'white' if cm[i,j] > threshold else 'black'
        plt.text (j,i , labels[i,j],horizontalalignment = 'center', color = color)

    plt.tight_layout()
    plt.ylabel('True labels')
    plt.xlabel('Predicted labels')
    return(figure)

In [33]:
import shutil
shutil.rmtree('./logs/', ignore_errors=True)

In [43]:
logdir = 'logs/image/' + datetime.now().strftime('%Y%m%d - %H%M%S')
tensorboard_callback = callbacks.TensorBoard(log_dir = logdir)
file_writer_cm = summary.create_file_writer(logdir +'/cm')

In [71]:
def log_confusion_matrix(epochs,logs):
    test_pred_raw = model.predict(test_img)
    test_pred = np.argmax(test_pred_raw,axis = 1)

    cm = confusion_matrix(test_pred, test_labels)
    figure = plot_confusion_matrix(cm, class_name = class_name)
    cm_image = plot_to_img(figure)

    with file_writer_cm.as_default():
        summary.image('epoch_confusion_matric',cm_image,step = epochs)
cm_callback = callbacks.LambdaCallback(on_epoch_end = log_confusion_matrix)

In [113]:
%tensorboard --logdir logs/image --port=6007
model.fit(train_img,
         train_labels,
         epochs = 5,
         verbose = 0,
         callbacks = [tensorboard_callback,cm_callback],
         validation_data = (test_img,test_labels))

Reusing TensorBoard on port 6007 (pid 10000), started 0:17:37 ago. (Use '!kill 10000' to kill it.)

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step


  labels = np.around(cm.astype('float') / cm.sum(axis = 1)[:, np.newaxis],decimals = 2)


[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step


  labels = np.around(cm.astype('float') / cm.sum(axis = 1)[:, np.newaxis],decimals = 2)


[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step


  labels = np.around(cm.astype('float') / cm.sum(axis = 1)[:, np.newaxis],decimals = 2)


[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step


  labels = np.around(cm.astype('float') / cm.sum(axis = 1)[:, np.newaxis],decimals = 2)


[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step


  labels = np.around(cm.astype('float') / cm.sum(axis = 1)[:, np.newaxis],decimals = 2)


<keras.src.callbacks.history.History at 0x1e2a0e713a0>