-
-
Notifications
You must be signed in to change notification settings - Fork 216
Description
Is there a recommended (and preferably generic) mechanism to construct an array y that is a function of a variable x without encountering the issue that array mutation is not supported? Conceptually, this array is constructed and not changed afterwards, i.e. is immutable. However, in practice many functions for constructing/populating an array mutate it. Of course, one can resort to defining a custom adjoints for all of these, but that is not a desirable situation for an autodiff package.
As a concrete example, I will consider a function f that constructs the matrix A out of a vector a. As background, this function computes the A matrix corresponding for the first order vector representation of a Stochastic Differential Equation (SDE).
# Desired function - autodiff does not work
function f(a::AbstractArray{T}) where {T}
p = length(a)
A = [zeros(T,p-1,1) Matrix{T}(I,p-1,p-1)
-a']
end
x = [0.2,0.1,-0.3]
@show f(x)
# y = [0.0 1.0 0.0; 0.0 0.0 1.0; -0.2 -0.1 0.3]
rng = MersenneTwister(58634)
ȳ = randn(rng,3,3)
y,back = Zygote.forward(f,x)
@show back(ȳ)[1]
# ERROR: LoadError: Mutating arrays is not supportedIn this case, I have been able to hack a working solution, but it ain't pretty, nor generalizable:
Atop(p) = hcat(zeros(p-1,1),Matrix(I,p-1,p-1))
Zygote.@nograd Atop
function f_working(a::AbstractArray{T}) where {T}
p = length(a)
A = [Atop(p);-a']
end
y,back = Zygote.forward(f_working,x)
@show back(ȳ)[1]
# (back(ȳ))[1] = [-0.579952; -0.51652; -0.520227]Here, the pattern of where a elements occur in y is not too complicated, so this workaround is feasible, but the pattern can be more complex.
I hope that there is a methodology using the existing code base. If not, perhaps one idea is to define a setindex that returns a copy, instead of setindex!, thus preventing mutation. Possibly there could even be a macro that would replace all setindex! by the copying version automatically when the function is passed through Zygote.forward. Not terribly efficient perhaps, but at least a working solution, and probably not too bad for small arrays.