Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rule for triangular solves #1264

Merged
merged 23 commits into from
Feb 23, 2024
Merged

Conversation

simsurace
Copy link
Contributor

So this is a first attempt by a complete beginner at this.
One issue is that the tests using EnzymeTestUtils fail because they seem to compare the data field of the cotangents:

A = UpperTriangular(rand(2, 2))
B = copy(A); B.data[2, 1] = 0
EnzymeTestUtils.test_approx(A, B) # Errors because the `data` field is not the same

The remaining tests currently pass, but they are not covering all cases.

@sethaxen
Copy link
Collaborator

sethaxen commented Feb 1, 2024

Actually, EnzymeTestUtils is doing the right thing. You don't want to end up pulling back a nonzero cotangent on the unused triangle, which is what you're doing here, since next Enzyme will pull back the cotangent through the constructor of the triangular matrix, and this pullback just extracts the data. So call tril! or triu! (whichever is correct) before constructing the cotangent.

@simsurace
Copy link
Contributor Author

Actually, EnzymeTestUtils is doing the right thing. You don't want to end up pulling back a nonzero cotangent on the unused triangle, which is what you're doing here, since next Enzyme will pull back the cotangent through the constructor of the triangular matrix, and this pullback just extracts the data. So call tril! or triu! (whichever is correct) before constructing the cotangent.

I think it relates to my question above, but if I'm assigning to the shadow with the UpperTriangular or LowerTriangular wrapper, I'm should not be changing the unused triangle.

A = UpperTriangular([1 2; 3 4])
B = UpperTriangular([0 0; 1 0])
A .+= B
A.data == [1 2; 3 4] # true

So I'm not sure where the non-zero entries come from. So your suggestion is to add a call to triu! after the assignment?

@sethaxen
Copy link
Collaborator

sethaxen commented Feb 1, 2024

Ah, I see now you're wrapping in AbstractTriangular before decrementing. No, that shouldn't be the problem, your rule is fine. The issue seems to be caused by https://github.com/JuliaDiff/FiniteDifferences.jl/blob/0766bbc7b81381b134835d31145ad4822fd7e65f/src/to_vec.jl#L98-L104, specifically, when FiniteDifferences.jl vectorizes a triangular matrix, it ignores the lower triangle. This is fine if one only ever interacts with the data via the abstract matrix API. If one ever accesses the unused triangle of data directly, then this will do the wrong thing.

We can show this by pirating this method. This overload causes tests to pass for me:

function FiniteDifferences.to_vec(x::T) where {T<:LinearAlgebra.AbstractTriangular}
    x_vec, back = FiniteDifferences.to_vec(parent(x))
    function AbstractTriangular_from_vec(x_vec)
        return T(reshape(back(x_vec), size(x)))
    end
    return x_vec, AbstractTriangular_from_vec
end

We could pirate this method here, but it's not the ideal way of handling this. An alternative is to manually create the tangents (i.e. construct the Duplicated yourself and pass to test_reverse), instantiating the unused triangle of the shadow of A.data as zeros.

A long-term solution would be for EnzymeTestUtils to roll its own to_vec alternative that always unwraps everything to scalars or Arrays.

src/internal_rules.jl Outdated Show resolved Hide resolved
@simsurace
Copy link
Contributor Author

We could pirate this method here, but it's not the ideal way of handling this. An alternative is to manually create the tangents (i.e. construct the Duplicated yourself and pass to test_reverse), instantiating the unused triangle of the shadow of A.data as zeros.

Sorry, but who is writing to that unused triangle? If I understand your explanation correctly, it is not FiniteDifferences.

@sethaxen
Copy link
Collaborator

sethaxen commented Feb 1, 2024

Sorry, but who is writing to that unused triangle? If I understand your explanation correctly, it is not FiniteDifferences.

No-one's writing to it. I'll explain. We want to test that a given rule only overwrites the shadow of an argument if that argument is itself overwritten by the primal. In all other cases, it should just increment the shadow. The way to test this is to start with a shadow filled with random numbers. There's then some complexity with how we get FiniteDifferences to work with this, since it doesn't support mutating functions. In short, we replace the mutating function with one that copies all arguments first, and then all arguments are also returned as outputs. Then for these returned arguments we pull back as cotangents the exact same random shadows we use with Enzyme. This causes FD to increment these random cotangents in the same way Enzyme should. It's a little complicated and happens here:

#=
_wrap_reverse_function(f, xs, ignores)
Return a new version of `f`, `fnew`, that ignores some of the arguments `xs` and returns
also non-ignored arguments.
All arguments are copied before being passed to `f`, so that `fnew` is non-mutating.
# Arguments
- `f`: The function to be wrapped.
- `xs`: Inputs to `f`, such that `y = f(xs...)`.
- `ignores`: Collection of `Bool`s, the same length as `xs`.
If `ignores[i] === true`, then `xs[i]` is ignored; `∂xs[i] === NoTangent()`.
=#
function _wrap_reverse_function(active_return, f, xs, ignores)
function fnew(sigargs...)
callargs = Any[]
retargs = Any[]
j = 1
for (i, (x, ignore)) in enumerate(zip(xs, ignores))
if ignore
push!(callargs, deepcopy(x))
else
arg = deepcopy(sigargs[j])
push!(callargs, arg)
push!(retargs, arg)
j += 1
end
end
@assert j == length(sigargs) + 1
@assert length(callargs) == length(xs)
@assert length(retargs) == count(!, ignores)
# if an arg and a return alias, do not consider the contribution from the arg as returned here,
# it will already be taken into account. This is implemented using the deepcopy_internal, which
# will add all objects inside the return into the dict `zeros`.
zeros = IdDict()
origRet = Base.deepcopy_internal(deepcopy(f)(callargs...), zeros)
# we will now explicitly zero all objects returned, and replace any of the args with this
# zero, if the input and output alias.
if active_return
for k in keys(zeros)
zeros[k] = zero_tangent(k)
end
end
return (origRet, Base.deepcopy_internal(retargs, zeros)...)
end
return fnew
end

But in this case, FD is intentionally ignoring the unused triangle stored by the triangular matrix, which is a problem for us, because a user need not ignore it, and Enzyme won't either. So FD is not pulling back the random values in the unused triangle, and that causes Enzyme and FD to disagree. This is why either modifying FD to pull back these values as I did above or using zeros instead of random values will work just fine.

@simsurace
Copy link
Contributor Author

Ok but it sounds like if we want to test both aspects we need to keep the random vectors. Looks like FD is providing the wrong abstractions for this. Instead of pirating to_vec, why don't we introduce interfaces that are better geared towards Enzyme?

@simsurace
Copy link
Contributor Author

I think I got it wrong. You meant to initialize randomly but only on the used triangle?

@sethaxen
Copy link
Collaborator

sethaxen commented Feb 1, 2024

Ok but it sounds like if we want to test both aspects we need to keep the random vectors.

If we initialize the shadow of the used diagonal randomly, then we still test this for the used diagonal at least, just not the unused one. 🤷‍♂️

Instead of pirating to_vec, why don't we introduce interfaces that are better geared towards Enzyme?

Yes, this is what I meant by the final sentence in #1264 (comment) . It just takes more time than I have right now to work on.

@simsurace
Copy link
Contributor Author

What if we add a test that ADs through the constructor as well?

@sethaxen
Copy link
Collaborator

sethaxen commented Feb 1, 2024

What if we add a test that ADs through the constructor as well?

Yeah I think that's a good way to go, but for some reason Enzyme errored. But sure, if you can get it working, it's probably best.

@simsurace
Copy link
Contributor Author

What is the API for injecting a custom cotangent into test_reverse?

@sethaxen
Copy link
Collaborator

