Skip to content

test_rrule doesn't catch returning wrong number of outputs #167

@oxinabox

Description

@oxinabox

Consider a primal function with 2 inputs.
It's pullback should thus have 3 outputs (one for each input, plus one for the function itself).

test_rrule doesn't notice the following mistake

using ChainRulesTestUtils
using ChainRulesCore

foo(x, y) = x + 2y

function ChainRulesCore.rrule(::typeof(foo), x, y)
    foo_pullback(dz) = NoTangent(), dz

    # Should be
    #foo_pullback(dz) = NoTangent(), dz, 2dz

    return foo(x,y), foo_pullback
end

test_rrule(foo, 21.0, 32.0)

This is because zip in julia truncates to the shortest.
So as long as the ones that it does have are right it passes it.

The mistake is in:

for (accumulated_x̄, x̄_ad, x̄_fd) in zip(accumulated_x̄, x̄s_ad, x̄s_fd)

We just need to add some extra tests before we get to that point

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions