Skip to content
Merged
39 changes: 20 additions & 19 deletions lib/OptimizationOptimisers/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,34 @@ authors = ["Vaibhav Dixit <vaibhavyashdixit@gmail.com> and contributors"]
version = "0.3.13"

[deps]
OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be here 😅

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ugh
yeah

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#1083 will fix it

OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"

[extras]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
[sources]
OptimizationBase = {path = "../OptimizationBase"}

[compat]
julia = "1.10"
OptimizationBase = "4"
SciMLBase = "2.122.1"
Logging = "1.10"
Optimisers = "0.2, 0.3, 0.4"
OptimizationBase = "4"
Reexport = "1.2"
Logging = "1.10"
SciMLBase = "2.122.1"
julia = "1.10"

[sources]
OptimizationBase = {path = "../OptimizationBase"}
[extras]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["ComponentArrays", "ForwardDiff", "Lux", "MLDataDevices", "MLUtils", "Random", "Test", "Zygote", "Printf"]
10 changes: 8 additions & 2 deletions lib/OptimizationOptimisers/src/OptimizationOptimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ function SciMLBase.__solve(cache::OptimizationCache{O}) where {O <: AbstractRule
breakall = false
progress_id = :OptimizationOptimizersJL
for epoch in 1:epochs, d in data

if cache.f.fg !== nothing && dataiterate
x = cache.f.fg(G, θ, d)
iterations += 1
Expand Down Expand Up @@ -106,7 +107,7 @@ function SciMLBase.__solve(cache::OptimizationCache{O}) where {O <: AbstractRule
if cache.progress
message = "Loss: $(round(first(first(x)); digits = 3))"
@logmsg(LogLevel(-1), "Optimization", _id=progress_id,
message=message, progress=iterations / maxiters)
message=message, progress=iterations/maxiters)
end
if cache.solver_args.save_best
if first(x)[1] < first(min_err)[1] #found a better solution
Expand All @@ -129,7 +130,12 @@ function SciMLBase.__solve(cache::OptimizationCache{O}) where {O <: AbstractRule
break
end
end
state, θ = Optimisers.update(state, θ, G)
# Skip update if gradient contains NaN or Inf values
if all(isfinite, G)
state, θ = Optimisers.update(state, θ, G)
elseif cache.progress
@warn "Skipping parameter update due to NaN or Inf in gradients at iteration $iterations" maxlog=10
end
end
cache.progress && @logmsg(LogLevel(-1), "Optimization",
_id=progress_id, message="Done", progress=1.0)
Expand Down
44 changes: 44 additions & 0 deletions lib/OptimizationOptimisers/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,47 @@ end

@test res.objective < 1e-4
end

@testset "NaN/Inf gradient handling" begin
# Test that optimizer skips updates when gradients contain NaN or Inf
# Function that can produce NaN due to sqrt of negative number
function weird_nan_function(x, p)
val = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
# sqrt of a value that can become negative produces NaN
val += sqrt(max(x[1], 0.0)) * 0.01
return val
end

x0 = [-0.5, 0.1] # Start with negative x[1] to trigger sqrt of negative
_p = [1.0, 100.0]

optprob = OptimizationFunction(weird_nan_function, OptimizationBase.AutoZygote())
prob = OptimizationProblem(optprob, x0, _p)

# Should not throw error and should complete all iterations
sol = solve(prob, Optimisers.Adam(0.01), maxiters = 50, progress = false)

# Verify solution completed all iterations
@test sol.stats.iterations == 50

# Verify parameters are not NaN (would be NaN if updates were applied with NaN gradients)
@test all(!isnan, sol.u)
@test all(isfinite, sol.u)

# Function with 1/x that can produce Inf gradient when x is very small
function weird_inf_function(x, p)
val = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
# 1/(x[1] + 0.01) can have very large gradient near x[1] = -0.01
val += 0.01 / (abs(x[1] - 0.1) + 1e-8)
return val
end

optprob_inf = OptimizationFunction(weird_inf_function, OptimizationBase.AutoZygote())
prob_inf = OptimizationProblem(optprob_inf, x0, _p)

sol_inf = solve(prob_inf, Optimisers.Adam(0.01), maxiters = 50, progress = false)

@test sol_inf.stats.iterations == 50
@test all(!isnan, sol_inf.u)
@test all(isfinite, sol_inf.u)
end
Loading