Skip to content

Commit

Permalink
Merge pull request optuna#4666 from gen740/remove-tensorflow-integration
Browse files Browse the repository at this point in the history
Remove `tensorflow` integration
  • Loading branch information
toshihikoyanase authored May 22, 2023
2 parents 30fabb9 + c6035c7 commit 8c86672
Show file tree
Hide file tree
Showing 6 changed files with 2 additions and 154 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/tests-integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ jobs:
--ignore tests/integration_tests/test_shap.py \
--ignore tests/integration_tests/test_skopt.py \
--ignore tests/integration_tests/test_tensorboard.py \
--ignore tests/integration_tests/test_tensorflow.py \
--ignore tests/integration_tests/lightgbm_tuner_tests \
--ignore tests/importance_tests/test_init.py \
--ignore tests/samplers_tests/test_samplers.py
Expand All @@ -109,7 +108,6 @@ jobs:
--ignore tests/integration_tests/test_shap.py \
--ignore tests/integration_tests/test_skopt.py \
--ignore tests/integration_tests/test_tensorboard.py \
--ignore tests/integration_tests/test_tensorflow.py \
--ignore tests/integration_tests/lightgbm_tuner_tests \
--ignore tests/importance_tests/test_init.py \
--ignore tests/samplers_tests/test_samplers.py
Expand Down
1 change: 0 additions & 1 deletion docs/source/reference/integration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ TensorFlow
:nosignatures:

optuna.integration.TensorBoardCallback
optuna.integration.TensorFlowPruningHook

XGBoost
-------
Expand Down
82 changes: 2 additions & 80 deletions optuna/integration/tensorflow.py
Original file line number Diff line number Diff line change
@@ -1,82 +1,4 @@
import optuna
from optuna._imports import try_import
from optuna_integration.tensorflow import TensorFlowPruningHook


with try_import() as _imports:
import tensorflow as tf
from tensorflow.estimator import SessionRunHook
from tensorflow_estimator.python.estimator.early_stopping import read_eval_metrics

if not _imports.is_successful():
SessionRunHook = object # NOQA


class TensorFlowPruningHook(SessionRunHook):
"""TensorFlow SessionRunHook to prune unpromising trials.
See `the example <https://github.com/optuna/optuna-examples/tree/main/
tensorflow/tensorflow_estimator_integration.py>`_
if you want to add a pruning hook to TensorFlow's estimator.
Args:
trial:
A :class:`~optuna.trial.Trial` corresponding to the current evaluation of
the objective function.
estimator:
An estimator which you will use.
metric:
An evaluation metric for pruning, e.g., ``accuracy`` and ``loss``.
run_every_steps:
An interval to watch the summary file.
"""

def __init__(
self,
trial: optuna.trial.Trial,
estimator: "tf.estimator.Estimator",
metric: str,
run_every_steps: int,
) -> None:
_imports.check()

self._trial = trial
self._estimator = estimator
self._current_summary_step = -1
self._metric = metric
self._global_step_tensor = None
self._timer = tf.estimator.SecondOrStepTimer(every_secs=None, every_steps=run_every_steps)

def begin(self) -> None:
self._global_step_tensor = tf.compat.v1.train.get_global_step()

def before_run(
self, run_context: "tf.estimator.SessionRunContext"
) -> "tf.estimator.SessionRunArgs":
del run_context
return tf.estimator.SessionRunArgs(self._global_step_tensor)

def after_run(
self,
run_context: "tf.estimator.SessionRunContext",
run_values: "tf.estimator.SessionRunValues",
) -> None:
global_step = run_values.results
# Get eval metrics every n steps.
if self._timer.should_trigger_for_step(global_step):
self._timer.update_last_triggered_step(global_step)
eval_metrics = read_eval_metrics(self._estimator.eval_dir())
else:
eval_metrics = None
if eval_metrics:
summary_step = next(reversed(eval_metrics))
latest_eval_metrics = eval_metrics[summary_step]
# If there exists a new evaluation summary.
if summary_step > self._current_summary_step:
current_score = latest_eval_metrics[self._metric]
if current_score is None:
current_score = float("nan")
self._trial.report(float(current_score), step=summary_step)
self._current_summary_step = summary_step
if self._trial.should_prune():
message = "Trial was pruned at iteration {}.".format(self._current_summary_step)
raise optuna.TrialPruned(message)
__all__ = ["TensorFlowPruningHook"]
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ integration = [
"scikit-optimize; python_version<'3.11'",
"shap; python_version<'3.11'",
"tensorflow; python_version<'3.11'",
"tensorflow-datasets; python_version<'3.11'",
"torch; python_version<'3.11'",
"torchaudio; python_version<'3.11'",
"torchvision; python_version<'3.11'",
Expand Down
4 changes: 0 additions & 4 deletions tests/integration_tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ def test_import() -> None:
from optuna.integration import LightGBMPruningCallback # NOQA
from optuna.integration import mxnet # NOQA
from optuna.integration import MXNetPruningCallback # NOQA
from optuna.integration import tensorflow # NOQA
from optuna.integration import TensorFlowPruningHook # NOQA
from optuna.integration import xgboost # NOQA
from optuna.integration import XGBoostPruningCallback # NOQA

Expand All @@ -26,11 +24,9 @@ def test_module_attributes() -> None:
assert hasattr(optuna.integration, "dask")
assert hasattr(optuna.integration, "lightgbm")
assert hasattr(optuna.integration, "mxnet")
assert hasattr(optuna.integration, "tensorflow")
assert hasattr(optuna.integration, "xgboost")
assert hasattr(optuna.integration, "LightGBMPruningCallback")
assert hasattr(optuna.integration, "MXNetPruningCallback")
assert hasattr(optuna.integration, "TensorFlowPruningHook")
assert hasattr(optuna.integration, "XGBoostPruningCallback")

with pytest.raises(AttributeError):
Expand Down
66 changes: 0 additions & 66 deletions tests/integration_tests/test_tensorflow.py

This file was deleted.

0 comments on commit 8c86672

Please sign in to comment.