Skip to content

Commit

Permalink
bump scimlbase and fix Optimisers err
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Jan 5, 2024
1 parent 08b130e commit c0e59f6
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ Printf = "1.9"
ProgressLogging = "0.1"
Reexport = "1.2"
ReverseDiff = "1.14"
SciMLBase = "2.16.2"
SciMLBase = "2.16.3"
SparseArrays = "1.9, 1.10"
SparseDiffTools = "2.14"
SymbolicIndexingInterface = "0.3"
Expand Down
4 changes: 2 additions & 2 deletions lib/OptimizationFlux/src/OptimizationFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ function SciMLBase.__solve(cache::OptimizationCache{
end

t1 = time()
stats = Optimization.OptimizationStats(; iterations = i,
time = t1 - t0, fevals = i, gevals = i)
stats = Optimization.OptimizationStats(; iterations = maxiters,

Check warning on line 105 in lib/OptimizationFlux/src/OptimizationFlux.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationFlux/src/OptimizationFlux.jl#L105

Added line #L105 was not covered by tests
time = t1 - t0, fevals = maxiters, gevals = maxiters)
SciMLBase.build_solution(cache, opt, θ, x[1], stats = stats)

Check warning on line 107 in lib/OptimizationFlux/src/OptimizationFlux.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationFlux/src/OptimizationFlux.jl#L107

Added line #L107 was not covered by tests
# here should be build_solution to create the output message
end
Expand Down
8 changes: 5 additions & 3 deletions lib/OptimizationOptimisers/src/OptimizationOptimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ function SciMLBase.__solve(cache::OptimizationCache{
P,
C,
}
local i
if cache.data != Optimization.DEFAULT_DATA
maxiters = length(cache.data)
data = cache.data
else
maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters)
if maxiters === nothing
throw(ArgumentError("The number of iterations must be specified as the maxiters kwarg."))

Check warning on line 50 in lib/OptimizationOptimisers/src/OptimizationOptimisers.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationOptimisers/src/OptimizationOptimisers.jl#L50

Added line #L50 was not covered by tests
end
data = Optimization.take(cache.data, maxiters)
end
opt = cache.opt
Expand Down Expand Up @@ -96,8 +98,8 @@ function SciMLBase.__solve(cache::OptimizationCache{
end

t1 = time()
stats = Optimization.OptimizationStats(; iterations = i,
time = t1 - t0, fevals = i, gevals = i)
stats = Optimization.OptimizationStats(; iterations = maxiters,
time = t1 - t0, fevals = maxiters, gevals = maxiters)
SciMLBase.build_solution(cache, cache.opt, θ, first(x)[1], stats = stats)
# here should be build_solution to create the output message
end
Expand Down
2 changes: 2 additions & 0 deletions lib/OptimizationOptimisers/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,6 @@ using Zygote
end
sol = solve(prob, Optimisers.Adam(0.1), maxiters = 1000, progress = false, callback = callback)
end

@test_throws ArgumentError sol = solve(prob, Optimisers.Adam())
end

0 comments on commit c0e59f6

Please sign in to comment.