Skip to content

[Nested AD] Incorrect gradient when taking a gradient over a gradient using StatefulLuxLayer #630

@MatsudaHaruki

Description

@MatsudaHaruki

Hello,

I needed to compute the gradient over a gradient and was able to run the following code without error using StatefulLuxLayer as described in the "Nested Automatic Differentiation" chapter of the documentation. However, when I verified the result with FiniteDiff.jl, as the documentation does, I found that the obtained result is invalid.

Code Example:

using Lux, Random, LinearAlgebra, Zygote, FiniteDiff

rng = Xoshiro(0)
model = Dense(3 => 3)    # very simple model
ps, st = Lux.setup(rng, model)

x = ones32(rng, 3, 1)

grads_1 = Zygote.gradient(x) do x
    smodel = StatefulLuxLayer(model, ps, st)
    v = randn!(rng, similar(x))
    w = Zygote.gradient((z -> dot(v, z))  smodel, x) |> only
    return sum(w)
end |> only    # shows [0.0; 0.0; 0.0;;]

grads_2 = FiniteDiff.finite_difference_gradient(x) do x
    smodel = StatefulLuxLayer(model, ps, st)
    v = randn!(rng, similar(x))
    w = Zygote.gradient((z -> dot(v, z))  smodel, x) |> only
    return sum(w)
end            # shows [-260.231; 114.75144; 21.052755;;]

Is this the expected behavior of this package? I would be glad to know if there is anything in the code that needs to be corrected.

Thank you.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions