Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

define adjoint #72

Merged
merged 8 commits into from
Dec 12, 2020
Merged

define adjoint #72

merged 8 commits into from
Dec 12, 2020

Conversation

ChrisRackauckas
Copy link
Member

Fixes SciML/DiffEqFlux.jl#381 . MWE:

using OrdinaryDiffEq, DiffEqSensitivity, Flux, DiffEqGPU, StaticArrays, CUDA
CUDA.allowscalar(false)

function model()
  prob = ODEProblem((du, u, p, t) -> du[1] = 1.01 * u[1] * p[1], u0, (0.0, 1.0), pa)

  function prob_func(prob, i, repeat)
    remake(prob, u0 = 0.5 .+ i/100 .* prob.u0)
  end

  ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
  solve(ensemble_prob, Tsit5(), EnsembleGPUArray(), saveat = 0.1, trajectories = 10, sensealg = ForwardDiffSensitivity(convert_tspan=false))
end

# loss function
loss() = sum(abs2,1.0.-Array(model()))

data = Iterators.repeated((), 10)

cb = function () # callback function to observe training
  @show loss()
end

pa = [1.0]
u0 = [3.0]
opt = ADAM(0.1)
println("Starting to train")

loss()

Flux.@epochs 10 Flux.train!(loss, params([pa]), data, opt; cb = cb)

@ChrisRackauckas
Copy link
Member Author

@DhairyaLGandhi I think the map adjoint doesn't correctly ignore nothing's going backwards, could you take a look at this?

@jc-audet
Copy link

jc-audet commented Dec 8, 2020

I have a similar issue as this thread in DiffEqFlux:

[https://github.com/SciML/DiffEqFlux.jl/issues/381]

with something resembling the MWE in this thread. Was any progress made in the past months?

Thank you

@DhairyaLGandhi
Copy link
Member

Could we test with FluxML/Zygote.jl#846 ?

@ChrisRackauckas
Copy link
Member Author

That branch doesn't seem to help. In fact, I'm a bit puzzled and made another similar example to work with first:

using OrdinaryDiffEq, DiffEqSensitivity, Flux, DiffEqGPU, StaticArrays, CUDA
CUDA.allowscalar(false)

function model()
  prob = ODEProblem((du, u, p, t) -> du[1] = 1.01 * u[1] * p[1] * p[2], u0, (0.0, 1.0), pa)

  function prob_func(prob, i, repeat)
    prob
  end

  ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
  solve(ensemble_prob, Tsit5(), EnsembleGPUArray(0.0), saveat = 0.1, trajectories = 10, sensealg = ForwardDiffSensitivity(convert_tspan=false))
end

# loss function
loss() = sum(abs2,1.0.-Array(model()))

data = Iterators.repeated((), 10)

cb = function () # callback function to observe training
  @show loss()
end

pa = [1.0,2.0]
u0 = [3.0]
opt = ADAM(0.1)
println("Starting to train")

loss()

Flux.@epochs 10 Flux.train!(loss, params([pa]), data, opt; cb = cb)

In the adjoint I specify, i.e. ZygoteRules.@adjoint function batch_solve_up(ensembleprob,probs,alg,ensemblealg,I,u0,p;kwargs...), I have that:

(size(Array(VectorOfArray(adj))), size(p)) = ((2, 10), (2, 10))

So I know that what I'm pulling back is the same size as p (correct? I assume Zygote doesn't do something crazy on matrices?). You would think that's working, but it gets all the way back to the Flux update! code where it sees

(x, gs[x]) = ([1.0, 2.0], [46839.635021615635; 23419.817510807818])

saying that the derivative somehow adjointed on its own... what?

@ChrisRackauckas
Copy link
Member Author

oh wait, remembering that Zygote's adjoints for comprehensions are incorrect I got rid of the comprehensions. See that last commit. That's all I needed to fix that issue. So I think comprehensions incorrectly transpose variables behind pulled back. @DhairyaLGandhi you might want to take a look at that today and try to find a smaller reproducer since that is an issue that keeps coming up.

@ChrisRackauckas ChrisRackauckas changed the title [WIP] define adjoint define adjoint Dec 11, 2020
@ChrisRackauckas
Copy link
Member Author

ChrisRackauckas commented Dec 12, 2020

The error isn't reproducible so I'm just going to merge, but @vchuravy it would be good to know why KernelAbstractions.jl cannot compile sometimes, and where it decides it can't is seemingly random, dependent on the computer, how many functions were ran before it, and just how many times a code has been ran. I don't remember it being unstable like that.

@ChrisRackauckas ChrisRackauckas merged commit 221a452 into master Dec 12, 2020
@ChrisRackauckas ChrisRackauckas deleted the adjoint branch December 12, 2020 03:50
@ChrisRackauckas
Copy link
Member Author

Seems like the test issue was just changing inbounds semantics between different environments.

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Dec 21, 2020

Im having trouble reproducing the issue, I see you've gotten rid of comprehensions but is there a more minimal example that I can use?

@ChrisRackauckas
Copy link
Member Author

There isn't a more minimal example I could find.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

EnsembleGPUArray() fails
3 participants