In [0]:
%tensorflow_version 2.x 

In [6]:
from tensorflow import keras
import tensorflow as tf
from tqdm import tqdm_notebook as tqdm
from IPython.display import display
import matplotlib.pyplot as plt

class ProgressCallback(keras.callbacks.Callback):
  def __init__(self, m, epochs, batchSize, valSplit, leavePlots=True, plot=False):
    self.plot = plot
    self.leavePlots = leavePlots
    self.epochs = epochs
    self.trainSize = int(m * (1 - valSplit))
    self.valSize = int(m * valSplit)
    self.batchSize = batchSize
    self.completed = 0
    self.trainErrors, self.valErrors, self.trainAcc, self.valAcc = [],[],[],[]
    if self.plot:
      self.fig, (self.axLoss, self.axAcc) = plt.subplots(nrows=1, ncols=2, figsize=(20, 5))
      self.ide = display(self.axLoss.figure, display_id=True)

  def on_train_begin(self, logs=None):
    self.initialProgbar = tqdm(total=self.epochs, desc = "Epochs Completed... ", position = 1)

  def on_train_batch_end(self, batch, logs=None):
    self.progbar.update(1)

  def on_epoch_begin(self, epoch, logs=None):
    self.progbar = tqdm(total = self.trainSize/self.batchSize, position = 0, leave = self.leavePlots)
    self.progbar.set_description("Epoch {}, Training... ".format(epoch))

  def on_epoch_end(self, epoch, logs=None):
    self.initialProgbar.update(1)
    self.progbar.close()
    print("Loss:", logs['loss'], ", Accuracy:", logs['accuracy'], ", Validation Loss:", logs['val_loss'], ", Validation Accuracy", logs['val_accuracy'])
    self.trainErrors.append(logs['loss'])
    self.valErrors.append(logs['val_loss'])
    self.trainAcc.append(logs['accuracy'])
    self.valAcc.append(logs['val_accuracy'])

    if self.plot:
      self.axLoss.cla()
      self.axLoss.plot(list(range(len(self.trainErrors))), self.trainErrors, label="Train Loss", color='blue')
      self.axLoss.plot(list(range(len(self.valErrors))), self.valErrors, label="Val Loss", color='red')
      self.axLoss.legend()

      self.axAcc.cla()
      self.axAcc.plot(list(range(len(self.trainAcc))), self.trainAcc, label="Train Accuracy", color='blue')
      self.axAcc.plot(list(range(len(self.valAcc))), self.valAcc, label="Val Accuracy", color='red')
      self.axAcc.legend()

      self.axAcc.set_ylim(ymin=0, ymax=1)
      self.axAcc.set_xlim(xmin=0, xmax=self.epochs)
      self.axLoss.set_xlim(xmin=0, xmax=self.epochs)
      self.axLoss.set_ylim(ymin=0)

      self.ide.update(self.axLoss.figure)

DIM = 28
EPOCHS = 10

(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()

x_train = x_train.reshape(x_train.shape[0], DIM, DIM, 1).astype('float32') / 255
x_test = x_test.reshape(x_test.shape[0], DIM, DIM, 1).astype('float32') / 255

# Very basic model to speed up training
def modelKeras(input_shape):
  inp = keras.layers.Input(input_shape)
  image = keras.layers.Flatten()(inp)
  image = keras.layers.Dense(10, activation = 'softmax')(image)
  model = keras.Model(inputs = inp, outputs = image)
  return model

kerasModel = modelKeras((28, 28, 1))
kerasModel.summary()

cb = ProgressCallback(x_train.shape[0], epochs=EPOCHS, valSplit = 0.2, batchSize=128)

kerasModel.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
kerasModel.fit(x_train, y_train, epochs = EPOCHS, validation_split = 0.2, batch_size = 128, shuffle=True, callbacks=[cb], verbose=0)

Model: "model_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_4 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
flatten_3 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 10)                7850      
Total params: 7,850
Trainable params: 7,850
Non-trainable params: 0
_________________________________________________________________


HBox(children=(IntProgress(value=0, description='Epochs Completed... ', max=10, style=ProgressStyle(descriptio…

HBox(children=(IntProgress(value=0, max=375), HTML(value='')))

Loss: 0.7837004326979319 , Accuracy: 0.7402292 , Validation Loss: 0.581614243666331 , Validation Accuracy 0.804


HBox(children=(IntProgress(value=0, max=375), HTML(value='')))

Loss: 0.5345598978201548 , Accuracy: 0.8225625 , Validation Loss: 0.5062860030333202 , Validation Accuracy 0.8268333


HBox(children=(IntProgress(value=0, max=375), HTML(value='')))

Loss: 0.487636226495107 , Accuracy: 0.83554167 , Validation Loss: 0.4780752960840861 , Validation Accuracy 0.8395


HBox(children=(IntProgress(value=0, max=375), HTML(value='')))

Loss: 0.4642881722450256 , Accuracy: 0.84316665 , Validation Loss: 0.4707410374482473 , Validation Accuracy 0.83725


HBox(children=(IntProgress(value=0, max=375), HTML(value='')))

Loss: 0.44708330543835956 , Accuracy: 0.849375 , Validation Loss: 0.4546132166385651 , Validation Accuracy 0.8441667


HBox(children=(IntProgress(value=0, max=375), HTML(value='')))

Loss: 0.43708890422185265 , Accuracy: 0.8512083 , Validation Loss: 0.44200757972399396 , Validation Accuracy 0.8485


HBox(children=(IntProgress(value=0, max=375), HTML(value='')))

Loss: 0.427922116279602 , Accuracy: 0.854625 , Validation Loss: 0.44106546211242675 , Validation Accuracy 0.85216665


HBox(children=(IntProgress(value=0, max=375), HTML(value='')))

Loss: 0.4226781499385834 , Accuracy: 0.8551458 , Validation Loss: 0.4369556939601898 , Validation Accuracy 0.85366666


HBox(children=(IntProgress(value=0, max=375), HTML(value='')))

Loss: 0.4164011974732081 , Accuracy: 0.85691667 , Validation Loss: 0.43779691990216574 , Validation Accuracy 0.84775


HBox(children=(IntProgress(value=0, max=375), HTML(value='')))

Loss: 0.411271061817805 , Accuracy: 0.8597917 , Validation Loss: 0.4266095713774363 , Validation Accuracy 0.8538333


<tensorflow.python.keras.callbacks.History at 0x7f79953ac6a0>