Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start using the less verbose Lux.@compact API #917

Open
avik-pal opened this issue Apr 12, 2024 · 0 comments
Open

Start using the less verbose Lux.@compact API #917

avik-pal opened this issue Apr 12, 2024 · 0 comments
Labels
good first issue Good for newcomers

Comments

@avik-pal
Copy link
Member

avik-pal commented Apr 12, 2024

Current version

@concrete struct NeuralODE{M <: AbstractExplicitLayer} <: NeuralDELayer
    model::M
    tspan
    args
    kwargs
end

function NeuralODE(model, tspan, args...; kwargs...)
    !(model isa AbstractExplicitLayer) && (model = Lux.transform(model))
    return NeuralODE(model, tspan, args, kwargs)
end

function (n::NeuralODE)(x, p, st)
    model = StatefulLuxLayer(n.model, nothing, st)

    dudt(u, p, t) = model(u, p)
    ff = ODEFunction{false}(dudt; tgrad = basic_tgrad)
    prob = ODEProblem{false}(ff, x, n.tspan, p)

    return (
        solve(prob, n.args...;
            sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP()), n.kwargs...),
        model.st)
end

This would become: (argument splatting args... won't work but keyword argument splatting kwargs... is fine)

function NeuralODE(model, tspan, solver = nothing; kwargs...)
    !(model isa AbstractExplicitLayer) && (model = Lux.transform(model))
    return @compact(; model, tspan, solver, sensealg=InterpolatingAdjoint(; autojacvec = ZygoteVJP()), kwargs...) do x, p
        dudt(u, p, t) = model(u, p)
        prob = ODEProblem(ODEFunction{false}(dudt; tgrad = basic_tgrad), x, n.tspan, p.model)
        return solve(prob, solver; sensealg, kwargs...)
    end
end

Also this handles all the boxing issues automatically (the reason we had to add the StatefulLuxLayer)

Not sure if this is considered breaking. The end user wont be able to do foo(::NeuralODE) after this. But we don't guarantee that (considering the NonlinearSolve.jl precedent where we made algorithms into functions and not types).

Needs LuxDL/Lux.jl#584 which will be released later today

@avik-pal avik-pal added the good first issue Good for newcomers label Apr 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

1 participant