Skip to content

ReverseDiff doesn't like DynamicPPL.ReshapeTransform #698

@penelopeysm

Description

@penelopeysm

#555 introduced DynamicPPL.ReshapeTransform, which is very nice, but there's what seems to be a bug in ReverseDiff.jl which causes it to fail when ReshapeTransform is composed with a broadcasted function.

I reported the upstream bug at JuliaDiff/ReverseDiff.jl#265. In the context of DynamicPPL, this occurs when we have something like the following:

using DynamicPPL: invlink_transform, ReshapeTransform
using ReverseDiff

f(x) = invlink_transform(InverseGamma(2, 3))
g(x) = ReshapeTransform(())(x)
h = f  g
ReverseDiff.gradient(h, [1.0])

I suspect we should be able to change the implementation of ReshapeTransform though to try to circumvent this. I don't actually know all the possible shapes of stuff ReshapeTransform handles and whether different input/output shapes would give different ReverseDiff errors. However, I dug into a couple of the failing tests in Turing.jl, and it seems that both of them stem from ReshapeTransform being given singleton arrays (e.g. [1.0] above). Furthermore, the error message observed in all the other failing tests is the same (although I didn't verify that they ultimately stem from singleton arrays). So I think we could special-case this behaviour to keep ReverseDiff on our side.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions