Skip to content

Conversation

@ChrisRackauckas
Copy link
Member

@ChrisRackauckas ChrisRackauckas commented Jan 22, 2022

MWE:

using DiffEqSensitivity, OrdinaryDiffEq
using Test

function fb(du,u,p,t)
  du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]*t
  du[2] = dy = -p[3]*u[2] + t*p[4]*u[1]*u[2]
end

p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0]
probb = ODEProblem(fb,u0,(0.0,10.0),p)
solb = solve(probb,Tsit5(),abstol=1e-14,reltol=1e-14)

t = 0.0:0.5:10.0
# g(t,u,i) = (1-u)^2/2, L2 away from 1
function dg(out,u,p,t,i)
  (out.=2.0.-u)
end

_,dp1 = adjoint_sensitivities(solb,Tsit5(),dg,t,abstol=1e-14,
                                 reltol=1e-14)

_,dp2 = adjoint_sensitivities(solb,Tsit5(),dg,t,abstol=1e-14,
                              reltol=1e-14,
                              sensealg=InterpolatingAdjoint(autojacvec=DiffEqSensitivity.ZygoteVJP()))

_,dp3 = adjoint_sensitivities(solb,Tsit5(),dg,t,abstol=1e-14,
                              reltol=1e-14,
                              sensealg=InterpolatingAdjoint(autojacvec=DiffEqSensitivity.ReverseDiffVJP(true))
                              )

_,dp4 = adjoint_sensitivities(solb,Tsit5(),dg,t,abstol=1e-14,
                              reltol=1e-14,
                              sensealg=InterpolatingAdjoint(autojacvec=DiffEqSensitivity.TrackerVJP())
                              )

isapprox(dp1, dp3, rtol = 1e-10) # yes
isapprox(dp1, dp4, rtol = 1e-10) # yes
isapprox(dp1, dp2, rtol = 1e-10) # no

When you run this branch, it has the TrackerVJP result right next to the ZygoteVJP result and shows Zygote just returns zeros for the gradients, silently, without any warning or error.

(tmp1, tmp12) = ([0.0, 0.0], [26.631836244556528, -59.276505258765454] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [4.864401590759809, -0.46328499411934765, -0.5016239375867565, 0.5567850853102051] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [26.631836244556528, -59.276505258765454] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [4.864401590759809, -0.46328499411934765, -0.5016239375867565, 0.5567850853102051] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [29.239759892814245, -64.92349793405678] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [5.342382748372607, -0.5088078592985804, -0.5486769599396241, 0.6090122999660315] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [29.29298962280317, -65.44712011526572] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [4.7737557956434475, -0.5216705467906235, -0.6424855333647359, 0.6244081606040183] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [29.361097784059247, -65.91604997082547] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [4.256393010680244, -0.541591584161458, -0.7642806891662404, 0.6482524400607981] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [29.746632620545274, -67.09795916159659] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.9093730111190297, -0.6660255541953102, -1.4860076220624616, 0.7971923923419794] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [29.826370888484554, -67.21982110170187] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.7662482681001164, -0.691436222195994, -1.6410179782414434, 0.8276074283520429] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [29.847562023470765, -67.24886639699635] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.732003675792577, -0.6981458462867473, -1.6824987027846035, 0.8356384434488817] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [29.847562023470765, -67.24886639699635] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.732003675792577, -0.6981458462867473, -1.6824987027846035, 0.8356384434488817] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [32.65231041349967, -72.99561289182861] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.995675486557265, -0.765525469930507, -1.823192757796991, 0.9055162751158428] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [32.84757009305544, -73.2149421811492] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.7154508868475986, -0.831261731883523, -2.235470687533171, 0.9832736553753629] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [33.09493861771779, -73.40860551400004] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.465594510079744, -0.9127640575496194, -2.7706489257299416, 1.0796802221707715] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [34.478421563934845, -73.87198707285454] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [1.8663938331581926, -1.3598041561180667, -5.964979923865799, 1.6084700544932242] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [34.76266196733998, -73.91542550789954] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [1.8122875993265848, -1.4538200950605515, -6.655100955467102, 1.7196785853347198] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [34.8381097793373, -73.92554421570522] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [1.79987646882824, -1.4790401810546734, -6.839931793384734, 1.749510640863286] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [34.8381097793373, -73.92554421570522] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [1.79987646882824, -1.4790401810546734, -6.839931793384734, 1.749510640863286] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [37.55519353359098, -79.15205293229714] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [1.9631265670333864, -1.6131901957852537, -7.312732552313099, 1.8704431272885764] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [38.20295942122944, -79.21985590426772] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [1.8716213669317996, -1.8681806360820092, -9.127588733823748, 2.1660964965090557] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [39.01926648980144, -79.2718955247713] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [1.8113403160333863, -2.2107034355180053, -11.4871241866697, 2.5632408740403383] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [43.488773916093315, -79.29749595923818] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [1.9518944136525855, -4.758192381606865, -25.521674609324524, 5.516973920214784] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [44.38537911729765, -79.2761670053049] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.0362976372072272, -5.461181202907928, -28.52577741349958, 6.332067275479824] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [44.6220448469952, -79.26950717301038] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.0611375046741944, -5.66065005834139, -29.328037615976378, 6.563345119052403] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [44.6220448469952, -79.26950717301038] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.0611375046741944, -5.66065005834139, -29.328037615976378, 6.563345119052403] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [44.71659209763698, -82.19967310954614] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.232156976071968, -6.13033312342071, -30.3183257604237, 6.784962499131496] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [46.246735382490535, -82.14351736770497] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.523772563410683, -8.381486405692192, -37.88184176666544, 9.276505828425474] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [48.12208259152519, -82.04204188875612] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [3.0201342060048035, -12.156942765640041, -47.546464235723974, 13.455125375459124] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [56.87032354456895, -80.66885364859259] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [9.715064886825084, -71.02380817665073, -98.42156579533089, 78.60810584387762] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [58.16915102272813, -80.138473660938] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [12.379517485886149, -96.5411812662228, -107.07830173514093, 106.85035891616823] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [58.47704253072354, -79.97838789429412] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [13.190904496310262, -104.39399165655345, -109.20859353986253, 115.54173390974331] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [58.47704253072354, -79.97838789429412] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [13.190904496310262, -104.39399165655345, -109.20859353986253, 115.54173390974331] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [33.55366541715509, -77.73410008088817] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [13.969059364344135, -110.55237850741402, -101.46448148094699, 107.34853128826781] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [32.98347559486527, -78.27007679071622] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [24.189946551213247, -209.46474121184778, -115.67273718552283, 203.39437856835428] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [32.92202056803882, -79.28035633641166] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [44.26541946933, -386.85627788167153, -122.02174752780171, 375.6450454610382] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [37.638011289933104, -84.0912917618233] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [159.54852959826255, -411.15798983322617, -42.6311784577402, 399.24248516347734] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [38.12312292545695, -84.11769812134202] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [164.37980227642223, -319.40735255840326, -32.99637945055528, 310.15081396479474] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [38.22424404025974, -84.09702123642069] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [164.99861526015644, -298.80741957262575, -30.9572893746732, 290.14787435810274] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [38.22424404025974, -84.09702123642069] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [164.99861526015644, -298.80741957262575, -30.9572893746732, 290.14787435810274] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [40.98088173036628, -39.227870259230855] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [138.45345883992013, -250.7349573911937, -31.914310762059685, 299.11757832361224] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [37.712244316539895, -41.716200404335325] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [137.22073086966876, -143.60303660069118, -19.48830866347262, 171.3131387614978] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [35.94618238427883, -45.862763782317984] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [129.6640077064968, -82.1383314171176, -12.529193697372602, 97.98800708390338] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [34.12594365800954, -60.38278343115164] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [93.37936953301197, -19.400422560788996, -5.230420482903112, 23.14399027250195] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [34.051041302878005, -62.02714127940237] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [88.64144209420428, -16.863261495028016, -4.979142207484118, 20.11725047640986] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [34.03504072459894, -62.42057196135026] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [87.48783486938594, -16.31640413621136, -4.929949106052264, 19.46486976907005] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [34.03504072459894, -62.42057196135026] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [87.48783486938594, -16.31640413621136, -4.929949106052264, 19.46486976907005] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.814176140237816, -53.008431702444355] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [79.79541973350233, -14.881775489520438, -5.268165422847731, 20.800245940015962] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.628793709610633, -58.5173110748337] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [71.71674376438688, -11.720995290261005, -5.02082187396316, 16.38242593236797] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.51509013600684, -63.40349175050008] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [64.14515832339075, -9.576103896884254, -5.041289540492184, 13.384512912629381] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.41287558996019, -75.24009119949635] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [43.44496823075077, -5.933318174611237, -7.014199002562334, 8.29299416317544] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.41301257908025, -76.39926487694274] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [41.138349915447044, -5.619002395297533, -7.565465703987922, 7.853675244733436] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.41299860618759, -76.67329729054796] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [40.5819774248363, -5.542938074124349, -7.7164835747995815, 7.7473602026353525] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.41299860618759, -76.67329729054796] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [40.5819774248363, -5.542938074124349, -7.7164835747995815, 7.7473602026353525] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.63794833648743, -80.11200799438166] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [40.56590783726544, -5.540743190226762, -8.188205213902684, 8.220969382008771] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.631902344342162, -82.54743572394703] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [36.34868775653611, -4.942240858483933, -9.715275172088273, 7.332953248541794] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.59916618731867, -84.64073474494845] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [32.4542838100408, -4.304010931982232, -11.813216642384194, 6.385992073061388] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [30.72428719937769, -89.39741187268406] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [21.75089580987753, -0.9372888222112805, -25.83319882402745, 1.3906839651203264] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [30.39789908846798, -89.82442312808675] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [20.506037994461035, -0.19876342414769294, -29.091196332323715, 0.2949113445762376] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [30.30317617895961, -89.9234904026111] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [20.20211745263974, 0.0, -29.974496800870366, 0.0] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [30.30317617895961, -89.9234904026111] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [20.20211745263974, 0.0, -29.974496800870366, 0.0] (tracked))

