Skip to content

Commit

Permalink
Merge pull request #71 from huanglangwen/tridiagonal-diff
Browse files Browse the repository at this point in the history
add colored diff support for (tri/bi-)diagonal matrices
  • Loading branch information
ChrisRackauckas committed Jul 30, 2019
2 parents 9a87cc5 + 95809a0 commit cd4ee8f
Show file tree
Hide file tree
Showing 8 changed files with 371 additions and 183 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ version = "0.14.0"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

[compat]
julia = "1"
ArrayInterface = "1.1"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
10 changes: 2 additions & 8 deletions src/DiffEqDiffTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,16 @@ __precompile__()

module DiffEqDiffTools

using LinearAlgebra, SparseArrays, StaticArrays
using LinearAlgebra, SparseArrays, StaticArrays, ArrayInterface

import Base: resize!

include("diffeqfastbc.jl")
include("function_wrappers.jl")
include("finitediff.jl")
include("derivatives.jl")
include("gradients.jl")
include("jacobians.jl")
include("hessians.jl")

# Piracy
function Base.setindex(x::Array,v,i::Int)
_x = copy(x)
_x[i] = v
_x
end

end # module
10 changes: 5 additions & 5 deletions src/derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,19 +109,19 @@ function finite_difference_derivative!(

fx, epsilon = cache.fx, cache.epsilon
if typeof(epsilon) != Nothing
@. epsilon = compute_epsilon(fdtype, x, relstep, absstep, dir)
@.. epsilon = compute_epsilon(fdtype, x, relstep, absstep, dir)
end
if fdtype == Val{:forward}
if typeof(fx) == Nothing
@. df = (f(x+epsilon) - f(x)) / epsilon
@.. df = (f(x+epsilon) - f(x)) / epsilon
else
@. df = (f(x+epsilon) - fx) / epsilon
@.. df = (f(x+epsilon) - fx) / epsilon
end
elseif fdtype == Val{:central}
@. df = (f(x+epsilon) - f(x-epsilon)) / (2 * epsilon)
@.. df = (f(x+epsilon) - f(x-epsilon)) / (2 * epsilon)
elseif fdtype == Val{:complex} && returntype<:Real
epsilon_complex = eps(eltype(x))
@. df = imag(f(x+im*epsilon_complex)) / epsilon_complex
@.. df = imag(f(x+im*epsilon_complex)) / epsilon_complex
else
fdtype_error(returntype)
end
Expand Down
76 changes: 76 additions & 0 deletions src/diffeqfastbc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import Base.Broadcast: _broadcast_getindex, preprocess, preprocess_args, Broadcasted, broadcast_unalias, combine_axes, broadcast_axes, broadcast_shape, check_broadcast_axes, check_broadcast_shape, throwdm, broadcastable, AbstractArrayStyle, DefaultArrayStyle
import Base: copyto!, tail, axes, length, ndims
struct DiffBC{T}
x::T
end
@inline axes(b::DiffBC) = axes(b.x)
@inline length(b::DiffBC) = length(b.x)
@inline broadcastable(b::DiffBC) = b
@inline Base.ndims(b::Type{DiffBC{T}}) where T = ndims(T)
Base.@propagate_inbounds _broadcast_getindex(b::DiffBC, i) = _broadcast_getindex(b.x, i)
Base.@propagate_inbounds _broadcast_getindex(b::DiffBC{<:AbstractArray{<:Any,0}}, i) = b.x[]
Base.@propagate_inbounds _broadcast_getindex(b::DiffBC{<:AbstractVector}, i) = b.x[i[1]]
Base.@propagate_inbounds _broadcast_getindex(b::DiffBC{<:AbstractArray}, i) = b.x[i]
diffbc(x::Array) = DiffBC(x)
diffbc(x) = x

# Ensure inlining
@inline combine_axes(A, B) = broadcast_shape(broadcast_axes(A), broadcast_axes(B)) # Julia 1.0 compatible
@inline check_broadcast_axes(shp, A::Union{Number, Array, Broadcasted}) = check_broadcast_shape(shp, axes(A))

@inline preprocess(f, dest, bc::Broadcasted{Style}) where {Style} = Broadcasted{Style}(bc.f, preprocess_args(f, dest, bc.args), bc.axes)
preprocess(f, dest, x) = f(broadcast_unalias(dest, x))

@inline preprocess_args(f, dest, args::Tuple) = (preprocess(f, dest, args[1]), preprocess_args(f, dest, tail(args))...)
@inline preprocess_args(f, dest, args::Tuple{Any}) = (preprocess(f, dest, args[1]),)
preprocess_args(f, dest, args::Tuple{}) = ()

# Performance optimization for the common identity scalar case: dest .= val
@inline copyto!(dest::DiffBC, bc::Broadcasted{<:AbstractArrayStyle{0}}) = copyto!(dest.x, bc)
@inline function copyto!(dest::DiffBC, bc::Broadcasted)
axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc))
dest′ = dest.x
# Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match
if bc.f === identity && bc.args isa Tuple{AbstractArray} # only a single input argument to broadcast!
A = bc.args[1]
if axes(dest) == axes(A)
return copyto!(dest′, A)
end
end
bcs′ = preprocess(diffbc, dest, bc)
@simd ivdep for I in eachindex(bcs′)
@inbounds dest′[I] = bcs′[I]
end
return dest′ # return the original array without the wrapper
end

