diff --git a/src/lbfgs.jl b/src/lbfgs.jl index 9724a0e6..778d3f93 100644 --- a/src/lbfgs.jl +++ b/src/lbfgs.jl @@ -246,13 +246,14 @@ end Resets the given LBFGS data. """ -function reset!(data :: LBFGSData) +function reset!(data :: LBFGSData{T}) where T fill!(data.s, 0) fill!(data.y, 0) fill!(data.ys, 0) fill!(data.α , 0) fill!(data.a, 0) fill!(data.b, 0) + data.scaling_factor = T(1) data.insert = 1 return data end diff --git a/src/lsr1.jl b/src/lsr1.jl index 3d7bdf06..2ba08a9e 100644 --- a/src/lsr1.jl +++ b/src/lsr1.jl @@ -172,12 +172,13 @@ end Reset the given LSR1 data. """ -function reset!(data :: LSR1Data) +function reset!(data :: LSR1Data{T}) where T fill!(data.s, 0) fill!(data.y, 0) fill!(data.ys, 0) fill!(data.a, 0) fill!(data.as, 0) + data.scaling_factor = T(1) data.insert = 1 return data end diff --git a/test/test_lbfgs.jl b/test/test_lbfgs.jl index a748c713..c69386c8 100644 --- a/test/test_lbfgs.jl +++ b/test/test_lbfgs.jl @@ -56,6 +56,8 @@ function test_lbfgs() @test norm(H * v - v) > rtol reset!(B) reset!(H) + @test B.data.scaling_factor == 1.0 + @test H.data.scaling_factor == 1.0 @test norm(B * v - v) < rtol @test norm(H * v - v) < rtol end diff --git a/test/test_lsr1.jl b/test/test_lsr1.jl index da6592b5..74439889 100644 --- a/test/test_lsr1.jl +++ b/test/test_lsr1.jl @@ -31,6 +31,7 @@ function test_lsr1() v = simple_vector(Float64, n) @test norm(B * v - v) > rtol reset!(B) + @test B.data.scaling_factor == 1.0 @test norm(B * v - v) < rtol end