diff --git a/src/gradients.jl b/src/gradients.jl index 273c555..dbadc04 100644 --- a/src/gradients.jl +++ b/src/gradients.jl @@ -33,8 +33,10 @@ function GradientCache( if typeof(c2)!=Void && x<:StridedVector warn("c2 cache isn't necessary when x<:StridedVector.") end - if typeof(c2)==Void || eltype(c2)!=real(eltype(x)) && !(x<:StridedVector) + if (typeof(c2)==Void || eltype(c2)!=real(eltype(x))) && !(typeof(x)<:StridedVector) _c2 = zeros(real(eltype(x)), size(x)) + elseif typeof(x)<:StridedArray + _c2 = nothing else _c2 = c2 end @@ -176,7 +178,7 @@ function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::Abstract end df end -#= + function finite_difference_gradient!(df::StridedVector{<:Number}, f, x::StridedVector{<:Number}, cache::GradientCache{T1,T2,T3,fdtype,returntype,inplace}) where {T1,T2,T3,fdtype,returntype,inplace} @@ -189,24 +191,24 @@ function finite_difference_gradient!(df::StridedVector{<:Number}, f, x::StridedV if fdtype == Val{:forward} @inbounds for i ∈ eachindex(x) epsilon = compute_epsilon(fdtype, x[i], epsilon_factor) - c2_old = c2[i] - c2[i] += epsilon + c1_old = c1[i] + c1[i] += epsilon if typeof(fx) != Void - df[i] = (f(c2) - fx) / epsilon + df[i] = (f(c1) - fx) / epsilon else - df[i] = (f(c2) - f(x)) / epsilon + df[i] = (f(c1) - f(x)) / epsilon end - c2[i] = c2_old + c1[i] = c1_old end elseif fdtype == Val{:central} @inbounds for i ∈ eachindex(x) epsilon = compute_epsilon(fdtype, x[i], epsilon_factor) - c2_old = c2[i] - c2[i] += epsilon + c1_old = c1[i] + c1[i] += epsilon x_old = x[i] x[i] -= epsilon - df[i] = (f(c2) - f(x)) / (2*epsilon) - c2[i] = c2_old + df[i] = (f(c1) - f(x)) / (2*epsilon) + c1[i] = c1_old x[i] = x_old end elseif fdtype == Val{:complex} && returntype <: Real @@ -224,7 +226,7 @@ function finite_difference_gradient!(df::StridedVector{<:Number}, f, x::StridedV end df end -=# + # vector of derivatives of a scalar->vector map # this is effectively a vector of partial derivatives, but we still call it a gradient function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::Number,