From 12aa4a42eef80afd8c07ea0e4737e34ccad1c10c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 11 Nov 2025 15:30:04 -0500 Subject: [PATCH 1/3] add throw in common if objective value is not finite --- src/algorithms/common.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/algorithms/common.jl b/src/algorithms/common.jl index c9afcd676..5086d2f0b 100644 --- a/src/algorithms/common.jl +++ b/src/algorithms/common.jl @@ -85,6 +85,14 @@ function step( params = apply(operator, typeof(q), opt_st, params, re) avg_st = apply(averager, avg_st, params) + if !isfinite(DiffResults.value(grad)) + throw( + ErrorException( + "The objective value is $(DiffResults.value(grad)). This indicates that the opitimization run diverged.", + ), + ) + end + state = ( prob=prob, q=re(params), From 8689daaa733bcd34c90c87302cf55b1a3a14e62d Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 11 Nov 2025 15:31:30 -0500 Subject: [PATCH 2/3] fix move objective finite check forward --- src/algorithms/common.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/algorithms/common.jl b/src/algorithms/common.jl index 5086d2f0b..30897da32 100644 --- a/src/algorithms/common.jl +++ b/src/algorithms/common.jl @@ -80,19 +80,19 @@ function step( rng, objective, adtype, grad_buf, obj_st, params, re, objargs... ) - grad = DiffResults.gradient(grad_buf) - opt_st, params = Optimisers.update!(opt_st, params, grad) - params = apply(operator, typeof(q), opt_st, params, re) - avg_st = apply(averager, avg_st, params) - - if !isfinite(DiffResults.value(grad)) + if !isfinite(DiffResults.value(grad_buf)) throw( ErrorException( - "The objective value is $(DiffResults.value(grad)). This indicates that the opitimization run diverged.", + "The objective value is $(DiffResults.value(grad_buf)). This indicates that the opitimization run diverged.", ), ) end + grad = DiffResults.gradient(grad_buf) + opt_st, params = Optimisers.update!(opt_st, params, grad) + params = apply(operator, typeof(q), opt_st, params, re) + avg_st = apply(averager, avg_st, params) + state = ( prob=prob, q=re(params), From db5b0f287eae30ce6d919bb6a61dbae7eea048a2 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 11 Nov 2025 15:36:44 -0500 Subject: [PATCH 3/3] update history --- HISTORY.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/HISTORY.md b/HISTORY.md index 1c4e19251..131ae0e1b 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -5,6 +5,8 @@ Specifically, the following measure-space optimization algorithms have been adde - `KLMinWassFwdBwd` +In addition, `KLMinRepGradDescent`, `KLMinRepGradProxDescent`, `KLMinScoreGradDescent` will now throw a `RuntimException` if the objective value estimated at each step turns out to be degenerate (`Inf` or `NaN`). Previously, the algorithms ran until `max_iter` even if the optimization run has failed. + # Release 0.5 ## Default Configuration Changes