In [None]:
# default_exp core

# Fancy Callbacks

> Fancy callbacks for Keras. This was created mainly to explore the usage of nbdev.

In [None]:
#hide
from nbdev.showdoc import *

## Plot metrics after training

> Normally, we all like to plot our metrics after the training has concluded. Because of this, we built a `Callback` that will automatically plot them for you once the training is completed.

In [None]:
#export
#hide
import tensorflow as tf
import matplotlib.pyplot as plt

In [None]:
#export
class PlotMetrics(tf.keras.callbacks.History):
    def __init__(self, figsize=(9, 4)):
        super(PlotMetrics, self).__init__()
        self.figsize = figsize
    def _get_unique_metrics(self):
        unique_names = [name for name in self.history.keys() if name[:4]!='val_']
        return unique_names

    def on_train_end(self, logs=None):
        unique_names = self._get_unique_metrics()
        fig, axes = plt.subplots(1, len(unique_names), figsize=self.figsize)
        for i, name in enumerate(unique_names):
            axes[i].plot(self.history[name], label='Train')
            axes[i].plot(self.history[f'val_{name}'], label='Validation')
            axes[i].set_title(name)
            axes[i].legend()
        plt.show()