Skip to content

Commit

Permalink
Dispatch on StaticArray instead of SArray (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored and jrevels committed Feb 12, 2019
1 parent d49bd6d commit 8d12336
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions src/DiffResults.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,17 @@ Return `r::DiffResult`, with output value storage provided by `value` and output
storage provided by `derivs`.
In reality, `DiffResult` is an abstract supertype of two concrete types, `MutableDiffResult`
and `ImmutableDiffResult`. If all `value`/`derivs` are all `Number`s or `SArray`s, then `r`
will be immutable (i.e. `r::ImmutableDiffResult`). Otherwise, `r` will be mutable
and `ImmutableDiffResult`. If all `value`/`derivs` are all `Number`s or `StaticArray`s,
then `r` will be immutable (i.e. `r::ImmutableDiffResult`). Otherwise, `r` will be mutable
(i.e. `r::MutableDiffResult`).
Note that `derivs` can be provide in splatted form, i.e. `DiffResult(value, derivs...)`.
"""
DiffResult

DiffResult(value::Number, derivs::Tuple{Vararg{Number}}) = ImmutableDiffResult(value, derivs)
DiffResult(value::Number, derivs::Tuple{Vararg{SArray}}) = ImmutableDiffResult(value, derivs)
DiffResult(value::SArray, derivs::Tuple{Vararg{SArray}}) = ImmutableDiffResult(value, derivs)
DiffResult(value::Number, derivs::Tuple{Vararg{StaticArray}}) = ImmutableDiffResult(value, derivs)
DiffResult(value::StaticArray, derivs::Tuple{Vararg{StaticArray}}) = ImmutableDiffResult(value, derivs)
DiffResult(value::Number, derivs::Tuple{Vararg{AbstractArray}}) = MutableDiffResult(value, derivs)
DiffResult(value::AbstractArray, derivs::Tuple{Vararg{AbstractArray}}) = MutableDiffResult(value, derivs)
DiffResult(value::Union{Number,AbstractArray}, derivs::Union{Number,AbstractArray}...) = DiffResult(value, derivs)
Expand All @@ -65,7 +65,7 @@ shape information. If you want to allocate storage yourself, use the `DiffResult
constructor instead.
"""
GradientResult(x::AbstractArray) = DiffResult(first(x), similar(x))
GradientResult(x::SArray) = DiffResult(first(x), x)
GradientResult(x::StaticArray) = DiffResult(first(x), x)

"""
JacobianResult(x::AbstractArray)
Expand All @@ -79,7 +79,7 @@ shape information. If you want to allocate storage yourself, use the `DiffResult
constructor instead.
"""
JacobianResult(x::AbstractArray) = DiffResult(similar(x), similar(x, length(x), length(x)))
JacobianResult(x::SArray{<:Any,T,<:Any,L}) where {T,L} = DiffResult(x, zeros(SMatrix{L,L,T}))
JacobianResult(x::StaticArray) = DiffResult(x, zeros(StaticArrays.similar_type(typeof(x), Size(length(x),length(x)))))

"""
JacobianResult(y::AbstractArray, x::AbstractArray)
Expand All @@ -92,7 +92,7 @@ Like the single argument version, `y` and `x` are only used for type and
shape information and are not stored in the returned `DiffResult`.
"""
JacobianResult(y::AbstractArray, x::AbstractArray) = DiffResult(similar(y), similar(y, length(y), length(x)))
JacobianResult(y::SArray{<:Any,<:Any,<:Any,Y}, x::SArray{<:Any,T,<:Any,X}) where {T,Y,X} = DiffResult(y, zeros(SMatrix{Y,X,T}))
JacobianResult(y::StaticArray, x::StaticArray) = DiffResult(y, zeros(StaticArrays.similar_type(typeof(x), Size(length(y),length(x)))))

"""
HessianResult(x::AbstractArray)
Expand All @@ -105,7 +105,7 @@ shape information. If you want to allocate storage yourself, use the `DiffResult
constructor instead.
"""
HessianResult(x::AbstractArray) = DiffResult(first(x), similar(x), similar(x, length(x), length(x)))
HessianResult(x::SArray{<:Any,T,<:Any,L}) where {T,L} = DiffResult(first(x), x, zeros(SMatrix{L,L,T}))
HessianResult(x::StaticArray) = DiffResult(first(x), x, zeros(StaticArrays.similar_type(typeof(x), Size(length(x),length(x)))))

#############
# Interface #
Expand Down Expand Up @@ -203,7 +203,7 @@ function derivative!(r::MutableDiffResult, x::AbstractArray, ::Type{Val{i}} = Va
return r
end

function derivative!(r::ImmutableDiffResult, x::Union{Number,SArray}, ::Type{Val{i}} = Val{1}) where {i}
function derivative!(r::ImmutableDiffResult, x::Union{Number,StaticArray}, ::Type{Val{i}} = Val{1}) where {i}
return ImmutableDiffResult(value(r), tuple_setindex(r.derivs, x, Val{i}))
end

Expand Down Expand Up @@ -232,7 +232,7 @@ function derivative!(f, r::ImmutableDiffResult, x::Number, ::Type{Val{i}} = Val{
return derivative!(r, f(x), Val{i})
end

function derivative!(f, r::ImmutableDiffResult, x::SArray, ::Type{Val{i}} = Val{1}) where {i}
function derivative!(f, r::ImmutableDiffResult, x::StaticArray, ::Type{Val{i}} = Val{1}) where {i}
return derivative!(r, map(f, x), Val{i})
end

Expand Down

0 comments on commit 8d12336

Please sign in to comment.