Skip to content

Commit

Permalink
Merge da5b4e7 into 1df670e
Browse files Browse the repository at this point in the history
  • Loading branch information
ararslan committed Jun 19, 2019
2 parents 1df670e + da5b4e7 commit 39ffe44
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 41 deletions.
52 changes: 33 additions & 19 deletions src/rules/linalg/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,23 @@ end
##### `BLAS.gemv`
#####

function rrule(::typeof(BLAS.gemv), tA, α, A, x)
Ω = BLAS.gemv(tA, α, A, x)
∂α = ΔΩ -> dot(ΔΩ, Ω) / α
∂A = ΔΩ -> uppercase(tA) == 'N' ? α * ΔΩ * x' : α * x * ΔΩ'
∂x = ΔΩ -> gemv(uppercase(tA) == 'N' ? 'T' : 'N', α, A, ΔΩ)
return Ω, (DNERule(), _rule_via(∂α), _rule_via(∂A), _rule_via(∂x))
function rrule(::typeof(gemv), tA::Char, α::T, A::AbstractMatrix{T},
x::AbstractVector{T}) where T<:BlasFloat
y = gemv(tA, α, A, x)
if uppercase(tA) === 'N'
∂A = Rule(ȳ -> α ** x', (Ā, ȳ) -> ger!(α, ȳ, x, Ā))
∂x = Rule(ȳ -> gemv('T', α, A, ȳ), (x̄, ȳ) -> gemv!('T', α, A, ȳ, one(T), x̄))
else
∂A = Rule(ȳ -> α * x *', (Ā, ȳ) -> ger!(α, x, ȳ, Ā))
∂x = Rule(ȳ -> gemv('N', α, A, ȳ), (x̄, ȳ) -> gemv!('N', α, A, ȳ, one(T), x̄))
end
return y, (DNERule(), Rule(ȳ -> dot(ȳ, y) / α), ∂A, ∂x)
end

function rrule(f::typeof(BLAS.gemv), tA, A, x)
Ω, (dtA, dα, dA, dx) = rrule(f, tA, one(eltype(A)), A, x)
return Ω, (dtA, dA, dx)
function rrule(::typeof(gemv), tA::Char, A::AbstractMatrix{T},
x::AbstractVector{T}) where T<:BlasFloat
y, (dtA, _, dA, dx) = rrule(gemv, tA, one(T), A, x)
return y, (dtA, dA, dx)
end

#####
Expand All @@ -82,25 +88,33 @@ end
function rrule(::typeof(gemm), tA::Char, tB::Char, α::T,
A::AbstractMatrix{T}, B::AbstractMatrix{T}) where T<:BlasFloat
C = gemm(tA, tB, α, A, B)
∂α = -> sum(C̄ .* C) / α
β = one(T)
if uppercase(tA) === 'N'
if uppercase(tB) === 'N'
∂A =-> gemm('N', 'T', α, C̄, B)
∂B =-> gemm('T', 'N', α, A, C̄)
∂A = Rule(C̄ -> gemm('N', 'T', α, C̄, B),
(Ā, C̄) -> gemm!('N', 'T', α, C̄, B, β, Ā))
∂B = Rule(C̄ -> gemm('T', 'N', α, A, C̄),
(B̄, C̄) -> gemm!('T', 'N', α, A, C̄, β, B̄))
else
∂A =-> gemm('N', 'N', α, C̄, B)
∂B =-> gemm('T', 'N', α, C̄, A)
∂A = Rule(C̄ -> gemm('N', 'N', α, C̄, B),
(Ā, C̄) -> gemm!('N', 'N', α, C̄, B, β, Ā))
∂B = Rule(C̄ -> gemm('T', 'N', α, C̄, A),
(B̄, C̄) -> gemm!('T', 'N', α, C̄, A, β, B̄))
end
else
if uppercase(tB) === 'N'
∂A =-> gemm('N', 'T', α, B, C̄)
∂B =-> gemm('N', 'N', α, A, C̄)
∂A = Rule(C̄ -> gemm('N', 'T', α, B, C̄),
(Ā, C̄) -> gemm!('N', 'T', α, B, C̄, β, Ā))
∂B = Rule(C̄ -> gemm('N', 'N', α, A, C̄),
(B̄, C̄) -> gemm!('N', 'N', α, A, C̄, β, B̄))
else
∂A =-> gemm('T', 'T', α, B, C̄)
∂B =-> gemm('T', 'T', α, C̄, A)
∂A = Rule(C̄ -> gemm('T', 'T', α, B, C̄),
(Ā, C̄) -> gemm!('T', 'T', α, B, C̄, β, Ā))
∂B = Rule(C̄ -> gemm('T', 'T', α, C̄, A),
(B̄, C̄) -> gemm!('T', 'T', α, C̄, A, β, B̄))
end
end
return C, (DNERule(), DNERule(), _rule_via(∂α), _rule_via(∂A), _rule_via(∂B))
return C, (DNERule(), DNERule(), Rule(C̄ -> dot(C̄, C) / α), ∂A, ∂B)
end

function rrule(::typeof(gemm), tA::Char, tB::Char,
Expand Down
29 changes: 15 additions & 14 deletions test/rules/blas.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using LinearAlgebra.BLAS: gemm

@testset "BLAS" begin
@testset "gemm" begin
rng = MersenneTwister(1)
Expand All @@ -9,18 +7,21 @@ using LinearAlgebra.BLAS: gemm
A = randn(rng, tA === 'N' ? (m, n) : (n, m))
B = randn(rng, tB === 'N' ? (n, p) : (p, n))
C = gemm(tA, tB, α, A, B)
fAB, (dtA, dtB, dα, dA, dB) = rrule(gemm, tA, tB, α, A, B)
@test C fAB
@test dtA isa ChainRules.DNERule
@test dtB isa ChainRules.DNERule
for (f, x, dx) in [(X->gemm(tA, tB, X, A, B), α, dα),
(X->gemm(tA, tB, α, X, B), A, dA),
(X->gemm(tA, tB, α, A, X), B, dB)]
= randn(rng, size(C)...)
x̄_ad = dx(ȳ)
x̄_fd = j′vp(central_fdm(5, 1), f, ȳ, x)
@test x̄_ad x̄_fd rtol=1e-9 atol=1e-9
end
= randn(rng, size(C)...)
rrule_test(gemm, ȳ, (tA, nothing), (tB, nothing), (α, randn(rng)),
(A, randn(rng, size(A))), (B, randn(rng, size(B))))
end
end
@testset "gemv" begin
rng = MersenneTwister(2)
for n in 3:5, m in 3:5, t in ('N', 'T')
α = randn(rng)
A = randn(rng, m, n)
x = randn(rng, t === 'N' ? n : m)
y = α * (t === 'N' ? A : A') * x
= randn(rng, size(y)...)
rrule_test(gemv, ȳ, (t, nothing), (α, randn(rng)), (A, randn(rng, size(A))),
(x, randn(rng, size(x))))
end
end
end
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# TODO: more tests!

using ChainRules, Test, FDM, LinearAlgebra, Random
using ChainRules, Test, FDM, LinearAlgebra, LinearAlgebra.BLAS, Random
using ChainRules: extern, accumulate, accumulate!, store!, @scalar_rule,
Wirtinger, wirtinger_primal, wirtinger_conjugate, add_wirtinger, mul_wirtinger,
Zero, add_zero, mul_zero, One, add_one, mul_one, Casted, cast, add_casted, mul_casted,
DNE, Thunk, Casted, DNERule
using Base.Broadcast: broadcastable
import LinearAlgebra: dot

include("test_util.jl")

Expand Down
14 changes: 7 additions & 7 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,41 +119,41 @@ function Base.isapprox(d_ad::Thunk, d_fd; kwargs...)
end

function test_accumulation(x̄, dx, ȳ, partial)
@test all(extern(ChainRules.add(x̄, partial)) .== extern(x̄) .+ extern(partial))
@test all(extern(ChainRules.add(x̄, partial)) . extern(x̄) .+ extern(partial))
test_accumulate(x̄, dx, ȳ, partial)
test_accumulate!(x̄, dx, ȳ, partial)
test_store!(x̄, dx, ȳ, partial)
return nothing
end

function test_accumulate(x̄::Zero, dx, ȳ, partial)
@test extern(accumulate(x̄, dx, ȳ)) == extern(partial)
@test extern(accumulate(x̄, dx, ȳ)) extern(partial)
return nothing
end

function test_accumulate(x̄::Number, dx, ȳ, partial)
@test extern(accumulate(x̄, dx, ȳ)) == extern(x̄) + extern(partial)
@test extern(accumulate(x̄, dx, ȳ)) extern(x̄) + extern(partial)
return nothing
end

function test_accumulate(x̄::AbstractArray, dx, ȳ, partial)
x̄_old = copy(x̄)
@test all(extern(accumulate(x̄, dx, ȳ)) .== (extern(x̄) .+ extern(partial)))
@test all(extern(accumulate(x̄, dx, ȳ)) . (extern(x̄) .+ extern(partial)))
@test== x̄_old
return nothing
end

test_accumulate!(x̄::Zero, dx, ȳ, partial) = nothing

function test_accumulate!(x̄::Number, dx, ȳ, partial)
@test accumulate!(x̄, dx, ȳ) == accumulate(x̄, dx, ȳ)
@test accumulate!(x̄, dx, ȳ) accumulate(x̄, dx, ȳ)
return nothing
end

function test_accumulate!(x̄::AbstractArray, dx, ȳ, partial)
x̄_copy = copy(x̄)
accumulate!(x̄_copy, dx, ȳ)
@test extern(x̄_copy) == (extern(x̄) .+ extern(partial))
@test extern(x̄_copy) (extern(x̄) .+ extern(partial))
return nothing
end

Expand All @@ -163,6 +163,6 @@ test_store!(x̄::Number, dx, ȳ, partial) = nothing
function test_store!(x̄::AbstractArray, dx, ȳ, partial)
x̄_copy = copy(x̄)
store!(x̄_copy, dx, ȳ)
@test all(x̄_copy .== extern(partial))
@test all(x̄_copy . extern(partial))
return nothing
end

0 comments on commit 39ffe44

Please sign in to comment.