Skip to content

Commit

Permalink
Improvements to pullback_function (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
sethaxen committed Oct 11, 2021
1 parent 3eb3fd1 commit b4343f5
Showing 1 changed file with 13 additions and 17 deletions.
30 changes: 13 additions & 17 deletions src/AbstractDifferentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,31 +221,27 @@ _zero(::AbstractVector, d::AbstractMatrix) = zero(similar(d, size(d, 2)))
_zero(::AbstractMatrix, d::AbstractMatrix) = zero(d)
_zero(::Any, d::Any) = zero(d)

@inline _dot(x, y) = dot(x, y)
@inline function _dot(x::AbstractVector, y::UniformScaling)
@assert length(x) == 1
return @inbounds dot(x[1], y.λ)
end
@inline function _dot(x::AbstractVector, y::AbstractMatrix)
@assert size(y, 2) == 1
return dot(x, y)
end

function pullback_function(ab::AbstractBackend, f, xs...)
return (ws) -> begin
jacs = jacobian(lowest(ab), (xs...,) -> begin
return gradient(lowest(ab), (xs...,) -> begin
vs = f(xs...)
if ws isa Tuple
@assert length(vs) == length(ws)
return sum(zip(vs, ws)) do v, w
if w isa Union{AbstractMatrix, UniformScaling} && v isa AbstractVector
return w' * v
else
# for arbitrary arrays
return dot(w, v)
end
end
return sum(Base.splat(_dot), zip(ws, vs))
else
w, v = ws, vs
if w isa Union{AbstractMatrix, UniformScaling} && v isa AbstractVector
return w' * v
else
# for arbitrary arrays
return dot(w, v)
end
return _dot(vs, ws)
end
end, xs...)
return adjoint.(jacs)
end
end
function value_and_pullback_function(
Expand Down

0 comments on commit b4343f5

Please sign in to comment.