# CNN visualisation with Tensorboard on FashonMNIS
written by [Mattia Chiari](mailto:m.chiari017@unibs.it)


## Sources 
* https://www.kaggle.com/code/rutvikdeshpande/fashion-mnist-cnn-beginner-98/notebook
* https://en.wikipedia.org/wiki/LeNet
* https://www.kaggle.com/code/gpreda/cnn-with-tensorflow-keras-for-fashion-mnist/notebook

## Imports

In [None]:
import os
import datetime
import tensorflow as tf
from tensorflow.keras.callbacks import TensorBoard
import numpy as np
import matplotlib.pyplot as plt
import io
from sklearn import metrics

In [None]:
%load_ext tensorboard


## Dataset: FashonMNIST

You can find more information about the dataset [here](https://github.com/zalandoresearch/fashion-mnist)

### Constants

In [None]:
# FashonMNIST labels values
class_names = {-1: '',
               0:	'T-shirt/top',
               1:	'Trouser',
               2:	'Pullover',
               3:	'Dress',
               4:	'Coat',
               5:	'Sandal',
               6:	'Shirt',
               7:	'Sneaker',
               8:	'Bag',
               9:	'Ankle boot'}

### Custom functions

In [None]:
def image_grid(x: list, y: list, figures: int = 36, cols: int = 6):
    """
    Plot a grid of images

    Args:
        x (list): list of images
        y (list): list of labels as integers
        figures (int, optional): number of figures to plot. Defaults to 36.
        cols (int, optional): number of columns in the grid. Defaults to 6.

    Raises:
        ValueError: if x and y have different lengths

    Returns:
        matplotlib.figure.Figure: a figure with a grid of images
    """
    if len(x) != len(y):
        raise ValueError("x and y must have the same length")

    figure = plt.figure(figsize=(12,12))

    lines = np.ceil(float(figures)/cols)
    for i in range(figures):
        plt.subplot(lines, cols, i + 1)
        plt.xlabel(class_names[y[i]])
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(x[i], cmap=plt.cm.coolwarm)
        #plt.tight_layout()

    return figure

### Code

In [None]:
# Download FashonMNIST
fashion_mnist = tf.keras.datasets.fashion_mnist
(x_train, y_train),(x_test, y_test) = fashion_mnist.load_data()

In [None]:
# Scale the images
x_train_scaled, x_test_scaled = x_train/255.0, x_test/255.0

In [None]:
figure_train = image_grid(x_train, y_train)

## Run TensorBoard

### Custom functions

In [None]:
def create_model():
  """
  Create a simple functional model

  Returns:
      Model: a simple model 
  """
  input_layer = tf.keras.layers.Input(shape=(28,28,1), name='input')
  conv1_layer = tf.keras.layers.Conv2D(16, (3,3), activation='relu', name='conv1')(input_layer)
  dropout1_layer = tf.keras.layers.Dropout(0.2)(conv1_layer)
  maxpool1_layer = tf.keras.layers.MaxPool2D((2,2))(dropout1_layer)
  conv3_layer = tf.keras.layers.Conv2D(64, (3,3), activation='relu', name='conv3')(maxpool2_layer)
  dropout3_layer = tf.keras.layers.Dropout(0.2)(conv3_layer)
  maxpool3_layer = tf.keras.layers.MaxPool2D((2,2))(dropout3_layer)
  flatten_layer = tf.keras.layers.Flatten()(maxpool3_layer)
  dense_layer = tf.keras.layers.Dense(64, activation='relu')(flatten_layer)
  output_layer = tf.keras.layers.Dense(10, activation='softmax')(dense_layer)

  model = tf.keras.models.Model(inputs=[input_layer], outputs=[output_layer])
  return model

In [None]:
def plot_to_image(figure: plt.figure):
    """
    Convert a matplotlib figure to a tensor

    Args:
        figure (plt.figure): a matplotlib figure

    Returns:
        tensor: a tensor that contains the image
    """
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    plt.close(figure)
    buf.seek(0)

    digit = tf.image.decode_png(buf.getvalue(), channels=4)
    digit = tf.expand_dims(digit, 0)

    return digit

In [None]:
def create_log_images(submodels: list, input_img:np.ndarray, file_writer_img: tf.summary.SummaryWriter):
    """
    Create a callback that will plot a intermediate images

    Args:
        submodels(list): a list of submodels
        input_img: input image
        file_writer_img (tf.summary.SummaryWriter): a file writer
    """

    def log_intermediate_images(epoch: int, logs: dict):
        """
        Log intermediate images

        Args:
            epoch (int): current epoch
            logs (dict): logs
        """
        preds = []
        for submodel in submodels:
          preds.append(submodel.predict(input_img.reshape(1,28,28,1)))
        i = image_grid([input_img], [-1], 1, 1)
        i = plot_to_image(i)
        with file_writer_img.as_default():
          tf.summary.image("Input", i, step=0)
        for p in range(len(preds)):
          imgs = [preds[p][0, :, :, k] for k in range(preds[p].shape[3])]
          y = np.ones((preds[p].shape[3],), dtype=int)*(-1)
          i = image_grid(imgs, y, preds[p].shape[3], int(np.sqrt(preds[p].shape[3])))
          i = plot_to_image(i)
          with file_writer_img.as_default():
            tf.summary.image(f"Conv layer {p+1} output", i, step=epoch)
    
    return log_intermediate_images

In [None]:
def create_log_kernels(weights_list: list, file_writer_kernel: tf.summary.SummaryWriter):
    """
    Create a callback that will plot a confusion matrix

    Args:
        weights_list(list): a list of layers' weights
        file_writer_kernel (tf.summary.SummaryWriter): a file writer
    """

    def log_kernels(epoch: int, logs: dict):
        """
        Log kernels
        Args:
            epoch (int): current epoch
            logs (dict): logs
        """
        for w, weights in enumerate(weights_list):
          weights = np.asanyarray(weights)
          imgs = [weights[:, :, 0, k] for k in range(weights.shape[3])]
          y = np.ones((weights.shape[3],), dtype=int)*(-1)
          i = image_grid(imgs, y, weights.shape[3], int(np.sqrt(weights.shape[3])))
          i = plot_to_image(i)
          with file_writer_kernel.as_default():
            tf.summary.image(f"Kernel layer {w+1} output", i, step=epoch)
    
    return log_kernels

In [None]:
def train_model(x_train: np.ndarray,
                y_train: np.ndarray,
                x_test: np.ndarray,
                y_test: np.ndarray,
                log_folder: str = None,
                epochs: int = 5):
  """
  Train a simple sequential model
  
  Args:
      x_train (np.ndarray): training data
      y_train (np.ndarray): training labels
      x_test (np.ndarray): test data
      y_test (np.ndarray): test labels
      log_folder (str, optional): directory to save logs. Defaults to './logs/'.
      epochs (int, optional): number of epochs to train. Defaults to 5. 
  """
  # Create and train the model
  model = create_model()
  model.compile(optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])

  # Define log dir
  if log_folder == None:
    logdir = os.path.join('logs', datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
  else:
    logdir = log_folder

  # Create file writer
  file_writer = tf.summary.create_file_writer(logdir)

  # Create submodels
  submodels = []
  submodels.append(tf.keras.models.Model(inputs=[model.get_layer('conv1').input], outputs=[model.get_layer('conv1').output]))
  submodels.append(tf.keras.models.Model(inputs=[model.get_layer('conv1').input], outputs=[model.get_layer('conv2').output]))
  submodels.append(tf.keras.models.Model(inputs=[model.get_layer('conv1').input], outputs=[model.get_layer('conv3').output]))

  # Create weights list
  weights_list = []
  weights_list.append(model.get_layer('conv1').weights[0])
  weights_list.append(model.get_layer('conv2').weights[0])
  weights_list.append(model.get_layer('conv3').weights[0])

  # Create callbacks list
  callbacks = [tf.keras.callbacks.TensorBoard(logdir, 
                                              histogram_freq=1, 
                                              profile_batch='250,500'),
               tf.keras.callbacks.LambdaCallback(on_epoch_end=create_log_images(submodels,
                                                                                x_test[0],
                                                                                file_writer)),
               tf.keras.callbacks.LambdaCallback(on_epoch_end=create_log_kernels(weights_list,
                                                                                 file_writer))]

  # Train the model
  model.fit(x=x_train, 
            y=y_train, 
            epochs=epochs, 
            validation_data=(x_test, y_test), 
            callbacks=callbacks)
  return model, submodels

### Code

In [None]:
%tensorboard --logdir=logs

In [None]:
model, submodels = train_model(x_train_scaled, y_train, x_test_scaled, y_test, epochs=5)

In [None]:
model.summary()