Skip to content

Commit

Permalink
Merge 9e6e9a8 into 71c5ac0
Browse files Browse the repository at this point in the history
  • Loading branch information
bzinberg committed Oct 14, 2020
2 parents 71c5ac0 + 9e6e9a8 commit 694aba0
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/derivatives/linalg/arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ for A in ARRAY_TYPES
@eval @inline Base.:+(x::$(A), y::TrackedArray{V,D}) where {V,D} = record_plus(x, y, D)
end

@inline Base.:+(x::TrackedArray{V,D}, y::StaticArray) where {V,D} = record_plus(x, Array(y), D)
@inline Base.:+(x::StaticArray, y::TrackedArray{V,D}) where {V,D} = record_plus(Array(x), y, D)

function record_plus(x, y, ::Type{D}) where D
tp = tape(x, y)
out = track(value(x) + value(y), D, tp)
Expand Down Expand Up @@ -108,6 +111,9 @@ for A in ARRAY_TYPES
@eval Base.:-(x::$(A), y::TrackedArray{V,D}) where {V,D} = record_minus(x, y, D)
end

@inline Base.:-(x::TrackedArray{V,D}, y::StaticArray) where {V,D} = record_minus(x, Array(y), D)
@inline Base.:-(x::StaticArray, y::TrackedArray{V,D}) where {V,D} = record_minus(Array(x), y, D)

function Base.:-(x::TrackedArray{V,D}) where {V,D}
tp = tape(x)
out = track(-(value(x)), D, tp)
Expand Down
4 changes: 4 additions & 0 deletions src/tracked.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ function deriv!(t::NTuple{N,Any}, v::NTuple{N,Any}) where N
return nothing
end

deriv!(t::StaticArray, v::AbstractArray) = deriv!(Tuple(t), Tuple(v))

# pulling values from origin #
#----------------------------#

Expand Down Expand Up @@ -223,6 +225,8 @@ unseed!(x::AbstractArray, i) = unseed!(x[i])
capture(t::TrackedReal) = ifelse(hastape(t), t, value(t))
capture(t::TrackedArray) = t
capture(t::AbstractArray) = istracked(t) ? map!(capture, similar(t), t) : copy(t)
# `StaticArray`s don't support mutation unless the eltype is a bits type (`isbitstype`).
capture(t::SA) where SA <: StaticArray = istracked(t) ? SA(map(capture, t)) : copy(t)

########################
# Conversion/Promotion #
Expand Down

0 comments on commit 694aba0

Please sign in to comment.