Skip to content

Commit

Permalink
Add rrules for binary linear algebra operations
Browse files Browse the repository at this point in the history
  • Loading branch information
ararslan committed Jun 10, 2019
1 parent 0e551fb commit ffce382
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[compat]
Cassette = "^0.2"
FDM = "^0.5"
FDM = "^0.6"
julia = "^1.0"

[extras]
Expand Down
46 changes: 46 additions & 0 deletions src/rules/linalg/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,49 @@ end
frule(::typeof(tr), x) = (tr(x), Rule(Δx -> tr(extern(Δx))))

rrule(::typeof(tr), x) = (tr(x), Rule(ΔΩ -> Diagonal(fill(ΔΩ, size(x, 1)))))

#####
##### `*`
#####

function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real})
return A * B, (Rule(Ȳ ->* B'), Rule(Ȳ -> A' * Ȳ))
end

#####
##### `/`
#####

function rrule(::typeof(/), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real})
Y = A / B
∂A = Rule(Ȳ ->/ B')
∂B = Rule(Ȳ -> -Y' * (Ȳ / B'))
return Y, (∂A, ∂B)
end

#####
##### `\`
#####

function rrule(::typeof(\), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real})
Y = A \ B
∂A = Rule(Ȳ -> -(A' \ Ȳ) * Y')
∂B = Rule(Ȳ -> A' \ Ȳ)
return Y, (∂A, ∂B)
end

#####
##### `norm`
#####

function rrule(::typeof(norm), A::AbstractArray{<:Real}, p::Real=2)
y = norm(A, p)
u = y^(1-p)
∂A = Rule(ȳ ->.* u .* abs.(A).^p ./ A)
∂p = Rule(ȳ ->* (u * sum(a->abs(a)^p * log(abs(a)), A) - y * log(y)) / p)
return y, (∂A, ∂p)
end

function rrule(::typeof(norm), x::Real, p::Real=2)
return norm(x, p), (Rule(ȳ ->* sign(x)), Rule(_ -> zero(x)))
end
30 changes: 30 additions & 0 deletions test/rules/linalg/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,34 @@ end
frule_test(tr, (randn(rng, N, N), randn(rng, N, N)))
rrule_test(tr, randn(rng), (randn(rng, N, N), randn(rng, N, N)))
end
@testset "*" begin
rng = MersenneTwister(123456)
dims = [3,4,5]
for n in dims, m in dims, p in dims
n > 3 && n == m == p && continue # don't need to test square case multiple times
A = randn(rng, m, n)
B = randn(rng, n, p)
= randn(rng, m, p)
rrule_test(*, Ȳ, (A, randn(rng, m, n)), (B, randn(rng, n, p)))
end
end
@testset "$f" for f in [/, \]
rng = MersenneTwister(420)
for n in 3:5, m in 3:5
n == m || continue # TODO: Incorporate fixes for rectangular matrices
A = randn(rng, m, n)
B = randn(rng, m, n)
= randn(rng, size(f(A, B)))
rrule_test(f, Ȳ, (A, randn(rng, m, n)), (B, randn(rng, m, n)), rtol=1e-6, atol=1e-6)
end
end
@testset "norm" begin
rng = MersenneTwister(3)
for dims in [(), (5,), (3, 2), (7, 3, 2)]
A = randn(rng, dims...)
p = randn(rng)
= randn(rng)
rrule_test(norm, ȳ, (A, randn(rng, dims...)), (p, randn(rng)))
end
end
end

0 comments on commit ffce382

Please sign in to comment.