Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mul/ewise rules for basic arithmetic semiring #26

Merged
merged 20 commits into from
Jul 11, 2021
Merged

Conversation

rayegun
Copy link
Member

@rayegun rayegun commented Jul 6, 2021

I removed some I'm still testing to get feedback on these and avoid a monster PR.

Notes:

  1. You need to call test_*rule with check_inferred=false. Issue Output eltype inference is not type stable #25 will fix.
  2. Missing the kwargs. I want to get everything working first before trying those.
  3. rrules are incorrect according to tests. Some of these are just floating point issues. However for mul there's a deeper issue. I'm 85-90% sure the rules are correct, but the patterns are not the same as for FiniteDifferences, and occasionally there's different values.

@codecov-commenter

This comment has been minimized.

@rayegun rayegun requested a review from mzgubic July 6, 2021 06:10
@rayegun
Copy link
Member Author

rayegun commented Jul 6, 2021

I apologize for the messy PR, the only important parts are in the tests and chainrules folders.

I'm primarily interested in your thoughts about the rrules for mul, and in particular whether I'm wrong, FiniteDifferences is wrong, or I just haven't given FiniteDifferences the right information.

Everything works fine for dense. For sparse inputs though there's two problems:

  1. ∂A has a different sparsity pattern, and thus different values where the sparsity is different.
  2. ∂B is straight up wrong according to FiniteDifferences.

@mzgubic

Copy link
Collaborator

@mzgubic mzgubic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't manage to finish it all today, but here's a few comments

Project.toml Outdated Show resolved Hide resolved
Project.toml Show resolved Hide resolved
src/chainrules/chainruleutils.jl Show resolved Hide resolved
src/chainrules/chainruleutils.jl Show resolved Hide resolved
test/runtests.jl Outdated
@@ -14,4 +14,5 @@ println("Testing SuiteSparseGraphBLAS.jl")
@testset "SuiteSparseGraphBLAS" begin
include_test("gbarray.jl")
include_test("operations.jl")
include_test("testrules.jl")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually the structure of test folder mirrors the src/ folder, which makes it easier to find things when the package grows.

src/chainrules/ewiserules.jl Show resolved Hide resolved
@mzgubic
Copy link
Collaborator

mzgubic commented Jul 7, 2021

Everything works fine for dense. For sparse inputs though there's two problems:
∂A has a different sparsity pattern, and thus different values where the sparsity is different.
∂B is straight up wrong according to FiniteDifferences.

I think the underlying issue is the same (and also the same one as in the elementwise rules). What it comes down is an instance of the "array dilemma", discussed in great detail over many issues and PRs. See JuliaDiff/ChainRulesCore.jl#347 (and related issues) for a discussion, but I warn you, it is a rabbit hole ;)

Essentially what it comes down to is whether you think of the input, say A::GBMatrix as an efficient representation of an array that just happens to be sparse, or whether you think of it as a struct. Consider y = A * B, where B::Matrix is a dense array, and y is therefore dense as well.

Primal computation will be fast because A is sparse. In fact A was probably chosen to be GBMatrix solely to get
that speedup. What happens in the backward pass depends on how you interpret A: an array, or a struct?

if you interpret it as an array, the dA = mul(ΔΩ, B') will be dense and you lose all the benefits of the speedup, but dA will match the dA you would have gotten with a dense A with the zeros in the same place as structural zeros of the sparse A.

If you interpret it as a struct, meaning that the zeros are structural, it doesn't make sense to compute the tangents to all the zeros, and you can compute the backward pass efficiently. Since dA for sparse A is sparse in this case, it is somewhat unintuitive that it is different to the dA that would be obtained if A was a dense array with the zeros in the same place.

Long story short, we are treating them as structs now in order to not completely kill efficiency. We should probably treat them as structs here as well.

Aside: projection, merged recently, was a way to make sure rules with abstractly typed arguments still return the correct tangent type. The classic example is Diagonal * Matrix where we project the dense gradient onto the Diagonal.


In this case, as you point out, masking is all we need to do, since we are writing dedicated rules for GBMatrix multiplication.

@rayegun rayegun merged commit f0dd5c9 into master Jul 11, 2021
@rayegun rayegun deleted the arithmeticchains branch July 11, 2021 02:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants