Skip to content

Adjoints give incorrect values if t values are not unique #335

@ChrisRackauckas

Description

@ChrisRackauckas

MWE https://discourse.julialang.org/t/different-results-between-zygote-forwarddiff-and-reversediff/46540

using DifferentialEquations, Flux, LinearAlgebra, DiffEqFlux, DiffEqSensitivity
using ForwardDiff
using Zygote
using ReverseDiff

function bimolecular!(du,u,p,t,dens,cons)
    # unpack rates and constants
    nᵣ = u[1]
    k₁,k₋₁  = p
    mᵣ,mₗ,A = dens
    # model
    du[1] = dnᵣ = A*k₁*mᵣ*mₗ - k₋₁*nᵣ

end

function run_model(model,p,data,densities,constants,t) # version of run function with multiple models

  Σ_sol_stack = zeros(1, size(data,2))

  for i in 1:size(data,1)
      # run model with given densities
      densities_i = densities[i,:]
      f = (du,u,p,t) -> model(du,u,10 .^ p,t,densities_i,constants)
      tmp_prob = ODEProblem(f,u₀,tspan,p)
      tmp_sol = solve(tmp_prob,Vern7(),saveat=t, abstol=1e-8,reltol=1e-8)
      # stack Σ of solution across n species
      Σ_sol = sum(Array(tmp_sol),dims=1)
      Σ_sol_stack = vcat(Σ_sol_stack,Σ_sol)
  end
  return Σ_sol_stack[2:end,:]
end


function loss(p,data,model,dens,cons,t)
  Σ_sol = run_model(model,p,data,dens,cons,t)
  sum(abs2, (Σ_sol - data)) #, Σ_sol
end

dataset = [  0.25  0.0618754
  0.25  0.040822
  0.5   0.127833
  0.5   0.198451
  1.0   0.274437
  1.0   0.223144
  2.0   0.579818
  2.0   0.653926
  4.0   0.693147
  4.0   0.776529
  6.0   0.820981
  6.0   0.776529
  8.0   0.653926
  8.0   0.776529
 16.0   0.820981
 16.0   0.733969]

t = dataset[:,1]
n = dataset[:,2]

densities = [25.0, 38.0, 1.0]
tspan = (0,maximum(t)+1)
u₀ = [0.0]
rates = [ -3.367837470456765, -0.2863777340019116]


loss_new = (p) -> loss(p,n',bimolecular!,densities',[],t)
loss_new(rates)

grad_zyg = Zygote.gradient(loss_new,rates)
grad_for = ForwardDiff.gradient(loss_new,rates)
grad_rev = ReverseDiff.gradient(loss_new,rates)
hes_for = ForwardDiff.hessian(loss_new,rates)
hes_rev = ReverseDiff.hessian(loss_new,rates)
hes_zyg = Zygote.hessian(loss_new,rates)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions