diff --git a/lib/OptimizationOptimisers/Project.toml b/lib/OptimizationOptimisers/Project.toml index d5a921f13..05e07b123 100644 --- a/lib/OptimizationOptimisers/Project.toml +++ b/lib/OptimizationOptimisers/Project.toml @@ -4,33 +4,34 @@ authors = ["Vaibhav Dixit and contributors"] version = "0.3.13" [deps] -OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb" -SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" +OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" +SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" -[extras] -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" -MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" -MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" -Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +[sources] +OptimizationBase = {path = "../OptimizationBase"} [compat] -julia = "1.10" -OptimizationBase = "4" -SciMLBase = "2.122.1" +Logging = "1.10" Optimisers = "0.2, 0.3, 0.4" +OptimizationBase = "4" Reexport = "1.2" -Logging = "1.10" +SciMLBase = "2.122.1" +julia = "1.10" -[sources] -OptimizationBase = {path = "../OptimizationBase"} +[extras] +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] test = ["ComponentArrays", "ForwardDiff", "Lux", "MLDataDevices", "MLUtils", "Random", "Test", "Zygote", "Printf"] diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index b1713244d..2476e5743 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -67,6 +67,7 @@ function SciMLBase.__solve(cache::OptimizationCache{O}) where {O <: AbstractRule breakall = false progress_id = :OptimizationOptimizersJL for epoch in 1:epochs, d in data + if cache.f.fg !== nothing && dataiterate x = cache.f.fg(G, θ, d) iterations += 1 @@ -106,7 +107,7 @@ function SciMLBase.__solve(cache::OptimizationCache{O}) where {O <: AbstractRule if cache.progress message = "Loss: $(round(first(first(x)); digits = 3))" @logmsg(LogLevel(-1), "Optimization", _id=progress_id, - message=message, progress=iterations / maxiters) + message=message, progress=iterations/maxiters) end if cache.solver_args.save_best if first(x)[1] < first(min_err)[1] #found a better solution @@ -129,7 +130,12 @@ function SciMLBase.__solve(cache::OptimizationCache{O}) where {O <: AbstractRule break end end - state, θ = Optimisers.update(state, θ, G) + # Skip update if gradient contains NaN or Inf values + if all(isfinite, G) + state, θ = Optimisers.update(state, θ, G) + elseif cache.progress + @warn "Skipping parameter update due to NaN or Inf in gradients at iteration $iterations" maxlog=10 + end end cache.progress && @logmsg(LogLevel(-1), "Optimization", _id=progress_id, message="Done", progress=1.0) diff --git a/lib/OptimizationOptimisers/test/runtests.jl b/lib/OptimizationOptimisers/test/runtests.jl index ad754cf74..269d01932 100644 --- a/lib/OptimizationOptimisers/test/runtests.jl +++ b/lib/OptimizationOptimisers/test/runtests.jl @@ -134,3 +134,47 @@ end @test res.objective < 1e-4 end + +@testset "NaN/Inf gradient handling" begin + # Test that optimizer skips updates when gradients contain NaN or Inf + # Function that can produce NaN due to sqrt of negative number + function weird_nan_function(x, p) + val = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2 + # sqrt of a value that can become negative produces NaN + val += sqrt(max(x[1], 0.0)) * 0.01 + return val + end + + x0 = [-0.5, 0.1] # Start with negative x[1] to trigger sqrt of negative + _p = [1.0, 100.0] + + optprob = OptimizationFunction(weird_nan_function, OptimizationBase.AutoZygote()) + prob = OptimizationProblem(optprob, x0, _p) + + # Should not throw error and should complete all iterations + sol = solve(prob, Optimisers.Adam(0.01), maxiters = 50, progress = false) + + # Verify solution completed all iterations + @test sol.stats.iterations == 50 + + # Verify parameters are not NaN (would be NaN if updates were applied with NaN gradients) + @test all(!isnan, sol.u) + @test all(isfinite, sol.u) + + # Function with 1/x that can produce Inf gradient when x is very small + function weird_inf_function(x, p) + val = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2 + # 1/(x[1] + 0.01) can have very large gradient near x[1] = -0.01 + val += 0.01 / (abs(x[1] - 0.1) + 1e-8) + return val + end + + optprob_inf = OptimizationFunction(weird_inf_function, OptimizationBase.AutoZygote()) + prob_inf = OptimizationProblem(optprob_inf, x0, _p) + + sol_inf = solve(prob_inf, Optimisers.Adam(0.01), maxiters = 50, progress = false) + + @test sol_inf.stats.iterations == 50 + @test all(!isnan, sol_inf.u) + @test all(isfinite, sol_inf.u) +end