This is probably the cause of many of the complaints in https://discourse.julialang.org/t/state-of-machine-learning-in-julia/74385. Specifically https://discourse.julialang.org/t/zero-gradients-with-zygote-vs-correct-gradients-with-reversediff-using-diffeqflux/74398 was what tracked it down to this issue (after being told it's a DiffEq issue and not a Zygote issue...............)

MWE:

```julia
using DiffEqSensitivity, OrdinaryDiffEq
using Test

function fb(du,u,p,t)
  du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]*t
  du[2] = dy = -p[3]*u[2] + t*p[4]*u[1]*u[2]
end

p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0]
probb = ODEProblem(fb,u0,(0.0,10.0),p)
solb = solve(probb,Tsit5(),abstol=1e-14,reltol=1e-14)

t = 0.0:0.5:10.0
# g(t,u,i) = (1-u)^2/2, L2 away from 1
function dg(out,u,p,t,i)
  (out.=2.0.-u)
end

_,dp1 = adjoint_sensitivities(sol,Tsit5(),dg,t,abstol=1e-14,
                                 reltol=1e-14)

_,dp2 = adjoint_sensitivities(solb,Tsit5(),dg,t,abstol=1e-14,
                              reltol=1e-14,
                              sensealg=InterpolatingAdjoint(autojacvec=DiffEqSensitivity.ZygoteVJP()))

_,dp3 = adjoint_sensitivities(solb,Tsit5(),dg,t,abstol=1e-14,
                              reltol=1e-14,
                              sensealg=InterpolatingAdjoint(autojacvec=DiffEqSensitivity.ReverseDiffVJP(true))
                              )

_,dp4 = adjoint_sensitivities(solb,Tsit5(),dg,t,abstol=1e-14,
                              reltol=1e-14,
                              sensealg=InterpolatingAdjoint(autojacvec=DiffEqSensitivity.TrackerVJP())
                              )

isapprox(dp1, dp3, rtol = 1e-10) # yes
isapprox(dp1, dp4, rtol = 1e-10) # yes
isapprox(dp1, dp2, rtol = 1e-10) # no
```

When you run this branch, it has the TrackerVJP result right next to the ZygoteVJP result and shows Zygote just returns zeros for the gradients, silently, without any warning or error. 

```julia
(tmp1, tmp12) = ([0.0, 0.0], [26.631836244556528, -59.276505258765454] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [4.864401590759809, -0.46328499411934765, -0.5016239375867565, 0.5567850853102051] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [26.631836244556528, -59.276505258765454] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [4.864401590759809, -0.46328499411934765, -0.5016239375867565, 0.5567850853102051] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [29.239759892814245, -64.92349793405678] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [5.342382748372607, -0.5088078592985804, -0.5486769599396241, 0.6090122999660315] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [29.29298962280317, -65.44712011526572] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [4.7737557956434475, -0.5216705467906235, -0.6424855333647359, 0.6244081606040183] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [29.361097784059247, -65.91604997082547] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [4.256393010680244, -0.541591584161458, -0.7642806891662404, 0.6482524400607981] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [29.746632620545274, -67.09795916159659] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.9093730111190297, -0.6660255541953102, -1.4860076220624616, 0.7971923923419794] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [29.826370888484554, -67.21982110170187] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.7662482681001164, -0.691436222195994, -1.6410179782414434, 0.8276074283520429] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [29.847562023470765, -67.24886639699635] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.732003675792577, -0.6981458462867473, -1.6824987027846035, 0.8356384434488817] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [29.847562023470765, -67.24886639699635] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.732003675792577, -0.6981458462867473, -1.6824987027846035, 0.8356384434488817] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [32.65231041349967, -72.99561289182861] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.995675486557265, -0.765525469930507, -1.823192757796991, 0.9055162751158428] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [32.84757009305544, -73.2149421811492] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.7154508868475986, -0.831261731883523, -2.235470687533171, 0.9832736553753629] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [33.09493861771779, -73.40860551400004] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.465594510079744, -0.9127640575496194, -2.7706489257299416, 1.0796802221707715] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [34.478421563934845, -73.87198707285454] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [1.8663938331581926, -1.3598041561180667, -5.964979923865799, 1.6084700544932242] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [34.76266196733998, -73.91542550789954] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [1.8122875993265848, -1.4538200950605515, -6.655100955467102, 1.7196785853347198] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [34.8381097793373, -73.92554421570522] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [1.79987646882824, -1.4790401810546734, -6.839931793384734, 1.749510640863286] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [34.8381097793373, -73.92554421570522] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [1.79987646882824, -1.4790401810546734, -6.839931793384734, 1.749510640863286] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [37.55519353359098, -79.15205293229714] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [1.9631265670333864, -1.6131901957852537, -7.312732552313099, 1.8704431272885764] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [38.20295942122944, -79.21985590426772] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [1.8716213669317996, -1.8681806360820092, -9.127588733823748, 2.1660964965090557] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [39.01926648980144, -79.2718955247713] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [1.8113403160333863, -2.2107034355180053, -11.4871241866697, 2.5632408740403383] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [43.488773916093315, -79.29749595923818] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [1.9518944136525855, -4.758192381606865, -25.521674609324524, 5.516973920214784] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [44.38537911729765, -79.2761670053049] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.0362976372072272, -5.461181202907928, -28.52577741349958, 6.332067275479824] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [44.6220448469952, -79.26950717301038] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.0611375046741944, -5.66065005834139, -29.328037615976378, 6.563345119052403] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [44.6220448469952, -79.26950717301038] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.0611375046741944, -5.66065005834139, -29.328037615976378, 6.563345119052403] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [44.71659209763698, -82.19967310954614] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.232156976071968, -6.13033312342071, -30.3183257604237, 6.784962499131496] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [46.246735382490535, -82.14351736770497] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.523772563410683, -8.381486405692192, -37.88184176666544, 9.276505828425474] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [48.12208259152519, -82.04204188875612] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [3.0201342060048035, -12.156942765640041, -47.546464235723974, 13.455125375459124] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [56.87032354456895, -80.66885364859259] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [9.715064886825084, -71.02380817665073, -98.42156579533089, 78.60810584387762] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [58.16915102272813, -80.138473660938] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [12.379517485886149, -96.5411812662228, -107.07830173514093, 106.85035891616823] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [58.47704253072354, -79.97838789429412] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [13.190904496310262, -104.39399165655345, -109.20859353986253, 115.54173390974331] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [58.47704253072354, -79.97838789429412] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [13.190904496310262, -104.39399165655345, -109.20859353986253, 115.54173390974331] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [33.55366541715509, -77.73410008088817] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [13.969059364344135, -110.55237850741402, -101.46448148094699, 107.34853128826781] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [32.98347559486527, -78.27007679071622] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [24.189946551213247, -209.46474121184778, -115.67273718552283, 203.39437856835428] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [32.92202056803882, -79.28035633641166] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [44.26541946933, -386.85627788167153, -122.02174752780171, 375.6450454610382] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [37.638011289933104, -84.0912917618233] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [159.54852959826255, -411.15798983322617, -42.6311784577402, 399.24248516347734] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [38.12312292545695, -84.11769812134202] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [164.37980227642223, -319.40735255840326, -32.99637945055528, 310.15081396479474] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [38.22424404025974, -84.09702123642069] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [164.99861526015644, -298.80741957262575, -30.9572893746732, 290.14787435810274] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [38.22424404025974, -84.09702123642069] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [164.99861526015644, -298.80741957262575, -30.9572893746732, 290.14787435810274] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [40.98088173036628, -39.227870259230855] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [138.45345883992013, -250.7349573911937, -31.914310762059685, 299.11757832361224] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [37.712244316539895, -41.716200404335325] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [137.22073086966876, -143.60303660069118, -19.48830866347262, 171.3131387614978] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [35.94618238427883, -45.862763782317984] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [129.6640077064968, -82.1383314171176, -12.529193697372602, 97.98800708390338] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [34.12594365800954, -60.38278343115164] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [93.37936953301197, -19.400422560788996, -5.230420482903112, 23.14399027250195] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [34.051041302878005, -62.02714127940237] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [88.64144209420428, -16.863261495028016, -4.979142207484118, 20.11725047640986] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [34.03504072459894, -62.42057196135026] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [87.48783486938594, -16.31640413621136, -4.929949106052264, 19.46486976907005] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [34.03504072459894, -62.42057196135026] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [87.48783486938594, -16.31640413621136, -4.929949106052264, 19.46486976907005] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.814176140237816, -53.008431702444355] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [79.79541973350233, -14.881775489520438, -5.268165422847731, 20.800245940015962] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.628793709610633, -58.5173110748337] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [71.71674376438688, -11.720995290261005, -5.02082187396316, 16.38242593236797] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.51509013600684, -63.40349175050008] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [64.14515832339075, -9.576103896884254, -5.041289540492184, 13.384512912629381] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.41287558996019, -75.24009119949635] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [43.44496823075077, -5.933318174611237, -7.014199002562334, 8.29299416317544] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.41301257908025, -76.39926487694274] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [41.138349915447044, -5.619002395297533, -7.565465703987922, 7.853675244733436] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.41299860618759, -76.67329729054796] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [40.5819774248363, -5.542938074124349, -7.7164835747995815, 7.7473602026353525] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.41299860618759, -76.67329729054796] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [40.5819774248363, -5.542938074124349, -7.7164835747995815, 7.7473602026353525] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.63794833648743, -80.11200799438166] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [40.56590783726544, -5.540743190226762, -8.188205213902684, 8.220969382008771] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.631902344342162, -82.54743572394703] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [36.34868775653611, -4.942240858483933, -9.715275172088273, 7.332953248541794] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.59916618731867, -84.64073474494845] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [32.4542838100408, -4.304010931982232, -11.813216642384194, 6.385992073061388] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [30.72428719937769, -89.39741187268406] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [21.75089580987753, -0.9372888222112805, -25.83319882402745, 1.3906839651203264] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [30.39789908846798, -89.82442312808675] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [20.506037994461035, -0.19876342414769294, -29.091196332323715, 0.2949113445762376] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [30.30317617895961, -89.9234904026111] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [20.20211745263974, 0.0, -29.974496800870366, 0.0] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [30.30317617895961, -89.9234904026111] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [20.20211745263974, 0.0, -29.974496800870366, 0.0] (tracked))
```

