Skip to content

Commit

Permalink
_dot to _realdot
Browse files Browse the repository at this point in the history
  • Loading branch information
sethaxen committed Oct 12, 2021
1 parent b4343f5 commit 9bab121
Showing 1 changed file with 14 additions and 16 deletions.
30 changes: 14 additions & 16 deletions src/AbstractDifferentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,28 +221,26 @@ _zero(::AbstractVector, d::AbstractMatrix) = zero(similar(d, size(d, 2)))
_zero(::AbstractMatrix, d::AbstractMatrix) = zero(d)
_zero(::Any, d::Any) = zero(d)

@inline _dot(x, y) = dot(x, y)
@inline function _dot(x::AbstractVector, y::UniformScaling)
# support pullbacks for complex numbers
@inline _realdot(x, y) = real(dot(x, y))
@inline _realdot(x::Number, y::Number) = muladd(real(x), real(y), imag(x) * imag(y))
@inline _realdot(x::Real, y::Number) = x * real(y)
@inline _realdot(x::Number, y::Real) = real(x) * y
@inline _realdot(x::Real, y::Real) = x * y
@inline function _realdot(x::AbstractVector, y::UniformScaling)
@assert length(x) == 1
return @inbounds dot(x[1], y.λ)
return @inbounds _realdot(x[1], y.λ)
end
@inline function _dot(x::AbstractVector, y::AbstractMatrix)
@inline function _realdot(x::AbstractVector, y::AbstractMatrix)
@assert size(y, 2) == 1
return dot(x, y)
return _realdot(x, vec(y))
end
@inline function _realdot(xs::NTuple{N}, ys::NTuple{N}) where {N}
return sum(Base.splat(_realdot), zip(xs, ys))
end

function pullback_function(ab::AbstractBackend, f, xs...)
return (ws) -> begin
return gradient(lowest(ab), (xs...,) -> begin
vs = f(xs...)
if ws isa Tuple
@assert length(vs) == length(ws)
return sum(Base.splat(_dot), zip(ws, vs))
else
return _dot(vs, ws)
end
end, xs...)
end
return (ws) -> gradient(lowest(ab), (xs...,) -> _realdot(ws, f(xs...)), xs...)
end
function value_and_pullback_function(
ab::AbstractBackend,
Expand Down

0 comments on commit 9bab121

Please sign in to comment.