Skip to content

Commit

Permalink
tmp5
Browse files Browse the repository at this point in the history
  • Loading branch information
jrevels committed Aug 24, 2018
1 parent dcd769e commit 8d2a2b0
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 113 deletions.
231 changes: 124 additions & 107 deletions src/derivatives/linalg/arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,150 +157,167 @@ end
# multiplication (*) #
######################

const A_MUL_B_FUNCS = ((:mul!, :*),
(:A_mul_Bt!, :A_mul_Bt), (:At_mul_B!, :At_mul_B), (:At_mul_Bt!, :At_mul_Bt),
(:A_mul_Bc!, :A_mul_Bc), (:Ac_mul_B!, :Ac_mul_B), (:Ac_mul_Bc!, :Ac_mul_Bc))
mulargvalue(x) = value(x)
mulargvalue(x::Adjoint) = adjoint(value(adjoint(x)))
mulargvalue(x::Transpose) = transpose(value(transpose(x)))

mulargpullvalue!(x) = pull_value!(x)
mulargpullvalue!(x::Adjoint) = pull_value!(adjoint(x))
mulargpullvalue!(x::Transpose) = pull_value!(transpose(x))

# recording pass #
#----------------#

for (f!, f) in A_MUL_B_FUNCS
record_f = Symbol(string("record_", f))
record_f! = Symbol(string("record_", f!))
@inline function record_mul(x, y, ::Type{D}) where D
tp = tape(x, y)
out = track(*(mulargvalue(x), mulargvalue(y)), D, tp)
cache = (similar(x, D), similar(y, D))
record!(tp, SpecialInstruction, *, (x, y), out, cache)
return out
end

@eval begin
@inline function $(record_f)(x, y, ::Type{D}) where D
tp = tape(x, y)
out = track($(f)(value(x), value(y)), D, tp)
cache = (similar(x, D), similar(y, D))
record!(tp, SpecialInstruction, $(f), (x, y), out, cache)
return out
end
@inline function record_mul!(out::TrackedArray{V,D}, x, y) where {V,D}
copyto!(mulargvalue(out), *(mulargvalue(x), mulargvalue(y)))
cache = (similar(x, D), similar(y, D))
record!(tape(x, y), SpecialInstruction, *, (x, y), out, cache)
return out
end


const TrackedVector{V,D} = TrackedArray{V,D,1}
const TrackedMatrix{V,D} = TrackedArray{V,D,2}

@inline function $(record_f!)(out::TrackedArray{V,D}, x, y) where {V,D}
copyto!(value(out), $(f)(value(x), value(y)))
cache = (similar(x, D), similar(y, D))
record!(tape(x, y), SpecialInstruction, $(f), (x, y), out, cache)
return out
for S1 in (:TrackedArray, :TrackedVector, :TrackedMatrix)
for S2 in (:TrackedArray, :TrackedVector, :TrackedMatrix)
@eval begin
LinearAlgebra.:*(x::$(S1){X,D}, y::$(S2){Y,D}) where {X,Y,D} = record_mul(x, y, D)

LinearAlgebra.:*(x::Transpose{<:Any,<:$(S1){X,D}}, y::Transpose{<:Any,<:$(S2){Y,D}}) where {X,Y,D} = record_mul(x, y, D)
LinearAlgebra.:*(x::Adjoint{<:Any,<:$(S1){X,D}}, y::Adjoint{<:Any,<:$(S2){Y,D}}) where {X,Y,D} = record_mul(x, y, D)

LinearAlgebra.:*(x::Transpose{<:Any,<:$(S1){X,D}}, y::$(S2){Y,D}) where {X,Y,D} = record_mul(x, y, D)
LinearAlgebra.:*(x::$(S1){X,D}, y::Transpose{<:Any,<:$(S2){Y,D}}) where {X,Y,D} = record_mul(x, y, D)
LinearAlgebra.:*(x::Adjoint{<:Any,<:$(S1){X,D}}, y::$(S2){Y,D}) where {X,Y,D} = record_mul(x, y, D)
LinearAlgebra.:*(x::$(S1){X,D}, y::Adjoint{<:Any,<:$(S2){Y,D}}) where {X,Y,D} = record_mul(x, y, D)

LinearAlgebra.mul!(out::TrackedArray{V,D}, x::$(S1){X,D}, y::$(S2){Y,D}) where {V,X,Y,D} = record_mul!(out, x, y)

LinearAlgebra.mul!(out::TrackedArray{V,D}, x::Transpose{<:Any,<:$(S1){X,D}}, y::Transpose{<:Any,<:$(S2){Y,D}}) where {V,X,Y,D} = record_mul!(out, x, y)
LinearAlgebra.mul!(out::TrackedArray{V,D}, x::Adjoint{<:Any,<:$(S1){X,D}}, y::Adjoint{<:Any,<:$(S2){Y,D}}) where {V,X,Y,D} = record_mul!(out, x, y)

LinearAlgebra.mul!(out::TrackedArray{V,D}, x::Transpose{<:Any,<:$(S1){X,D}}, y::$(S2){Y,D}) where {V,X,Y,D} = record_mul!(out, x, y)
LinearAlgebra.mul!(out::TrackedArray{V,D}, x::$(S1){X,D}, y::Transpose{<:Any,<:$(S2){Y,D}}) where {V,X,Y,D} = record_mul!(out, x, y)
LinearAlgebra.mul!(out::TrackedArray{V,D}, x::Adjoint{<:Any,<:$(S1){X,D}}, y::$(S2){Y,D}) where {V,X,Y,D} = record_mul!(out, x, y)
LinearAlgebra.mul!(out::TrackedArray{V,D}, x::$(S1){X,D}, y::Adjoint{<:Any,<:$(S2){Y,D}}) where {V,X,Y,D} = record_mul!(out, x, y)
end
end

@eval LinearAlgebra.$(f)(x::TrackedArray{X,D}, y::TrackedArray{Y,D}) where {X,Y,D} = $(record_f)(x, y, D)
@eval LinearAlgebra.$(f!)(out::TrackedArray{V,D}, x::TrackedArray{X,D}, y::TrackedArray{Y,D}) where {V,X,Y,D} = $(record_f!)(out, x, y)

for T in ARRAY_TYPES
@eval LinearAlgebra.$(f)(x::TrackedArray{V,D}, y::$(T)) where {V,D} = $(record_f)(x, y, D)
@eval LinearAlgebra.$(f)(x::$(T), y::TrackedArray{V,D}) where {V,D} = $(record_f)(x, y, D)
@eval LinearAlgebra.$(f!)(out::TrackedArray, x::TrackedArray, y::$(T)) = $(record_f!)(out, x, y)
@eval LinearAlgebra.$(f!)(out::TrackedArray, x::$(T), y::TrackedArray) = $(record_f!)(out, x, y)
@eval begin
LinearAlgebra.:*(x::$(S1){V,D}, y::$(T)) where {V,D} = record_mul(x, y, D)
LinearAlgebra.:*(x::$(T), y::$(S1){V,D}) where {V,D} = record_mul(x, y, D)

LinearAlgebra.:*(x::Transpose{<:Any,<:$(T)}, y::$(S1){V,D}) where {V,D} = record_mul(x, y, D)
LinearAlgebra.:*(x::$(S1){V,D}, y::Transpose{<:Any,<:$(T)}) where {V,D} = record_mul(x, y, D)
LinearAlgebra.:*(x::Adjoint{<:Any,<:$(T)}, y::$(S1){V,D}) where {V,D} = record_mul(x, y, D)
LinearAlgebra.:*(x::$(S1){V,D}, y::Adjoint{<:Any,<:$(T)}) where {V,D} = record_mul(x, y, D)

