-
-
Notifications
You must be signed in to change notification settings - Fork 80
Clearly show that Zygote.Buffer is incorrect and drops gradients #549
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
MWE:
```julia
using DiffEqSensitivity, OrdinaryDiffEq
using Test
function fb(du,u,p,t)
du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]*t
du[2] = dy = -p[3]*u[2] + t*p[4]*u[1]*u[2]
end
p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0]
probb = ODEProblem(fb,u0,(0.0,10.0),p)
solb = solve(probb,Tsit5(),abstol=1e-14,reltol=1e-14)
t = 0.0:0.5:10.0
# g(t,u,i) = (1-u)^2/2, L2 away from 1
function dg(out,u,p,t,i)
(out.=2.0.-u)
end
_,dp1 = adjoint_sensitivities(sol,Tsit5(),dg,t,abstol=1e-14,
reltol=1e-14)
_,dp2 = adjoint_sensitivities(solb,Tsit5(),dg,t,abstol=1e-14,
reltol=1e-14,
sensealg=InterpolatingAdjoint(autojacvec=DiffEqSensitivity.ZygoteVJP()))
_,dp3 = adjoint_sensitivities(solb,Tsit5(),dg,t,abstol=1e-14,
reltol=1e-14,
sensealg=InterpolatingAdjoint(autojacvec=DiffEqSensitivity.ReverseDiffVJP(true))
)
_,dp4 = adjoint_sensitivities(solb,Tsit5(),dg,t,abstol=1e-14,
reltol=1e-14,
sensealg=InterpolatingAdjoint(autojacvec=DiffEqSensitivity.TrackerVJP())
)
isapprox(dp1, dp3, rtol = 1e-10) # yes
isapprox(dp1, dp4, rtol = 1e-10) # yes
isapprox(dp1, dp2, rtol = 1e-10) # no
```
When you run this branch, it has the TrackerVJP result right next to the ZygoteVJP result and shows Zygote just returns zeros for the gradients, silently, without any warning or error.
```julia
(tmp1, tmp12) = ([0.0, 0.0], [26.631836244556528, -59.276505258765454] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [4.864401590759809, -0.46328499411934765, -0.5016239375867565, 0.5567850853102051] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [26.631836244556528, -59.276505258765454] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [4.864401590759809, -0.46328499411934765, -0.5016239375867565, 0.5567850853102051] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [29.239759892814245, -64.92349793405678] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [5.342382748372607, -0.5088078592985804, -0.5486769599396241, 0.6090122999660315] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [29.29298962280317, -65.44712011526572] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [4.7737557956434475, -0.5216705467906235, -0.6424855333647359, 0.6244081606040183] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [29.361097784059247, -65.91604997082547] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [4.256393010680244, -0.541591584161458, -0.7642806891662404, 0.6482524400607981] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [29.746632620545274, -67.09795916159659] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.9093730111190297, -0.6660255541953102, -1.4860076220624616, 0.7971923923419794] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [29.826370888484554, -67.21982110170187] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.7662482681001164, -0.691436222195994, -1.6410179782414434, 0.8276074283520429] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [29.847562023470765, -67.24886639699635] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.732003675792577, -0.6981458462867473, -1.6824987027846035, 0.8356384434488817] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [29.847562023470765, -67.24886639699635] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.732003675792577, -0.6981458462867473, -1.6824987027846035, 0.8356384434488817] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [32.65231041349967, -72.99561289182861] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.995675486557265, -0.765525469930507, -1.823192757796991, 0.9055162751158428] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [32.84757009305544, -73.2149421811492] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.7154508868475986, -0.831261731883523, -2.235470687533171, 0.9832736553753629] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [33.09493861771779, -73.40860551400004] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.465594510079744, -0.9127640575496194, -2.7706489257299416, 1.0796802221707715] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [34.478421563934845, -73.87198707285454] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [1.8663938331581926, -1.3598041561180667, -5.964979923865799, 1.6084700544932242] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [34.76266196733998, -73.91542550789954] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [1.8122875993265848, -1.4538200950605515, -6.655100955467102, 1.7196785853347198] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [34.8381097793373, -73.92554421570522] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [1.79987646882824, -1.4790401810546734, -6.839931793384734, 1.749510640863286] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [34.8381097793373, -73.92554421570522] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [1.79987646882824, -1.4790401810546734, -6.839931793384734, 1.749510640863286] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [37.55519353359098, -79.15205293229714] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [1.9631265670333864, -1.6131901957852537, -7.312732552313099, 1.8704431272885764] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [38.20295942122944, -79.21985590426772] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [1.8716213669317996, -1.8681806360820092, -9.127588733823748, 2.1660964965090557] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [39.01926648980144, -79.2718955247713] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [1.8113403160333863, -2.2107034355180053, -11.4871241866697, 2.5632408740403383] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [43.488773916093315, -79.29749595923818] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [1.9518944136525855, -4.758192381606865, -25.521674609324524, 5.516973920214784] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [44.38537911729765, -79.2761670053049] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.0362976372072272, -5.461181202907928, -28.52577741349958, 6.332067275479824] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [44.6220448469952, -79.26950717301038] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.0611375046741944, -5.66065005834139, -29.328037615976378, 6.563345119052403] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [44.6220448469952, -79.26950717301038] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.0611375046741944, -5.66065005834139, -29.328037615976378, 6.563345119052403] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [44.71659209763698, -82.19967310954614] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.232156976071968, -6.13033312342071, -30.3183257604237, 6.784962499131496] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [46.246735382490535, -82.14351736770497] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [2.523772563410683, -8.381486405692192, -37.88184176666544, 9.276505828425474] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [48.12208259152519, -82.04204188875612] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [3.0201342060048035, -12.156942765640041, -47.546464235723974, 13.455125375459124] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [56.87032354456895, -80.66885364859259] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [9.715064886825084, -71.02380817665073, -98.42156579533089, 78.60810584387762] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [58.16915102272813, -80.138473660938] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [12.379517485886149, -96.5411812662228, -107.07830173514093, 106.85035891616823] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [58.47704253072354, -79.97838789429412] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [13.190904496310262, -104.39399165655345, -109.20859353986253, 115.54173390974331] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [58.47704253072354, -79.97838789429412] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [13.190904496310262, -104.39399165655345, -109.20859353986253, 115.54173390974331] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [33.55366541715509, -77.73410008088817] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [13.969059364344135, -110.55237850741402, -101.46448148094699, 107.34853128826781] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [32.98347559486527, -78.27007679071622] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [24.189946551213247, -209.46474121184778, -115.67273718552283, 203.39437856835428] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [32.92202056803882, -79.28035633641166] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [44.26541946933, -386.85627788167153, -122.02174752780171, 375.6450454610382] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [37.638011289933104, -84.0912917618233] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [159.54852959826255, -411.15798983322617, -42.6311784577402, 399.24248516347734] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [38.12312292545695, -84.11769812134202] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [164.37980227642223, -319.40735255840326, -32.99637945055528, 310.15081396479474] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [38.22424404025974, -84.09702123642069] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [164.99861526015644, -298.80741957262575, -30.9572893746732, 290.14787435810274] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [38.22424404025974, -84.09702123642069] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [164.99861526015644, -298.80741957262575, -30.9572893746732, 290.14787435810274] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [40.98088173036628, -39.227870259230855] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [138.45345883992013, -250.7349573911937, -31.914310762059685, 299.11757832361224] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [37.712244316539895, -41.716200404335325] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [137.22073086966876, -143.60303660069118, -19.48830866347262, 171.3131387614978] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [35.94618238427883, -45.862763782317984] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [129.6640077064968, -82.1383314171176, -12.529193697372602, 97.98800708390338] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [34.12594365800954, -60.38278343115164] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [93.37936953301197, -19.400422560788996, -5.230420482903112, 23.14399027250195] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [34.051041302878005, -62.02714127940237] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [88.64144209420428, -16.863261495028016, -4.979142207484118, 20.11725047640986] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [34.03504072459894, -62.42057196135026] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [87.48783486938594, -16.31640413621136, -4.929949106052264, 19.46486976907005] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [34.03504072459894, -62.42057196135026] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [87.48783486938594, -16.31640413621136, -4.929949106052264, 19.46486976907005] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.814176140237816, -53.008431702444355] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [79.79541973350233, -14.881775489520438, -5.268165422847731, 20.800245940015962] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.628793709610633, -58.5173110748337] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [71.71674376438688, -11.720995290261005, -5.02082187396316, 16.38242593236797] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.51509013600684, -63.40349175050008] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [64.14515832339075, -9.576103896884254, -5.041289540492184, 13.384512912629381] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.41287558996019, -75.24009119949635] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [43.44496823075077, -5.933318174611237, -7.014199002562334, 8.29299416317544] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.41301257908025, -76.39926487694274] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [41.138349915447044, -5.619002395297533, -7.565465703987922, 7.853675244733436] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.41299860618759, -76.67329729054796] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [40.5819774248363, -5.542938074124349, -7.7164835747995815, 7.7473602026353525] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.41299860618759, -76.67329729054796] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [40.5819774248363, -5.542938074124349, -7.7164835747995815, 7.7473602026353525] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.63794833648743, -80.11200799438166] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [40.56590783726544, -5.540743190226762, -8.188205213902684, 8.220969382008771] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.631902344342162, -82.54743572394703] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [36.34868775653611, -4.942240858483933, -9.715275172088273, 7.332953248541794] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [31.59916618731867, -84.64073474494845] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [32.4542838100408, -4.304010931982232, -11.813216642384194, 6.385992073061388] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [30.72428719937769, -89.39741187268406] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [21.75089580987753, -0.9372888222112805, -25.83319882402745, 1.3906839651203264] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [30.39789908846798, -89.82442312808675] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [20.506037994461035, -0.19876342414769294, -29.091196332323715, 0.2949113445762376] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [30.30317617895961, -89.9234904026111] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [20.20211745263974, 0.0, -29.974496800870366, 0.0] (tracked))
(tmp1, tmp12) = ([0.0, 0.0], [30.30317617895961, -89.9234904026111] (tracked))
(tmp2, tmp22) = ([0.0, 0.0, 0.0, 0.0], [20.20211745263974, 0.0, -29.974496800870366, 0.0] (tracked))
```
This is probably the cause of many of the complaints in https://discourse.julialang.org/t/state-of-machine-learning-in-julia/74385. Specifically https://discourse.julialang.org/t/zero-gradients-with-zygote-vs-correct-gradients-with-reversediff-using-diffeqflux/74398 was what tracked it down to this issue (after being told it's a DiffEq issue and not a Zygote issue...............)
|
What was interesting is that this worked: using Zygote, Tracker
function ff(du,u,p,t)
du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]*t
du[2] = dy = -p[3]*u[2] + t*p[4]*u[1]*u[2]
end
p = [1.5,1.0,3.0,1.0]; y = [1.0;1.0]; t = 0.0
_dy, back = Zygote.pullback(y, p) do u, p
out_ = Zygote.Buffer(similar(u))
ff(out_, u, p, t)
vec(copy(out_))
end
tmp1,tmp2 = back([2.0,3.0])
_dy2, back = Tracker.forward(y, p) do u, p
out_ = map(zero, u)
f(out_, u, p, t)
Tracker.collect(out_)
end
tmp12,tmp22 = back([2.0,3.0])
@show _dy, _dy2
@show tmp1, tmp12
@show tmp2, tmp22So the error is even weirder than I thought. It's not "just" buffer. |
|
Bingo, I found it, and I fixed the DiffEq case by doing so? using Zygote, Tracker, SciMLBase
function ff(du,u,p,t)
du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]*t
du[2] = dy = -p[3]*u[2] + t*p[4]*u[1]*u[2]
end
p = [1.5,1.0,3.0,1.0]; y = [1.0;1.0]; t = 0.0
odef = ODEFunction(ff)
_dy, back = Zygote.pullback(y, p) do u, p
out_ = Zygote.Buffer(similar(u))
odef(out_, u, p, t)
vec(copy(out_))
end
tmp1,tmp2 = back([2.0,3.0])
_dy2, back = Tracker.forward(y, p) do u, p
out_ = map(zero, u)
odef(out_, u, p, t)
Tracker.collect(out_)
end
tmp12,tmp22 = back([2.0,3.0])
@show _dy, _dy2
@show tmp1, tmp12
@show tmp2, tmp22I have no idea why that would cause a difference in the gradients. Am going to MWE that more. |
|
Commenting out https://github.com/SciML/DiffEqBase.jl/blob/v6.81.0/src/chainrules.jl#L132-L138 fixes this, even though the overload https://github.com/SciML/DiffEqBase.jl/blob/v6.81.0/src/chainrules.jl#L124-L130 right above it is fine. This means that |
|
SciML/DiffEqBase.jl@2328ebb removes that |
MWE:
When you run this branch, it has the TrackerVJP result right next to the ZygoteVJP result and shows Zygote just returns zeros for the gradients, silently, without any warning or error.
This is probably the cause of many of the complaints in https://discourse.julialang.org/t/state-of-machine-learning-in-julia/74385. Specifically https://discourse.julialang.org/t/zero-gradients-with-zygote-vs-correct-gradients-with-reversediff-using-diffeqflux/74398 was what tracked it down to this issue (after being told it's a DiffEq issue and not a Zygote issue...............)