Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjoints for LinearSolve #832

Closed
avik-pal opened this issue Jun 7, 2023 · 1 comment
Closed

Adjoints for LinearSolve #832

avik-pal opened this issue Jun 7, 2023 · 1 comment

Comments

@avik-pal
Copy link
Member

avik-pal commented Jun 7, 2023

Simple non-working example.

using Zygote, ForwardDiff, SciMLSensitivity, SciMLBase, LinearSolve, ComponentArrays,
      FiniteDiff

function loss_function(θ)
    (; A, b) = θ
    prob = LinearProblem(A, b)
    sol = solve(prob, nothing)
    return sum(sol.u)
end

function loss_function_chainrules(θ)
    (; A, b) = θ
    x = A \ b
    return sum(x)
end

A = Float32[1 0; 1 -2]; b = Float32[32; -4];
θ = ComponentArray(; A, b)

loss_function(θ)  loss_function_chainrules(θ)  # true

Zygote.gradient(loss_function, θ)  # fails
Zygote.gradient(loss_function_chainrules, θ)  # works

ForwardDiff.gradient(loss_function, θ)  # fails
ForwardDiff.gradient(loss_function_chainrules, θ)  # works

FiniteDiff.finite_difference_gradient(loss_function, θ)  # works

We need SciML/LinearSolve.jl#322 before we can work on the adjoint code.

@avik-pal
Copy link
Member Author

This is done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant