diff --git a/HISTORY.md b/HISTORY.md index 1c4e19251..131ae0e1b 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -5,6 +5,8 @@ Specifically, the following measure-space optimization algorithms have been adde - `KLMinWassFwdBwd` +In addition, `KLMinRepGradDescent`, `KLMinRepGradProxDescent`, `KLMinScoreGradDescent` will now throw a `RuntimException` if the objective value estimated at each step turns out to be degenerate (`Inf` or `NaN`). Previously, the algorithms ran until `max_iter` even if the optimization run has failed. + # Release 0.5 ## Default Configuration Changes diff --git a/src/algorithms/common.jl b/src/algorithms/common.jl index c9afcd676..30897da32 100644 --- a/src/algorithms/common.jl +++ b/src/algorithms/common.jl @@ -80,6 +80,14 @@ function step( rng, objective, adtype, grad_buf, obj_st, params, re, objargs... ) + if !isfinite(DiffResults.value(grad_buf)) + throw( + ErrorException( + "The objective value is $(DiffResults.value(grad_buf)). This indicates that the opitimization run diverged.", + ), + ) + end + grad = DiffResults.gradient(grad_buf) opt_st, params = Optimisers.update!(opt_st, params, grad) params = apply(operator, typeof(q), opt_st, params, re)