-
-
Notifications
You must be signed in to change notification settings - Fork 216
Description
SciML/SciMLSensitivity.jl#340 (comment) highlights an issue with the current adjoint definition that is used in Zygote. Essentially it's as follows. The ArrayPartition is a type that is a tuple of arrays, and the object acts like a vector which is the concatenation of the arrays.
using RecursiveArrayTools
ap = ArrayPartition([1.0, 2.0],[3.0, 4.0])
ap[3] # 3.0Simple? Yes, but enough to break this. Zygote's type handling allows it to pull back on some functions:
function f(ap)
sum(ArrayPartition(ap.x[1],ap.x[2]))
end
ap = ArrayPartition([0.0, 0.0],[0.0, 0.0])
Zygote.gradient(f,ap) # ((x = ([1.0, 1.0], [1.0, 1.0]),),)Notice that this works like we'd expect: 4 values in, 4 values out. Now let's try something else:
function f(ap)
ap.x[2][1] + ap.x[2][2]
end
ap = ArrayPartition([0.0, 0.0],[0.0, 0.0])
Zygote.gradient(f,ap) # ((x = (nothing, [1.0, 1.0]),),)And 💥 . There are multiple issues here. First of all, having nothing in an array partition doesn't work and all operations will fail. But secondly, even if we do manually interpret that to zero to fix it for Zygote, Zygote isn't giving us back an object that is sized, so the resulting derivative only has 3 values! In some sense, the missing value isn't nothing, but it's a Zero(2 dimensional array), and without that information we cannot appropriately interpret and reconstruct the vector to get the correct ArrayPartition!
Now can I define my way out of this? No. Let's say we wanted to define a literal_getproperty:
using ZygoteRules
ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(A::ArrayPartition, ::Val{:x})
function literal_ArrayPartition_x_adjoint(d)
@show d
ArrayPartition(d)
end
A.x,literal_ArrayPartition_x_adjoint
endIf we then run the function that worked:
function f(ap)
sum(ArrayPartition(ap.x[1],ap.x[2]))
endwe'll notice that we get an error. The reason is because the ap.x[2] call causes only the second value in the tuple to be used, and so the @show d shows d = (nothing, [1.0, 1.0]) and it errors because the ArrayPartition is malformed. Again, nothing is giving not enough information for the user to correct this! I can manually correct for it by using the forward pass value:
ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(A::ArrayPartition, ::Val{:x})
function literal_ArrayPartition_x_adjoint(d)
(ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...),)
end
A.x,literal_ArrayPartition_x_adjoint
endand tada it's handled correctly now.
function f(ap)
ap.x[2][1] + ap.x[2][2]
end
ap = ArrayPartition([0.0, 0.0],[0.0, 0.0])
Zygote.gradient(f,ap) # ArrayPartition([0.0, 0.0],[1.0, 1.0])So moral of the story, using nothing there is just wrong and if this is fixed, a lot of getproperty fallback definitions will have a much better chance of working. As a reference, here it's an array type but #510 's definition was constrained to only work on single partial dual numbers, and if you try to alleviate the issue you'll see it's this same issue of having to handle nothing in odd ways. Somehow this should really be a zero.