Skip to content

Commit

Permalink
Test thunked inputs to SVD
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Oct 13, 2020
1 parent f9f3e0c commit 45bbb57
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using LinearAlgebra.BLAS: gemv, gemv!, gemm!, trsm!, axpy!, ger!
function rrule(::typeof(svd), X::AbstractMatrix{<:Real})
F = svd(X)
function svd_pullback::Composite)
# svd_rev does a lot of linear algebra, it it is efficient to unthunk before
# svd_rev does a lot of linear algebra, it is is efficient to unthunk before
∂X = svd_rev(F, unthunk.U), unthunk.S), unthunk.V))
return (NO_FIELDS, ∂X)
end
Expand Down
24 changes: 22 additions & 2 deletions test/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo
@test dself1 === NO_FIELDS
@test dp === DoesNotExist()

ΔF = unthunk(dF)
dself2, dX = dX_pullback(ΔF)
dself2, dX = dX_pullback(dF)
@test dself2 === NO_FIELDS
X̄_ad = unthunk(dX)
X̄_fd = only(j′vp(central_fdm(5, 1), X->getproperty(svd(X), p), Ȳ, X))
@test all(isapprox.(X̄_ad, X̄_fd; rtol=1e-6, atol=1e-6))

end
@testset "Vt" begin
Y, dF_pullback = rrule(getproperty, F, :Vt)
Expand All @@ -27,6 +27,26 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo
end
end

@testset "Thunked inputs" begin
X = randn(4, 3)
F, dX_pullback = rrule(svd, X)
for p in [:U, :S, :V]
Y, dF_pullback = rrule(getproperty, F, p)
= randn(size(Y)...)

_, dF_unthunked, _ = dF_pullback(Ȳ)

@assert !(getproperty(dF_unthunked, p) isa AbstractThunk)
dF_thunked = map(f->Thunk(()->f), dF_unthunked)
@assert getproperty(dF_thunked, p) isa AbstractThunk

dself_thunked, dX_thunked = dX_pullback(dF_thunked)
dself_unthunked, dX_unthunked = dX_pullback(dF_unthunked)
@test dself_thunked == dself_unthunked
@test dX_thunked == dX_unthunked
end
end

@testset "+" begin
X = [1.0 2.0; 3.0 4.0; 5.0 6.0]
F, dX_pullback = rrule(svd, X)
Expand Down

0 comments on commit 45bbb57

Please sign in to comment.