In [1]:
!pip install optuna -q

Optuna example that demonstrates a pruner for Keras.
In this example, we optimize the validation accuracy of hand-written digit recognition using
Keras and MNIST, where the architecture of the neural network and the learning rate of optimizer
is optimized. Throughout the training of neural networks, a pruner observes intermediate
results and stops unpromising trials.


In [2]:
import optuna
from optuna.integration import KerasPruningCallback
from optuna.trial import TrialState

from tensorflow import keras
from keras.datasets import mnist
from keras.layers import Dense
from keras.layers import Dropout
from keras.models import Sequential

In [3]:
N_TRAIN_EXAMPLES = 3000
N_VALID_EXAMPLES = 1000
BATCHSIZE = 64
CLASSES = 10
EPOCHS = 20


In [4]:
def create_model(trial):
    # We optimize the number of layers, hidden units and dropout in each layer and
    # the learning rate of RMSProp optimizer.

    # We define our MLP.
    n_layers = trial.suggest_int("n_layers", 1, 5)
    model = Sequential()
    for i in range(n_layers):
        num_hidden = trial.suggest_int("n_units_l{}".format(i), 4, 128, log=True)
        model.add(Dense(num_hidden, activation="relu"))
        dropout = trial.suggest_float("dropout_l{}".format(i), 0.2, 0.5)
        model.add(Dropout(rate=dropout))
    model.add(Dense(CLASSES, activation="softmax"))

    # We compile our model with a sampled learning rate.
    lr = trial.suggest_float("learning_rate", 1e-5, 1e-1, log=True)
    model.compile(
        loss="categorical_crossentropy",
        optimizer=keras.optimizers.RMSprop(learning_rate=lr),
        metrics=["accuracy"],
    )

    return model

In [5]:
def objective(trial):
    # Clear clutter from previous session graphs.
    keras.backend.clear_session()

    # The data is split between train and validation sets.
    (x_train, y_train), (x_valid, y_valid) = mnist.load_data()
    x_train = x_train.reshape(60000, 784)[:N_TRAIN_EXAMPLES].astype("float32") / 255
    x_valid = x_valid.reshape(10000, 784)[:N_VALID_EXAMPLES].astype("float32") / 255

    # Convert class vectors to binary class matrices.
    y_train = keras.utils.to_categorical(y_train[:N_TRAIN_EXAMPLES], CLASSES)
    y_valid = keras.utils.to_categorical(y_valid[:N_VALID_EXAMPLES], CLASSES)

    # Generate our trial model.
    model = create_model(trial)

    # Fit the model on the training data.
    # The KerasPruningCallback checks for pruning condition every epoch.
    model.fit(
        x_train,
        y_train,
        batch_size=BATCHSIZE,
        callbacks=[KerasPruningCallback(trial, "val_accuracy")],
        epochs=EPOCHS,
        validation_data=(x_valid, y_valid),
        verbose=1,
    )

    # Evaluate the model accuracy on the validation set.
    score = model.evaluate(x_valid, y_valid, verbose=1)
    return score[1]

In [6]:
study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
study.optimize(objective, n_trials=15)

pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])




[32m[I 2021-08-25 06:31:33,863][0m A new study created in memory with name: no-name-9d1b87b7-1f73-4b9a-b11f-940dc0f6be80[0m

KerasPruningCallback has been deprecated in v2.1.0. This feature will be removed in v4.0.0. See https://github.com/optuna/optuna/releases/tag/v2.1.0. Recent Keras release (2.4.0) simply redirects all APIs in the standalone keras package to point to tf.keras. There is now only one Keras: tf.keras. There may be some breaking changes for some workflows by upgrading to keras 2.4.0. Test before upgrading. REF:https://github.com/keras-team/keras/releases/tag/2.4.0



Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


[32m[I 2021-08-25 06:31:37,596][0m Trial 0 finished with value: 0.625 and parameters: {'n_layers': 1, 'n_units_l0': 12, 'dropout_l0': 0.36240951209791317, 'learning_rate': 0.04232331604266482}. Best is trial 0 with value: 0.625.[0m

KerasPruningCallback has been deprecated in v2.1.0. This feature will be removed in v4.0.0. See https://github.com/optuna/optuna/releases/tag/v2.1.0. Recent Keras release (2.4.0) simply redirects all APIs in the standalone keras package to point to tf.keras. There is now only one Keras: tf.keras. There may be some breaking changes for some workflows by upgrading to keras 2.4.0. Test before upgrading. REF:https://github.com/keras-team/keras/releases/tag/2.4.0



Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


[32m[I 2021-08-25 06:31:41,622][0m Trial 1 finished with value: 0.503000020980835 and parameters: {'n_layers': 4, 'n_units_l0': 23, 'dropout_l0': 0.24011347155728746, 'n_units_l1': 24, 'dropout_l1': 0.4435396301097096, 'n_units_l2': 18, 'dropout_l2': 0.3084350845830808, 'n_units_l3': 10, 'dropout_l3': 0.2389261391581259, 'learning_rate': 0.0001658588129720693}. Best is trial 0 with value: 0.625.[0m

KerasPruningCallback has been deprecated in v2.1.0. This feature will be removed in v4.0.0. See https://github.com/optuna/optuna/releases/tag/v2.1.0. Recent Keras release (2.4.0) simply redirects all APIs in the standalone keras package to point to tf.keras. There is now only one Keras: tf.keras. There may be some breaking changes for some workflows by upgrading to keras 2.4.0. Test before upgrading. REF:https://github.com/keras-team/keras/releases/tag/2.4.0



Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


[32m[I 2021-08-25 06:31:45,285][0m Trial 2 finished with value: 0.5120000243186951 and parameters: {'n_layers': 3, 'n_units_l0': 8, 'dropout_l0': 0.4229473858654087, 'n_units_l1': 5, 'dropout_l1': 0.22991187912055755, 'n_units_l2': 4, 'dropout_l2': 0.47316646820151875, 'learning_rate': 0.005283566645130581}. Best is trial 0 with value: 0.625.[0m

KerasPruningCallback has been deprecated in v2.1.0. This feature will be removed in v4.0.0. See https://github.com/optuna/optuna/releases/tag/v2.1.0. Recent Keras release (2.4.0) simply redirects all APIs in the standalone keras package to point to tf.keras. There is now only one Keras: tf.keras. There may be some breaking changes for some workflows by upgrading to keras 2.4.0. Test before upgrading. REF:https://github.com/keras-team/keras/releases/tag/2.4.0



Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


[32m[I 2021-08-25 06:31:48,699][0m Trial 3 finished with value: 0.593999981880188 and parameters: {'n_layers': 2, 'n_units_l0': 8, 'dropout_l0': 0.431425445244245, 'n_units_l1': 25, 'dropout_l1': 0.27896167141392825, 'learning_rate': 0.00014881789737974626}. Best is trial 0 with value: 0.625.[0m

KerasPruningCallback has been deprecated in v2.1.0. This feature will be removed in v4.0.0. See https://github.com/optuna/optuna/releases/tag/v2.1.0. Recent Keras release (2.4.0) simply redirects all APIs in the standalone keras package to point to tf.keras. There is now only one Keras: tf.keras. There may be some breaking changes for some workflows by upgrading to keras 2.4.0. Test before upgrading. REF:https://github.com/keras-team/keras/releases/tag/2.4.0



Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


[32m[I 2021-08-25 06:31:52,372][0m Trial 4 finished with value: 0.3440000116825104 and parameters: {'n_layers': 3, 'n_units_l0': 4, 'dropout_l0': 0.4818642823408836, 'n_units_l1': 15, 'dropout_l1': 0.32705251378767586, 'n_units_l2': 5, 'dropout_l2': 0.21983563040785067, 'learning_rate': 0.022556406045604636}. Best is trial 0 with value: 0.625.[0m

KerasPruningCallback has been deprecated in v2.1.0. This feature will be removed in v4.0.0. See https://github.com/optuna/optuna/releases/tag/v2.1.0. Recent Keras release (2.4.0) simply redirects all APIs in the standalone keras package to point to tf.keras. There is now only one Keras: tf.keras. There may be some breaking changes for some workflows by upgrading to keras 2.4.0. Test before upgrading. REF:https://github.com/keras-team/keras/releases/tag/2.4.0



Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


[32m[I 2021-08-25 06:31:56,419][0m Trial 5 finished with value: 0.8769999742507935 and parameters: {'n_layers': 1, 'n_units_l0': 96, 'dropout_l0': 0.44463510235119474, 'learning_rate': 0.00021830036801948215}. Best is trial 5 with value: 0.8769999742507935.[0m

KerasPruningCallback has been deprecated in v2.1.0. This feature will be removed in v4.0.0. See https://github.com/optuna/optuna/releases/tag/v2.1.0. Recent Keras release (2.4.0) simply redirects all APIs in the standalone keras package to point to tf.keras. There is now only one Keras: tf.keras. There may be some breaking changes for some workflows by upgrading to keras 2.4.0. Test before upgrading. REF:https://github.com/keras-team/keras/releases/tag/2.4.0



Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


[32m[I 2021-08-25 06:31:59,636][0m Trial 6 finished with value: 0.8360000252723694 and parameters: {'n_layers': 1, 'n_units_l0': 8, 'dropout_l0': 0.2115072068105493, 'learning_rate': 0.002267102605844644}. Best is trial 5 with value: 0.8769999742507935.[0m

KerasPruningCallback has been deprecated in v2.1.0. This feature will be removed in v4.0.0. See https://github.com/optuna/optuna/releases/tag/v2.1.0. Recent Keras release (2.4.0) simply redirects all APIs in the standalone keras package to point to tf.keras. There is now only one Keras: tf.keras. There may be some breaking changes for some workflows by upgrading to keras 2.4.0. Test before upgrading. REF:https://github.com/keras-team/keras/releases/tag/2.4.0



Epoch 1/20


[32m[I 2021-08-25 06:32:01,196][0m Trial 7 pruned. Trial was pruned at epoch 0.[0m

KerasPruningCallback has been deprecated in v2.1.0. This feature will be removed in v4.0.0. See https://github.com/optuna/optuna/releases/tag/v2.1.0. Recent Keras release (2.4.0) simply redirects all APIs in the standalone keras package to point to tf.keras. There is now only one Keras: tf.keras. There may be some breaking changes for some workflows by upgrading to keras 2.4.0. Test before upgrading. REF:https://github.com/keras-team/keras/releases/tag/2.4.0



Epoch 1/20


[32m[I 2021-08-25 06:32:02,778][0m Trial 8 pruned. Trial was pruned at epoch 0.[0m

KerasPruningCallback has been deprecated in v2.1.0. This feature will be removed in v4.0.0. See https://github.com/optuna/optuna/releases/tag/v2.1.0. Recent Keras release (2.4.0) simply redirects all APIs in the standalone keras package to point to tf.keras. There is now only one Keras: tf.keras. There may be some breaking changes for some workflows by upgrading to keras 2.4.0. Test before upgrading. REF:https://github.com/keras-team/keras/releases/tag/2.4.0



Epoch 1/20


[32m[I 2021-08-25 06:32:04,209][0m Trial 9 pruned. Trial was pruned at epoch 0.[0m

KerasPruningCallback has been deprecated in v2.1.0. This feature will be removed in v4.0.0. See https://github.com/optuna/optuna/releases/tag/v2.1.0. Recent Keras release (2.4.0) simply redirects all APIs in the standalone keras package to point to tf.keras. There is now only one Keras: tf.keras. There may be some breaking changes for some workflows by upgrading to keras 2.4.0. Test before upgrading. REF:https://github.com/keras-team/keras/releases/tag/2.4.0



Epoch 1/20


[32m[I 2021-08-25 06:32:05,450][0m Trial 10 pruned. Trial was pruned at epoch 0.[0m

KerasPruningCallback has been deprecated in v2.1.0. This feature will be removed in v4.0.0. See https://github.com/optuna/optuna/releases/tag/v2.1.0. Recent Keras release (2.4.0) simply redirects all APIs in the standalone keras package to point to tf.keras. There is now only one Keras: tf.keras. There may be some breaking changes for some workflows by upgrading to keras 2.4.0. Test before upgrading. REF:https://github.com/keras-team/keras/releases/tag/2.4.0



Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


[32m[I 2021-08-25 06:32:08,899][0m Trial 11 finished with value: 0.9129999876022339 and parameters: {'n_layers': 1, 'n_units_l0': 48, 'dropout_l0': 0.20675427068205948, 'learning_rate': 0.0015381421266250687}. Best is trial 11 with value: 0.9129999876022339.[0m

KerasPruningCallback has been deprecated in v2.1.0. This feature will be removed in v4.0.0. See https://github.com/optuna/optuna/releases/tag/v2.1.0. Recent Keras release (2.4.0) simply redirects all APIs in the standalone keras package to point to tf.keras. There is now only one Keras: tf.keras. There may be some breaking changes for some workflows by upgrading to keras 2.4.0. Test before upgrading. REF:https://github.com/keras-team/keras/releases/tag/2.4.0



Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


[32m[I 2021-08-25 06:32:14,907][0m Trial 12 finished with value: 0.9020000100135803 and parameters: {'n_layers': 1, 'n_units_l0': 70, 'dropout_l0': 0.31829945314476926, 'learning_rate': 0.001042841395458935}. Best is trial 11 with value: 0.9129999876022339.[0m

KerasPruningCallback has been deprecated in v2.1.0. This feature will be removed in v4.0.0. See https://github.com/optuna/optuna/releases/tag/v2.1.0. Recent Keras release (2.4.0) simply redirects all APIs in the standalone keras package to point to tf.keras. There is now only one Keras: tf.keras. There may be some breaking changes for some workflows by upgrading to keras 2.4.0. Test before upgrading. REF:https://github.com/keras-team/keras/releases/tag/2.4.0



Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


[32m[I 2021-08-25 06:32:18,677][0m Trial 13 finished with value: 0.9010000228881836 and parameters: {'n_layers': 2, 'n_units_l0': 50, 'dropout_l0': 0.305077642159821, 'n_units_l1': 52, 'dropout_l1': 0.48758701961707995, 'learning_rate': 0.0013537280916659262}. Best is trial 11 with value: 0.9129999876022339.[0m

KerasPruningCallback has been deprecated in v2.1.0. This feature will be removed in v4.0.0. See https://github.com/optuna/optuna/releases/tag/v2.1.0. Recent Keras release (2.4.0) simply redirects all APIs in the standalone keras package to point to tf.keras. There is now only one Keras: tf.keras. There may be some breaking changes for some workflows by upgrading to keras 2.4.0. Test before upgrading. REF:https://github.com/keras-team/keras/releases/tag/2.4.0



Epoch 1/20


[32m[I 2021-08-25 06:32:19,763][0m Trial 14 pruned. Trial was pruned at epoch 0.[0m


In [7]:
print("Study statistics: ")
print("  Number of finished trials: ", len(study.trials))
print("  Number of pruned trials: ", len(pruned_trials))
print("  Number of complete trials: ", len(complete_trials))

print("Best trial:")
trial = study.best_trial

print("  Value: ", trial.value)

print("  Params: ")
for key, value in trial.params.items():
  print("    {}: {}".format(key, value))

Study statistics: 
  Number of finished trials:  15
  Number of pruned trials:  5
  Number of complete trials:  10
Best trial:
  Value:  0.9129999876022339
  Params: 
    n_layers: 1
    n_units_l0: 48
    dropout_l0: 0.20675427068205948
    learning_rate: 0.0015381421266250687
