From f908912c1784894f5334c9e9f636c0841dc032bc Mon Sep 17 00:00:00 2001 From: Contramundum Date: Thu, 28 Mar 2024 11:27:58 +0900 Subject: [PATCH 1/2] Fix average_is_best flag --- optuna/pruners/_wilcoxon.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optuna/pruners/_wilcoxon.py b/optuna/pruners/_wilcoxon.py index 950b79c781..096824bdaa 100644 --- a/optuna/pruners/_wilcoxon.py +++ b/optuna/pruners/_wilcoxon.py @@ -198,10 +198,10 @@ def prune(self, study: "optuna.study.Study", trial: FrozenTrial) -> bool: if study.direction == StudyDirection.MAXIMIZE: alt = "less" - average_is_best = best_trial.value <= sum(step_values) / len(step_values) + average_is_best = sum(best_step_values) / len(best_step_values) <= sum(step_values) / len(step_values) else: alt = "greater" - average_is_best = best_trial.value >= sum(step_values) / len(step_values) + average_is_best = sum(best_step_values) / len(best_step_values) >= sum(step_values) / len(step_values) # We use zsplit to avoid the problem when all values are zero. p = ss.wilcoxon(diff_values, alternative=alt, zero_method="zsplit").pvalue From 3981f6971956ca06f760c91504ab05003ed7061c Mon Sep 17 00:00:00 2001 From: Contramundum Date: Thu, 28 Mar 2024 11:29:05 +0900 Subject: [PATCH 2/2] format --- optuna/pruners/_wilcoxon.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/optuna/pruners/_wilcoxon.py b/optuna/pruners/_wilcoxon.py index 096824bdaa..7b08025611 100644 --- a/optuna/pruners/_wilcoxon.py +++ b/optuna/pruners/_wilcoxon.py @@ -198,10 +198,14 @@ def prune(self, study: "optuna.study.Study", trial: FrozenTrial) -> bool: if study.direction == StudyDirection.MAXIMIZE: alt = "less" - average_is_best = sum(best_step_values) / len(best_step_values) <= sum(step_values) / len(step_values) + average_is_best = sum(best_step_values) / len(best_step_values) <= sum( + step_values + ) / len(step_values) else: alt = "greater" - average_is_best = sum(best_step_values) / len(best_step_values) >= sum(step_values) / len(step_values) + average_is_best = sum(best_step_values) / len(best_step_values) >= sum( + step_values + ) / len(step_values) # We use zsplit to avoid the problem when all values are zero. p = ss.wilcoxon(diff_values, alternative=alt, zero_method="zsplit").pvalue