# Forcing `broadcasted` to inline is not necessary, since `Vern9` plays well
# with the Base implementation, and `Feagin`s do not use broadcasting.
#
#import Base.Broadcast: broadcasted, combine_styles
#map_nostop(f, t::Tuple{}) = ()
#map_nostop(f, t::Tuple{Any,}) = (f(t[1]),)
#map_nostop(f, t::Tuple{Any, Any}) = (f(t[1]), f(t[2]))
#map_nostop(f, t::Tuple{Any, Any, Any}) = (f(t[1]), f(t[2]), f(t[3]))
#map_nostop(f, t::Tuple) = (Base.@_inline_meta; (f(t[1]), map_nostop(f,tail(t))...))
#@inline function broadcasted(f::Union{typeof(*), typeof(+), typeof(muladd)}, arg1, arg2, args...)
# arg1′ = broadcastable(arg1)
# arg2′ = broadcastable(arg2)
# args′ = map_nostop(broadcastable, args)
# broadcasted(combine_styles(arg1′, arg2′, args′...), f, arg1′, arg2′, args′...)
#end

macro ..(x)
expr = Base.Broadcast.__dot__(x)
if expr.head in (:(.=), :(.+=), :(.-=), :(.*=), :(./=), :(.\=), :(.^=)) # we exclude `÷=` `%=` `&=` `|=` `⊻=` `>>>=` `>>=` `<<=` because they are for integers
name = gensym()
dest = :(DiffEqDiffTools.diffbc($(expr.args[1])))
expr.args[1] = name
return esc(quote
$name = $dest
$expr
end)
else
return esc(expr)
end
end
103 changes: 59 additions & 44 deletions src/gradients.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,21 @@ function GradientCache(
if fdtype!=Val{:complex} # complex-mode FD only needs one cache, for x+eps*im
if typeof(x)<:StridedVector
if eltype(df)<:Complex && !(eltype(x)<:Complex)
_c1 = fill(zero(Complex{eltype(x)}), size(x))
_c1 = zero(Complex{eltype(x)}) .* x
_c2 = nothing
else
_c1 = nothing
_c2 = nothing
end
else
_c1 = similar(x)
_c2 = fill(zero(real(eltype(x))), size(x))
_c2 = zero(real(eltype(x))) .* x
end
else
if !(returntype<:Real)
fdtype_error(returntype)
else
_c1 = x .+ 0*im
_c1 = x .+ zero(eltype(x)) .* im
_c2 = nothing
end
end
Expand All @@ -39,7 +39,7 @@ function GradientCache(
_c1 = similar(df)
_c2 = similar(df)
else
_c1 = fill(zero(Complex{eltype(x)}), size(df))
_c1 = zero(Complex{eltype(x)}) .* df
_c2 = nothing
end
end
Expand Down Expand Up @@ -70,7 +70,7 @@ function GradientCache(
if fdtype!=Val{:complex} # complex-mode FD only needs one cache, for x+eps*im
if typeof(x)<:StridedVector
if eltype(df)<:Complex && !(eltype(x)<:Complex)
_c1 = fill(zero(Complex{eltype(x)}), size(x))
_c1 = zero(Complex{eltype(x)}) .* x
_c2 = nothing
else
_c1 = nothing
Expand All @@ -87,7 +87,7 @@ function GradientCache(
if !(returntype<:Real)
fdtype_error(returntype)
else
_c1 = x + 0*im
_c1 = x + zero(eltype(x)) .* im
_c2 = nothing
end
end
Expand Down Expand Up @@ -119,7 +119,7 @@ function finite_difference_gradient(
dir=true) where {T1,T2,T3}

if typeof(x) <: AbstractArray
df = fill(zero(returntype), size(x))
df = zero(returntype) .* x
else
if inplace == Val{true}
if typeof(fx)==Nothing && typeof(c1)==Nothing && typeof(c2)==Nothing
Expand Down Expand Up @@ -164,7 +164,7 @@ function finite_difference_gradient(
dir=true) where {T1,T2,T3,fdtype,returntype,inplace}

if typeof(x) <: AbstractArray
df = fill(zero(returntype), size(x))
df = zero(returntype) .* x
else
df = zero(cache.c1)
end
Expand All @@ -186,61 +186,72 @@ function finite_difference_gradient!(
# NOTE: in this case epsilon is a vector, we need two arrays for epsilon and x1
# c1 denotes x1, c2 is epsilon
fx, c1, c2 = cache.fx, cache.c1, cache.c2
if fdtype != Val{:complex}
@. c2 = compute_epsilon(fdtype, x, relstep, absstep, dir)
if fdtype != Val{:complex} && ArrayInterface.fast_scalar_indexing(c2)
@.. c2 = compute_epsilon(fdtype, x, relstep, absstep, dir)
copyto!(c1,x)
end
if fdtype == Val{:forward}
@inbounds for i eachindex(x)
epsilon = c2[i]*dir
c1_old = c1[i]
c1[i] += epsilon
if ArrayInterface.fast_scalar_indexing(c2)
epsilon = ArrayInterface.allowed_getindex(c2,i)*dir
else
epsilon = compute_epsilon(fdtype, x, relstep, absstep, dir)*dir
end
c1_old = ArrayInterface.allowed_getindex(c1,i)
ArrayInterface.allowed_setindex!(c1,c1_old + epsilon,i)
if typeof(fx) != Nothing
dfi = (f(c1) - fx) / epsilon
else
fx0 = f(x)
dfi = (f(c1) - fx0) / epsilon
end
df[i] = real(dfi)
c1[i] = c1_old
df_tmp = real(dfi)
if eltype(df)<:Complex
c1[i] += im * epsilon
ArrayInterface.allowed_setindex!(c1,c1_old + im * epsilon,i)
if typeof(fx) != Nothing
dfi = (f(c1) - fx) / (im*epsilon)
else
dfi = (f(c1) - fx0) / (im*epsilon)
end
c1[i] = c1_old
df[i] -= im * imag(dfi)
ArrayInterface.allowed_setindex!(c1,c1_old,i)
ArrayInterface.allowed_setindex!(df, df_tmp - im * imag(dfi), i)
else
ArrayInterface.allowed_setindex!(df, df_tmp, i)
ArrayInterface.allowed_setindex!(c1,c1_old,i)
end
end
elseif fdtype == Val{:central}
@inbounds for i eachindex(x)
epsilon = c2[i]
c1_old = c1[i]
c1[i] += epsilon
x_old = x[i]
x[i] -= epsilon
df[i] = real((f(c1) - f(x)) / (2*epsilon))
c1[i] = c1_old
x[i] = x_old
if ArrayInterface.fast_scalar_indexing(c2)
epsilon = ArrayInterface.allowed_getindex(c2,i)*dir
else
epsilon = compute_epsilon(fdtype, x, relstep, absstep, dir)*dir
end
c1_old = ArrayInterface.allowed_getindex(c1,i)
ArrayInterface.allowed_setindex!(c1,c1_old + epsilon, i)
x_old = ArrayInterface.allowed_getindex(x,i)
ArrayInterface.allowed_setindex!(x,x_old - epsilon,i)
df_tmp = real((f(c1) - f(x)) / (2*epsilon))
if eltype(df)<:Complex
c1[i] += im*epsilon
x[i] -= im*epsilon
df[i] -= im*imag( (f(c1) - f(x)) / (2*im*epsilon) )
c1[i] = c1_old
x[i] = x_old
ArrayInterface.allowed_setindex!(c1,c1_old + im*epsilon,i)
ArrayInterface.allowed_setindex!(x,x_old - im*epsilon,i)
df_tmp2 = im*imag( (f(c1) - f(x)) / (2*im*epsilon) )
ArrayInterface.allowed_setindex!(df,df_tmp-df_tmp2,i)
else
ArrayInterface.allowed_setindex!(df,df_tmp,i)
end
ArrayInterface.allowed_setindex!(c1,c1_old, i)
ArrayInterface.allowed_setindex!(x,x_old,i)
end
elseif fdtype == Val{:complex} && returntype <: Real
copyto!(c1,x)
epsilon_complex = eps(real(eltype(x)))
# we use c1 here to avoid typing issues with x
@inbounds for i eachindex(x)
c1_old = c1[i]
c1[i] += im*epsilon_complex
df[i] = imag(f(c1)) / epsilon_complex
c1[i] = c1_old
c1_old = ArrayInterface.allowed_getindex(c1,i)
ArrayInterface.allowed_setindex!(c1,c1_old+im*epsilon_complex,i)
ArrayInterface.allowed_setindex!(df,imag(f(c1)) / epsilon_complex,i)
ArrayInterface.allowed_setindex!(c1,c1_old,i)
end
else
fdtype_error(returntype)
Expand Down Expand Up @@ -360,41 +371,45 @@ function finite_difference_gradient!(
# c1 denotes fx1, c2 is fx2, sizes guaranteed by the cache constructor
fx, c1, c2 = cache.fx, cache.c1, cache.c2

if inplace == Val{true}
_c1, _c2 = c1, c2
end

if fdtype == Val{:forward}
epsilon = compute_epsilon(Val{:forward}, x, relstep, absstep, dir)
if inplace == Val{true}
f(c1, x+epsilon)
else
c1 .= f(x+epsilon)
_c1 = f(x+epsilon)
end
if typeof(fx) != Nothing
@. df = (c1 - fx) / epsilon
@.. df = (_c1 - fx) / epsilon
else
if inplace == Val{true}
f(c2, x)
else
c2 .= f(x)
_c2 = f(x)
end
@. df = (c1 - c2) / epsilon
@.. df = (_c1 - _c2) / epsilon
end
elseif fdtype == Val{:central}
epsilon = compute_epsilon(Val{:central}, x, relstep, absstep, dir)
if inplace == Val{true}
f(c1, x+epsilon)
f(c2, x-epsilon)
else
c1 .= f(x+epsilon)
c2 .= f(x-epsilon)
_c1 = f(x+epsilon)
_c2 = f(x-epsilon)
end
@. df = (c1 - c2) / (2*epsilon)
@.. df = (_c1 - _c2) / (2*epsilon)
elseif fdtype == Val{:complex} && returntype <: Real
epsilon_complex = eps(real(eltype(x)))
if inplace == Val{true}
f(c1, x+im*epsilon_complex)
else
c1 .= f(x+im*epsilon_complex)
_c1 = f(x+im*epsilon_complex)
end
@. df = imag(c1) / epsilon_complex
@.. df = imag(_c1) / epsilon_complex
else
fdtype_error(returntype)
end
Expand Down

0 comments on commit cd4ee8f

Please sign in to comment.