In [1]:
print("Hello World")

Hello World


In [2]:
from pprint import pprint
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorboard import notebook
%load_ext tensorboard

In [3]:
class Model:
    def __init__(self):
        pass

    @staticmethod
    def build() -> keras.Model:
        inputs = keras.Input(shape=(28, 28))
        x = keras.layers.experimental.preprocessing.Rescaling(1.0 / 255)(inputs)
        x = keras.layers.Flatten()(x)
        x = keras.layers.Dense(128, activation="relu")(x)
        x = keras.layers.Dense(128, activation="relu")(x)
        outputs = keras.layers.Dense(10, activation="softmax")(x)

        model = keras.Model(inputs=inputs, outputs=outputs)
        model.compile(optimizer=keras.optimizers.Adam(), loss=keras.losses.SparseCategoricalCrossentropy())

        return model

    @staticmethod
    def train(model: "Model",
              dataset: tf.data.Dataset,
              validation_dataset: tf.data.Dataset = None,
              epochs: int = 1) -> keras.callbacks.History:
        callbacks = [
            keras.callbacks.TensorBoard(log_dir='./logs')
        ]
        return model.fit(dataset, epochs=epochs, callbacks=callbacks, validation_data=validation_dataset)


# Example
model = Model.build()
model.summary()

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
batch_size = 64
dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size)
val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size)
history = Model.train(model, dataset, val_dataset, 2)

pprint(history.history)

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 28, 28)]          0         
_________________________________________________________________
rescaling (Rescaling)        (None, 28, 28)            0         
_________________________________________________________________
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 128)               100480    
_________________________________________________________________
dense_1 (Dense)              (None, 128)               16512     
_________________________________________________________________
dense_2 (Dense)              (None, 10)                1290      
Total params: 118,282
Trainable params: 118,282
Non-trainable params: 0
_______________________________________________________

In [4]:
%tensorboard --logdir ./logs
notebook.list()

Reusing TensorBoard on port 6006 (pid 12263), started 0:56:28 ago. (Use '!kill 12263' to kill it.)