In [1]:
%tensorflow_version 2.x 

TensorFlow 2.x selected.


In [2]:
from tensorflow import keras
from tqdm import tqdm_notebook as tqdm
from IPython.display import display
import matplotlib.pyplot as plt

class ProgressCallback(keras.callbacks.Callback):
  def __init__(self, m, epochs, batchSize, valSplit, leavePlots=True, plot=False):
    self.plot = plot
    self.leavePlots = leavePlots
    self.epochs = epochs
    self.trainSize = int(m * (1 - valSplit))
    self.valSize = int(m * valSplit)
    self.batchSize = batchSize
    self.completed = 0
    self.trainErrors, self.valErrors, self.trainAcc, self.valAcc = [],[],[],[]
    if self.plot:
      self.fig, (self.axLoss, self.axAcc) = plt.subplots(nrows=1, ncols=2, figsize=(20, 5))
      self.ide = display(self.axLoss.figure, display_id=True)

  def on_train_begin(self, logs=None):
    self.initialProgbar = tqdm(total=self.epochs, desc = "Epochs Completed... ", position = 1)

  def on_train_batch_end(self, batch, logs=None):
    self.progbar.update(1)

  def on_epoch_begin(self, epoch, logs=None):
    self.progbar = tqdm(total = self.trainSize/self.batchSize, position = 0, leave = self.leavePlots)
    self.progbar.set_description("Epoch {}, Training... ".format(epoch))

  def on_epoch_end(self, epoch, logs=None):
    self.initialProgbar.update(1)
    self.progbar.close()
    print("Loss:", logs['loss'], ", Accuracy:", logs['accuracy'], ", Validation Loss:", logs['val_loss'], ", Validation Accuracy", logs['val_accuracy'])
    self.trainErrors.append(logs['loss'])
    self.valErrors.append(logs['val_loss'])
    self.trainAcc.append(logs['accuracy'])
    self.valAcc.append(logs['val_accuracy'])

    if self.plot:
      self.axLoss.cla()
      self.axLoss.plot(list(range(len(self.trainErrors))), self.trainErrors, label="Train Loss", color='blue')
      self.axLoss.plot(list(range(len(self.valErrors))), self.valErrors, label="Val Loss", color='red')
      self.axLoss.legend()

      self.axAcc.cla()
      self.axAcc.plot(list(range(len(self.trainAcc))), self.trainAcc, label="Train Accuracy", color='blue')
      self.axAcc.plot(list(range(len(self.valAcc))), self.valAcc, label="Val Accuracy", color='red')
      self.axAcc.legend()

      self.axAcc.set_ylim(ymin=0, ymax=1)
      self.axAcc.set_xlim(xmin=0, xmax=self.epochs)
      self.axLoss.set_xlim(xmin=0, xmax=self.epochs)
      self.axLoss.set_ylim(ymin=0)

      self.ide.update(self.axLoss.figure)


DIM = 28
EPOCHS = 10

(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()

x_train = x_train.reshape(x_train.shape[0], DIM, DIM, 1).astype('float32') / 255
x_test = x_test.reshape(x_test.shape[0], DIM, DIM, 1).astype('float32') / 255

# Very basic model to speed up training
def modelKeras(input_shape):
  inp = keras.layers.Input(input_shape)
  image = keras.layers.Flatten()(inp)
  image = keras.layers.Dense(10, activation = 'softmax')(image)
  model = keras.Model(inputs = inp, outputs = image)
  return model

kerasModel = modelKeras((28, 28, 1))
kerasModel.summary()

cb = ProgressCallback(x_train.shape[0], epochs=EPOCHS, valSplit = 0.2, batchSize=128)

kerasModel.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
kerasModel.fit(x_train, y_train, epochs = EPOCHS, validation_split = 0.2, batch_size = 128, shuffle=True, callbacks=[cb], verbose=0)

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 10)                7850      
Total params: 7,850
Trainable params: 7,850
Non-trainable params: 0
_________________________________________________________________