sethaxen commented Feb 3, 2024

What is the API for injecting a custom cotangent into test_reverse?

Instead of passing a tuple (A, TA), construct an activity (e.g. Duplicated) where the shadow contains the custom cotangent and pass that instead. See

- `args`: Each entry is either an argument to `f`, an activity type accepted by `autodiff`,
or a tuple of the form `(arg, Activity)`, where `Activity` is the activity type of
`arg`. If the activity type specified requires a shadow, one will be automatically
generated.

@simsurace
Copy link
Contributor Author

The tests should now pass, maybe we can trigger the CI.
I would keep extending and refactoring the existing rule for later.
A remaining doubt I had, @sethaxen: it this the right rule to implement, or should we implement one for ldiv! with three arguments instead?

@codecov-commenter
Copy link

codecov-commenter commented Feb 3, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (2c1fb5d) 75.93% compared to head (3d37fc4) 93.54%.

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1264       +/-   ##
===========================================
+ Coverage   75.93%   93.54%   +17.61%     
===========================================
  Files          35        7       -28     
  Lines       10543      248    -10295     
===========================================
- Hits         8006      232     -7774     
+ Misses       2537       16     -2521     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@simsurace
Copy link
Contributor Author

I only ran the tests in a local script, so I forgot to include EnzymeTestUtils. Should be fine now.

Copy link
Collaborator

@sethaxen sethaxen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If \ is under the hood just calling ldiv!, and all other function calls are differentiable by Enzyme, then I agree it makes more sense to target ldiv!.

test/internal_rules.jl Outdated Show resolved Hide resolved
test/internal_rules.jl Outdated Show resolved Hide resolved
test/internal_rules.jl Outdated Show resolved Hide resolved
src/internal_rules.jl Outdated Show resolved Hide resolved
src/internal_rules.jl Outdated Show resolved Hide resolved
@simsurace
Copy link
Contributor Author

I will try to re-write the rule in terms of ldiv! next.

@simsurace
Copy link
Contributor Author

Maybe I should add some tests for batched inputs? Otherwise feel free to merge.

@simsurace
Copy link
Contributor Author

Let's trigger CI to get a full picture of the additional tests.

src/internal_rules.jl Outdated Show resolved Hide resolved
@simsurace
Copy link
Contributor Author

I've stared at the errors for a while but I cannot make sense of them. They are mostly inside the EnzymeTestUtils logic for calling FiniteDifferences for BatchDuplicated. Is everything fine in there? @sethaxen could you please have another look?

test/internal_rules.jl Outdated Show resolved Hide resolved
Comment on lines 415 to 418
for Tret in (Const, Active),
TY in (Const, Duplicated, BatchDuplicated),
TA in (Const, Duplicated, BatchDuplicated),
TB in (Const, Duplicated, BatchDuplicated)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wsmoses in the past, mixing and matching Duplicated and BatchDuplicated caused errors. Is this no longer the case? I'm surprised this succeeds.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, looking at the CI logs, it seems it doesn't succeed.

test/internal_rules.jl Show resolved Hide resolved
test/internal_rules.jl Outdated Show resolved Hide resolved
simsurace and others added 4 commits February 22, 2024 17:23
Co-authored-by: Seth Axen <seth@sethaxen.com>
Co-authored-by: Seth Axen <seth@sethaxen.com>
Co-authored-by: Seth Axen <seth@sethaxen.com>
@simsurace
Copy link
Contributor Author

simsurace commented Feb 22, 2024

Despite the check for compatibility of activities, I still get some failures.
EDIT: nvm, I think I got it

Copy link
Collaborator

@sethaxen sethaxen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@wsmoses wsmoses merged commit e8ede63 into EnzymeAD:main Feb 23, 2024
40 of 43 checks passed
@simsurace simsurace deleted the triangular-solve branch February 24, 2024 15:41
@sethaxen sethaxen mentioned this pull request Mar 15, 2024
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants