Skip to content

Commit

Permalink
Merge pull request optuna#4662 from toshihikoyanase/remove-tfkeras-in…
Browse files Browse the repository at this point in the history
…tegration

Remove `tf.keras` integration.
  • Loading branch information
HideakiImamura committed May 12, 2023
2 parents 9cb6ea8 + f8b651d commit 4ef1a10
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 128 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/tests-integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ jobs:
--ignore tests/integration_tests/test_skopt.py \
--ignore tests/integration_tests/test_tensorboard.py \
--ignore tests/integration_tests/test_tensorflow.py \
--ignore tests/integration_tests/test_tfkeras.py \
--ignore tests/integration_tests/lightgbm_tuner_tests \
--ignore tests/importance_tests/test_init.py \
--ignore tests/samplers_tests/test_samplers.py
Expand All @@ -111,7 +110,6 @@ jobs:
--ignore tests/integration_tests/test_skopt.py \
--ignore tests/integration_tests/test_tensorboard.py \
--ignore tests/integration_tests/test_tensorflow.py \
--ignore tests/integration_tests/test_tfkeras.py \
--ignore tests/integration_tests/lightgbm_tuner_tests \
--ignore tests/importance_tests/test_init.py \
--ignore tests/samplers_tests/test_samplers.py
Expand Down
1 change: 0 additions & 1 deletion docs/source/reference/integration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ TensorFlow

optuna.integration.TensorBoardCallback
optuna.integration.TensorFlowPruningHook
optuna.integration.TFKerasPruningCallback

XGBoost
-------
Expand Down
60 changes: 2 additions & 58 deletions optuna/integration/tfkeras.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,4 @@
from typing import Any
from typing import Dict
from typing import Optional
import warnings
from optuna_integration.tfkeras import TFKerasPruningCallback

import optuna


with optuna._imports.try_import() as _imports:
from tensorflow.keras.callbacks import Callback

if not _imports.is_successful():
Callback = object # NOQA


class TFKerasPruningCallback(Callback):
"""tf.keras callback to prune unpromising trials.
This callback is intend to be compatible for TensorFlow v1 and v2,
but only tested with TensorFlow v2.
See `the example <https://github.com/optuna/optuna-examples/blob/main/
tfkeras/tfkeras_integration.py>`__
if you want to add a pruning callback which observes the validation accuracy.
Args:
trial:
A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the
objective function.
monitor:
An evaluation metric for pruning, e.g., ``val_loss`` or ``val_acc``.
"""

def __init__(self, trial: optuna.trial.Trial, monitor: str) -> None:
super().__init__()

_imports.check()

self._trial = trial
self._monitor = monitor

def on_epoch_end(self, epoch: int, logs: Optional[Dict[str, Any]] = None) -> None:
logs = logs or {}
current_score = logs.get(self._monitor)

if current_score is None:
message = (
"The metric '{}' is not in the evaluation logs for pruning. "
"Please make sure you set the correct metric name.".format(self._monitor)
)
warnings.warn(message)
return

# Report current score and epoch to Optuna's trial.
self._trial.report(float(current_score), step=epoch)

# Prune trial if needed
if self._trial.should_prune():
message = "Trial was pruned at epoch {}.".format(epoch)
raise optuna.TrialPruned(message)
__all__ = ["TFKerasPruningCallback"]
67 changes: 0 additions & 67 deletions tests/integration_tests/test_tfkeras.py

This file was deleted.

0 comments on commit 4ef1a10

Please sign in to comment.