In [1]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

from tensorflow import keras

In [2]:
train_ds = keras.utils.image_dataset_from_directory(
    'data/cats_dogs/',
    batch_size=16,
    image_size=(256, 256),
    shuffle=True,
    seed=42,
    validation_split=0.3,
    subset='training')

val_ds = keras.utils.image_dataset_from_directory(
    'data/cats_dogs/',
    batch_size=16,
    image_size=(256, 256),
    shuffle=True,
    seed=42,
    validation_split=0.3,
    subset='validation')

Found 25000 files belonging to 2 classes.
Using 17500 files for training.
Found 25000 files belonging to 2 classes.
Using 7500 files for validation.


In [3]:
model = keras.Sequential()
model.add(keras.layers.InputLayer(input_shape=(256, 256, 3)))
model.add(keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu'))
model.add(keras.layers.MaxPooling2D(pool_size=(2, 2)))
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(1, activation='sigmoid'))

model.compile(
    loss=keras.losses.BinaryCrossentropy(),
    optimizer=keras.optimizers.SGD(),
    metrics=[
        keras.metrics.BinaryAccuracy(),
        keras.metrics.Precision(),
        keras.metrics.Recall()
    ]
)

Andiamo ad inserire tre diversi callback:

1. un `ModelCheckpoint`, che salva i risultati intermedi nel file da noi specificati;
2. un `EarlyStopping`, che interrompe l'addestramento qualora i risultati ottenuti non migliorino;
3. un `TensorBoard`, che fa in modo che i risultati siano mostrati mediante TensorBoard.

In [4]:
callbacks = [
    keras.callbacks.ModelCheckpoint(
        filepath='checkpoints/checkpoints',
        save_weights_only=True,
        monitor='val_binary_accuracy',
        save_best_only=True),
    keras.callbacks.EarlyStopping(
        monitor='val_binary_accuracy',
        min_delta=0.1,
        patience=3,
        restore_best_weights=True),
    keras.callbacks.TensorBoard(log_dir='logs')
]

model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=10,
    callbacks=callbacks
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10


<keras.callbacks.History at 0x2419d14ada0>