-
-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
define adjoint #72
define adjoint #72
Conversation
@DhairyaLGandhi I think the |
I have a similar issue as this thread in DiffEqFlux: [https://github.com/SciML/DiffEqFlux.jl/issues/381] with something resembling the MWE in this thread. Was any progress made in the past months? Thank you |
Could we test with FluxML/Zygote.jl#846 ? |
adf621f
to
5e37c2f
Compare
That branch doesn't seem to help. In fact, I'm a bit puzzled and made another similar example to work with first: using OrdinaryDiffEq, DiffEqSensitivity, Flux, DiffEqGPU, StaticArrays, CUDA
CUDA.allowscalar(false)
function model()
prob = ODEProblem((du, u, p, t) -> du[1] = 1.01 * u[1] * p[1] * p[2], u0, (0.0, 1.0), pa)
function prob_func(prob, i, repeat)
prob
end
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
solve(ensemble_prob, Tsit5(), EnsembleGPUArray(0.0), saveat = 0.1, trajectories = 10, sensealg = ForwardDiffSensitivity(convert_tspan=false))
end
# loss function
loss() = sum(abs2,1.0.-Array(model()))
data = Iterators.repeated((), 10)
cb = function () # callback function to observe training
@show loss()
end
pa = [1.0,2.0]
u0 = [3.0]
opt = ADAM(0.1)
println("Starting to train")
loss()
Flux.@epochs 10 Flux.train!(loss, params([pa]), data, opt; cb = cb) In the adjoint I specify, i.e. (size(Array(VectorOfArray(adj))), size(p)) = ((2, 10), (2, 10)) So I know that what I'm pulling back is the same size as (x, gs[x]) = ([1.0, 2.0], [46839.635021615635; 23419.817510807818]) saying that the derivative somehow adjointed on its own... what? |
oh wait, remembering that Zygote's adjoints for comprehensions are incorrect I got rid of the comprehensions. See that last commit. That's all I needed to fix that issue. So I think comprehensions incorrectly transpose variables behind pulled back. @DhairyaLGandhi you might want to take a look at that today and try to find a smaller reproducer since that is an issue that keeps coming up. |
The error isn't reproducible so I'm just going to merge, but @vchuravy it would be good to know why KernelAbstractions.jl cannot compile sometimes, and where it decides it can't is seemingly random, dependent on the computer, how many functions were ran before it, and just how many times a code has been ran. I don't remember it being unstable like that. |
Seems like the test issue was just changing inbounds semantics between different environments. |
Im having trouble reproducing the issue, I see you've gotten rid of comprehensions but is there a more minimal example that I can use? |
There isn't a more minimal example I could find. |
Fixes SciML/DiffEqFlux.jl#381 . MWE: