Skip to content

Commit

Permalink
Minor perf improvements for gradients #2.
Browse files Browse the repository at this point in the history
  • Loading branch information
dextorious authored and ChrisRackauckas committed Feb 15, 2018
1 parent 0f8e5c9 commit c416bbf
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions src/gradients.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ function GradientCache(
if typeof(c2)!=Void && x<:StridedVector
warn("c2 cache isn't necessary when x<:StridedVector.")
end
if typeof(c2)==Void || eltype(c2)!=real(eltype(x)) && !(x<:StridedVector)
if (typeof(c2)==Void || eltype(c2)!=real(eltype(x))) && !(typeof(x)<:StridedVector)
_c2 = zeros(real(eltype(x)), size(x))
elseif typeof(x)<:StridedArray
_c2 = nothing
else
_c2 = c2
end
Expand Down Expand Up @@ -176,7 +178,7 @@ function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::Abstract
end
df
end
#=

function finite_difference_gradient!(df::StridedVector{<:Number}, f, x::StridedVector{<:Number},
cache::GradientCache{T1,T2,T3,fdtype,returntype,inplace}) where {T1,T2,T3,fdtype,returntype,inplace}

Expand All @@ -189,24 +191,24 @@ function finite_difference_gradient!(df::StridedVector{<:Number}, f, x::StridedV
if fdtype == Val{:forward}
@inbounds for i eachindex(x)
epsilon = compute_epsilon(fdtype, x[i], epsilon_factor)
c2_old = c2[i]
c2[i] += epsilon
c1_old = c1[i]
c1[i] += epsilon
if typeof(fx) != Void
df[i] = (f(c2) - fx) / epsilon
df[i] = (f(c1) - fx) / epsilon
else
df[i] = (f(c2) - f(x)) / epsilon
df[i] = (f(c1) - f(x)) / epsilon
end
c2[i] = c2_old
c1[i] = c1_old
end
elseif fdtype == Val{:central}
@inbounds for i eachindex(x)
epsilon = compute_epsilon(fdtype, x[i], epsilon_factor)
c2_old = c2[i]
c2[i] += epsilon
c1_old = c1[i]
c1[i] += epsilon
x_old = x[i]
x[i] -= epsilon
df[i] = (f(c2) - f(x)) / (2*epsilon)
c2[i] = c2_old
df[i] = (f(c1) - f(x)) / (2*epsilon)
c1[i] = c1_old
x[i] = x_old
end
elseif fdtype == Val{:complex} && returntype <: Real
Expand All @@ -224,7 +226,7 @@ function finite_difference_gradient!(df::StridedVector{<:Number}, f, x::StridedV
end
df
end
=#

# vector of derivatives of a scalar->vector map
# this is effectively a vector of partial derivatives, but we still call it a gradient
function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::Number,
Expand Down

0 comments on commit c416bbf

Please sign in to comment.