LinearAlgebra.:*(x::Transpose{<:Any,<:$(S1){V,D}}, y::$(T)) where {V,D} = record_mul(x, y, D)
LinearAlgebra.:*(x::$(T), y::Transpose{<:Any,<:$(S1){V,D}}) where {V,D} = record_mul(x, y, D)
LinearAlgebra.:*(x::Adjoint{<:Any,<:$(S1){V,D}}, y::$(T)) where {V,D} = record_mul(x, y, D)
LinearAlgebra.:*(x::$(T), y::Adjoint{<:Any,<:$(S1){V,D}}) where {V,D} = record_mul(x, y, D)

LinearAlgebra.mul!(out::TrackedArray, x::$(S1), y::$(T)) = record_mul!(out, x, y)
LinearAlgebra.mul!(out::TrackedArray, x::$(T), y::$(S1)) = record_mul!(out, x, y)

LinearAlgebra.mul!(out::TrackedArray, x::$(S1), y::Transpose{<:Any,<:$(T)}) = record_mul!(out, x, y)
LinearAlgebra.mul!(out::TrackedArray, x::Transpose{<:Any,<:$(T)}, y::$(S1)) = record_mul!(out, x, y)
LinearAlgebra.mul!(out::TrackedArray, x::$(S1), y::Adjoint{<:Any,<:$(T)}) = record_mul!(out, x, y)
LinearAlgebra.mul!(out::TrackedArray, x::Adjoint{<:Any,<:$(T)}, y::$(S1)) = record_mul!(out, x, y)

LinearAlgebra.mul!(out::TrackedArray, x::Transpose{<:Any,<:$(S1)}, y::$(T)) = record_mul!(out, x, y)
LinearAlgebra.mul!(out::TrackedArray, x::$(T), y::Transpose{<:Any,<:$(S1)}) = record_mul!(out, x, y)
LinearAlgebra.mul!(out::TrackedArray, x::Adjoint{<:Any,<:$(S1)}, y::$(T)) = record_mul!(out, x, y)
LinearAlgebra.mul!(out::TrackedArray, x::$(T), y::Adjoint{<:Any,<:$(S1)}) = record_mul!(out, x, y)
end
end
end

# forward pass #
#--------------#

for (f!, f) in A_MUL_B_FUNCS
@eval begin
@noinline function special_forward_exec!(instruction::SpecialInstruction{typeof($f)})
a, b = instruction.input
pull_value!(a)
pull_value!(b)
$(f!)(value(instruction.output), value(a), value(b))
return nothing
end
end
@noinline function special_forward_exec!(instruction::SpecialInstruction{typeof(*)})
a, b = instruction.input
mulargpullvalue!(a)
mulargpullvalue!(b)
mul!(mulargvalue(instruction.output), mulargvalue(a), mulargvalue(b))
return nothing
end

# reverse pass #
#--------------#

### *

@noinline function special_reverse_exec!(instruction::SpecialInstruction{typeof(*)})
a, b = instruction.input
a_tmp, b_tmp = instruction.cache
output = instruction.output
output_deriv = deriv(output)
istracked(a) && increment_deriv!(a, A_mul_Bc!(a_tmp, output_deriv, value(b)))
istracked(b) && increment_deriv!(b, Ac_mul_B!(b_tmp, value(a), output_deriv))
reverse_mul!(output, output_deriv, a, b, a_tmp, b_tmp)
unseed!(output)
return nothing
end

### A_mul_Bt
# a * b

@noinline function special_reverse_exec!(instruction::SpecialInstruction{typeof(A_mul_Bt)})
a, b = instruction.input
a_tmp, b_tmp = instruction.cache
output = instruction.output
output_deriv = deriv(output)
istracked(a) && increment_deriv!(a, A_mul_B!(a_tmp, output_deriv, value(b)))
istracked(b) && increment_deriv!(b, At_mul_B!(b_tmp, output_deriv, value(a)))
unseed!(output)
return nothing
function reverse_mul!(output, output_deriv, a, b, a_tmp, b_tmp)
istracked(a) && increment_deriv!(a, mul!(a_tmp, output_deriv, transpose(value(b))))
istracked(b) && increment_deriv!(b, mul!(b_tmp, transpose(value(a)), output_deriv))
end

### At_mul_B

@noinline function special_reverse_exec!(instruction::SpecialInstruction{typeof(At_mul_B)})
a, b = instruction.input
a_tmp, b_tmp = instruction.cache
output = instruction.output
output_deriv = deriv(output)
istracked(a) && increment_deriv!(a, A_mul_Bt!(a_tmp, value(b), output_deriv))
istracked(b) && increment_deriv!(b, A_mul_B!(b_tmp, value(a), output_deriv))
unseed!(output)
return nothing
end

### At_mul_Bt

@noinline function special_reverse_exec!(instruction::SpecialInstruction{typeof(At_mul_Bt)})
a, b = instruction.input
a_tmp, b_tmp = instruction.cache
output = instruction.output
output_deriv = deriv(output)
istracked(a) && increment_deriv!(a, At_mul_Bt!(a_tmp, value(b), output_deriv))
istracked(b) && increment_deriv!(b, At_mul_Bt!(b_tmp, output_deriv, value(a)))
unseed!(output)
return nothing
end

### A_mul_Bc

@noinline function special_reverse_exec!(instruction::SpecialInstruction{typeof(A_mul_Bc)})
a, b = instruction.input
a_tmp, b_tmp = instruction.cache
output = instruction.output
output_deriv = deriv(output)
istracked(a) && increment_deriv!(a, A_mul_B!(a_tmp, output_deriv, value(b)))
istracked(b) && increment_deriv!(b, Ac_mul_B!(b_tmp, output_deriv, value(a)))
unseed!(output)
return nothing
for (f, F) in ((:transpose, :Transpose), (:adjoint, :Adjoint))
@eval begin
# a * f(b)
function reverse_mul!(output, output_deriv, a, b::$F, a_tmp, b_tmp)
_b = ($f)(b)
istracked(a) && increment_deriv!(a, mul!(a_tmp, output_deriv, mulargvalue(b)))
istracked(_b) && increment_deriv!(_b, ($f)(mul!(b_tmp, ($f)(output_deriv), value(a))))
end
# f(a) * b
function reverse_mul!(output, output_deriv, a::$F, b, a_tmp, b_tmp)
_a = ($f)(a)
istracked(_a) && increment_deriv!(_a, ($f)(mul!(a_tmp, value(b), ($f)(output_deriv))))
istracked(b) && increment_deriv!(b, mul!(b_tmp, mulargvalue(a), output_deriv))
end
# f(a) * f(b)
function reverse_mul!(output, output_deriv, a::$F, b::$F, a_tmp, b_tmp)
_a = ($f)(a)
_b = ($f)(b)
istracked(_a) && increment_deriv!(_a, ($f)(mul!(a_tmp, ($f)(mulargvalue(b)), ($f)(output_deriv))))
istracked(_b) && increment_deriv!(_b, ($f)(mul!(b_tmp, ($f)(output_deriv), ($f)(mulargvalue(a)))))
end
end
end

### Ac_mul_B
# adjoint(a) * transpose(b)

@noinline function special_reverse_exec!(instruction::SpecialInstruction{typeof(Ac_mul_B)})
a, b = instruction.input
a_tmp, b_tmp = instruction.cache
output = instruction.output
output_deriv = deriv(output)
istracked(a) && increment_deriv!(a, A_mul_Bc!(a_tmp, value(b), output_deriv))
istracked(b) && increment_deriv!(b, A_mul_B!(b_tmp, value(a), output_deriv))
unseed!(output)
return nothing
function reverse_mul!(output, output_deriv, a::Adjoint, b::Transpose, a_tmp, b_tmp)
_a = adjoint(a)
_b = transpose(b)
if istracked(_a)
reverse_mul!(output, output_deriv, transpose(_a), b, a_tmp, b_tmp)
elseif istracked(_b)
increment_deriv!(_b, transpose(mul!(b_tmp, adjoint(output_deriv), adjoint(mulargvalue(a)))))
end
end

### Ac_mul_Bc
# transpose(a) * adjoint(b)

@noinline function special_reverse_exec!(instruction::SpecialInstruction{typeof(Ac_mul_Bc)})
a, b = instruction.input
a_tmp, b_tmp = instruction.cache
output = instruction.output
output_deriv = deriv(output)
istracked(a) && increment_deriv!(a, Ac_mul_Bc!(a_tmp, value(b), output_deriv))
istracked(b) && increment_deriv!(b, Ac_mul_Bc!(b_tmp, output_deriv, value(a)))
unseed!(output)
return nothing
function reverse_mul!(output, output_deriv, a::Transpose, b::Adjoint, a_tmp, b_tmp)
_a = transpose(a)
_b = adjoint(b)
if istracked(_b)
reverse_mul!(output, output_deriv, a, transpose(_b), a_tmp, b_tmp)
elseif istracked(_a)
increment_deriv!(_a, transpose(mul!(a_tmp, adjoint(mulargvalue(b)), adjoint(output_deriv))))
end
end
4 changes: 2 additions & 2 deletions src/derivatives/linalg/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ end
output = instruction.output
output_value, output_deriv = value(output), deriv(output)
output_tmp1, output_tmp2 = instruction.cache
A_mul_Bc!(output_tmp1, output_deriv, output_value)
Ac_mul_B!(output_tmp2, output_value, output_tmp1)
mul!(output_tmp1, output_deriv, adjoint(output_value))
mul!(output_tmp2, adjoint(output_value), output_tmp1)
decrement_deriv!(instruction.input, output_tmp2)
unseed!(output)
return nothing
Expand Down
27 changes: 23 additions & 4 deletions test/derivatives/LinAlgTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,29 @@ for f in (+, -)
test_arr2arr(f, a, b, tp)
end

for (f!, f) in ReverseDiff.A_MUL_B_FUNCS
test_println("A_mul_B functions", f)
test_arr2arr(eval(LinearAlgebra, f), a, b, tp)
test_arr2arr_inplace(eval(LinearAlgebra, f!), eval(LinearAlgebra, f), x, a, b, tp)
test_println("*(A, B) functions", "*(a, b)")

test_arr2arr(*, a, b, tp)
test_arr2arr_inplace(mul!, *, x, a, b, tp)

for f in (transpose, adjoint)
test_println("*(A, B) functions", string("*(", f, "(a), b)"))
test_arr2arr(*, f(a), b, tp)
test_arr2arr_inplace(mul!, *, x, f(a), b, tp)
test_println("*(A, B) functions", string("*(a, ", f, "(b))"))
test_arr2arr(*, a, f(b), tp)
test_arr2arr_inplace(mul!, *, x, a, f(b), tp)
test_println("*(A, B) functions", string("*(", f, "(a), ", f, "(b))"))
test_arr2arr(*, f(a), f(b), tp)
test_arr2arr_inplace(mul!, *, x, f(a), f(b), tp)
end

test_println("*(A, B) functions", "*(adjoint(a), transpose(b))")
test_arr2arr(*, adjoint(a), transpose(b), tp)
test_arr2arr_inplace(mul!, *, x, adjoint(a), transpose(b), tp)

test_println("*(A, B) functions", "*(transpose(a), adjoint(b))")
test_arr2arr(*, transpose(a), adjoint(b), tp)
test_arr2arr_inplace(mul!, *, x, transpose(a), adjoint(b), tp)

end # module

0 comments on commit 8d2a2b0

Please sign in to comment.