HBox(children=(IntProgress(value=0, description='Epochs Completed... ', max=10, style=ProgressStyle(descriptio…

HBox(children=(IntProgress(value=0, max=375), HTML(value='')))

Loss: 0.7779719337622325 , Accuracy: 0.7418542 , Validation Loss: 0.5713321083386739 , Validation Accuracy 0.8120833


HBox(children=(IntProgress(value=0, max=375), HTML(value='')))

Loss: 0.5333842458724976 , Accuracy: 0.8216042 , Validation Loss: 0.5069219506581625 , Validation Accuracy 0.83016664


HBox(children=(IntProgress(value=0, max=375), HTML(value='')))

Loss: 0.48723925415674846 , Accuracy: 0.8358333 , Validation Loss: 0.4815882132848104 , Validation Accuracy 0.83491665


HBox(children=(IntProgress(value=0, max=375), HTML(value='')))

Loss: 0.4637924505074819 , Accuracy: 0.8432083 , Validation Loss: 0.47030982200304666 , Validation Accuracy 0.83575


HBox(children=(IntProgress(value=0, max=375), HTML(value='')))

Loss: 0.4477901275952657 , Accuracy: 0.84845835 , Validation Loss: 0.45376576042175293 , Validation Accuracy 0.84583336


HBox(children=(IntProgress(value=0, max=375), HTML(value='')))

Loss: 0.43733351929982506 , Accuracy: 0.8506875 , Validation Loss: 0.4419459256331126 , Validation Accuracy 0.84725


HBox(children=(IntProgress(value=0, max=375), HTML(value='')))

Loss: 0.4278545788923899 , Accuracy: 0.8535 , Validation Loss: 0.4387272061506907 , Validation Accuracy 0.85158336


HBox(children=(IntProgress(value=0, max=375), HTML(value='')))

Loss: 0.4214695064226786 , Accuracy: 0.85479164 , Validation Loss: 0.4385284388065338 , Validation Accuracy 0.84683335


HBox(children=(IntProgress(value=0, max=375), HTML(value='')))

Loss: 0.4168753369649251 , Accuracy: 0.8576667 , Validation Loss: 0.4357163206736247 , Validation Accuracy 0.85183334


HBox(children=(IntProgress(value=0, max=375), HTML(value='')))

Loss: 0.410902664621671 , Accuracy: 0.8584375 , Validation Loss: 0.4336517126560211 , Validation Accuracy 0.84975


<tensorflow.python.keras.callbacks.History at 0x7fab907fe748>

Here is an example of the ProgressCallback in a Jupyter Notebook:
![alt text](https://storage.googleapis.com/codein-prod.appspot.com/gci-2019/core_taskupdate/doc/6589474071379968_1576389634_Screen_Shot_2019-12-15_at_12.59.50_AM.png?Expires=1576476572&GoogleAccessId=codein-prod%40appspot.gserviceaccount.com&Signature=rZArwK18Tw%2FPc6MoB1huQPJ5PB7AaePcG0lLMN5UGX3TszyThGi5Vj5iOWQ6w4khjP808HWOrBO4x3JRNH8AL2cIjj6LR3FIrSFEv9AackRU5jVznROiTkcwLc9UbJiR3wWoYmHEXnhpLRKFjRLr%2B3Porqu%2BIeaIkY%2BPJTVvDPKdcJPu%2Frv0chfWzZrTppJWnbBb3p5p06oa1LoHdi8FoSj9SOaaRce4svWMk3YlUTy%2BmL6N9LgLT6vg7kNTwURqQCdzsEcxPQVBQQKgDGSJkRuN0cfsoISRgrekqgZaTgWgC%2Bqa8KrqGVu%2FeiZtG%2BnQ5DhFs3%2F%2BKmUSEpUEbi4rtA%3D%3D)