This is probably the cause of many of the complaints in https://discourse.julialang.org/t/state-of-machine-learning-in-julia/74385. Specifically https://discourse.julialang.org/t/zero-gradients-with-zygote-vs-correct-gradients-with-reversediff-using-diffeqflux/74398 was what tracked it down to this issue (after being told it's a DiffEq issue and not a Zygote issue...............)
@ChrisRackauckas
Copy link
Member Author

What was interesting is that this worked:

using Zygote, Tracker
function ff(du,u,p,t)
  du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]*t
  du[2] = dy = -p[3]*u[2] + t*p[4]*u[1]*u[2]
end
p = [1.5,1.0,3.0,1.0]; y = [1.0;1.0]; t = 0.0
_dy, back = Zygote.pullback(y, p) do u, p
  out_ = Zygote.Buffer(similar(u))
  ff(out_, u, p, t)
  vec(copy(out_))
end
tmp1,tmp2 = back([2.0,3.0])

_dy2, back = Tracker.forward(y, p) do u, p
  out_ = map(zero, u)
  f(out_, u, p, t)
  Tracker.collect(out_)
end
tmp12,tmp22 = back([2.0,3.0])
@show _dy, _dy2
@show tmp1, tmp12
@show tmp2, tmp22

So the error is even weirder than I thought. It's not "just" buffer.

@ChrisRackauckas
Copy link
Member Author

ChrisRackauckas commented Jan 22, 2022

Bingo, I found it, and I fixed the DiffEq case by doing so?

using Zygote, Tracker, SciMLBase
function ff(du,u,p,t)
  du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]*t
  du[2] = dy = -p[3]*u[2] + t*p[4]*u[1]*u[2]
end
p = [1.5,1.0,3.0,1.0]; y = [1.0;1.0]; t = 0.0
odef = ODEFunction(ff)
_dy, back = Zygote.pullback(y, p) do u, p
  out_ = Zygote.Buffer(similar(u))
  odef(out_, u, p, t)
  vec(copy(out_))
end
tmp1,tmp2 = back([2.0,3.0])

_dy2, back = Tracker.forward(y, p) do u, p
  out_ = map(zero, u)
  odef(out_, u, p, t)
  Tracker.collect(out_)
end
tmp12,tmp22 = back([2.0,3.0])
@show _dy, _dy2
@show tmp1, tmp12
@show tmp2, tmp22

I have no idea why that would cause a difference in the gradients. Am going to MWE that more.

@ChrisRackauckas
Copy link
Member Author

Commenting out https://github.com/SciML/DiffEqBase.jl/blob/v6.81.0/src/chainrules.jl#L132-L138 fixes this, even though the overload https://github.com/SciML/DiffEqBase.jl/blob/v6.81.0/src/chainrules.jl#L124-L130 right above it is fine. This means that @adjoint! is a bit finicky and should probably be avoided for now.

ChrisRackauckas added a commit to SciML/DiffEqBase.jl that referenced this pull request Jan 22, 2022
@ChrisRackauckas
Copy link
Member Author

SciML/DiffEqBase.jl@2328ebb removes that @adjoint! usage

@ChrisRackauckas ChrisRackauckas deleted the allow_buffer branch January 22, 2022 20:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants