Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consistency in the type behavior of restructure #95

Open
ChrisRackauckas opened this issue Jul 1, 2022 · 6 comments
Open

Consistency in the type behavior of restructure #95

ChrisRackauckas opened this issue Jul 1, 2022 · 6 comments

Comments

@ChrisRackauckas
Copy link
Member

This was discovered in SciML/NeuralPDE.jl#533 as an issue that only showed itself as an incorrect gradient: the primal passes of what was being trained was in Float64, the reverse passes gave a Float64, the loss function print out give a Float64, and everything looked fine, except magically the Flux neural network was just "a bit more janky", in that it had a much higher probability of failing CI tests for a reason nobody could figure out for 5 months. Finally it was discovered that parts of the gradient were calculated in Float32 because the Flux.Chain had Float32 parameters in there. This showcased that re(p) does not "always" respect the types of p.

But it doesn't "always" respect the types of the Flux.Chain either. For example, for a standard Flux.Chain of Dense layers with Float32 parameters, you get:

  • re(p::Vector{Float64}) computes in Float32
  • re(p::CuVector{Float32}) computes on the GPU in Float32
  • re(p::Vector{Dual}) computes with Dual numbers
  • re(p::Vector{ComplexF32}) computes with Float32

And now let's have some fun:

  • re(p::CuVector{Float64}) computes ???. My guess is CuVector{Float32}?
  • re(p::ReverseDiff.TrackedArray) computes ??? My guess is Array{TrackedReal{Float32}}?

I understand that this isn't intended behavior and comes out of some quirks about ProjectTo , that exposes some (IMO odd) behaviors of a ChainRules internal to users who are likely not experts in the autodiff system.

Now the problem that I have with it is that discovering this behavior is rather hard, because if you do anything other than the simplest "just use the neural network", almost any case will not expose to the user that this behavior exists. For example,

  • (p[end] .* re(p))::typeof(p)
  • (p[end] .+ re(p))::typeof(p)
  • ...

so hold in the examples I described because the type demotion is countered by the type promotion that's applied by essentially any other computation that uses things with the eltype(p). Thus unless re(p) is the only operation that is used (in which case, you probably don't need to be using restructure/destructure), some other operation in the primal will mask the demotion and your forward pass will look like it computed using typeof(p). It will only present itself to a user in the gradient pass.

Thus I understand @mcabbott's reasoning behind saying it's not a gradient correctness issue (since it's correctly calculating the gradients of the object that is actually reconstructed), but I have now isolated many different cases that I thought were just "Flux janky behavior" and "I don't know why FastChain works here but Flux.Chain doesn't" all back to this same behavior. It may not be a gradient correctness issue, but it only presents itself as one in downstream libraries where I have found this, it only really exposes itself if you try to look into a seemingly incorrect gradient, and if it quacks like 🦆?

I understand that this behavior is now documented, but I'm not sure a behavior that presents itself like that is sufficiently handled just by documentation because it's hard to even figure out that something is going wrong without investigating the gradient calculation.

What could be done?

I would propose that we should just make the behavior undeniably straightforward and consistent. Either always make re(p) compute using values of typeof(p), or make it so it always computes using the values from the original Flux.Chain. Either choice is an easily explainable and predictable behavior. This middle ground is not easy to explain or predict.

Always matching p is the more predictable behavior in the Julia ecosystem. If you stick a complex number as the initial condition in the ODE solver, as the initial guess for a value in Optim, as the starting point for IterativeSolvers or NLsolve, etc. any generic code that I can think of, they will treat the computation in the sense that p provides. In many cases generic codes will just error if they can't handle it, but they try to compute using p. Non-generic codes immediately throw method errors describing what the allowed inputs are. I cannot think of another example in the Julia ecosystem where the "computation type" for f(p) does not match p or a fixed type, but instead match the internal types of the fields of f, only sometimes, other times it matches p.

If it always matches the Flux.Chain, at least that would be clearly visible since when you do it on a CuArray you see you get an Array and you're like oh, I see how this works. If I want to GPU, then I |> gpu the chain because it doesn't convert to p. Got it. With the current behavior, you see it re(p) works on the GPU, so okay why not just do re(p::Array{Float64}) as a quick way to convert to Float64? And if you think like that, you get burned.

The other behavior could be to throw an error in any case where a type conversion is necessary. If you want re(p::Array{Float64}) to work, go back and |> f64 the neural network. Now, this will cause some issues with making libraries work, but it's a nice (overly) safe option that would ensure there are no surprises.

Or, as @ToucheSir suggested, maybe these are two different functions, or two different options, and you should be required to choose which behavior you want. Some kind of re(p,Optimisers.NoConvert()) and re(p,Optimisers.Convert()).

Those 4 behaviors would be clear and easily predictable. I think the only option I would be adamantly against is the current behavior.

@ToucheSir
Copy link
Member

ToucheSir commented Jul 1, 2022

I would actually be in favour of behaviour 3: destructure is fundamentally a function that promises too much, and even after the effort made towards tightening that (+ improving correctness) when porting over to Optimisers, one could argue it still does. Disallowing promotion while destructuring would also reduce some internal complexity.

Now, another tricky thing is what to do about structured array types. Here I think we just have to enumerate as many weird cases as we can think of and come to an agreement on how to handle them all consistently. One such example:

julia> d = Dense(Diagonal(rand(Float32, 3)), false)
Dense(3 => 3; bias=false)  # 9 parameters

julia> d.weight
3×3 Diagonal{Float32, Vector{Float32}}:
 0.24043             
         0.657887    
                  0.52947

julia> p, re = destructure(d)
(  [1]  =  0.24043
  [5]  =  0.657887
  [9]  =  0.52947, Restructure(Dense, ..., 9))

julia> p
9-element SparseArrays.SparseVector{Float32, Int64} with 3 stored entries:
  [1]  =  0.24043
  [5]  =  0.657887
  [9]  =  0.52947

julia> re(p)
Dense(3 => 3; bias=false)  # 9 parameters

julia> re(p) |> dump
Dense{typeof(identity), Diagonal{Float32, SparseArrays.SparseVector{Float32, Int64}}, Bool}
  weight: Diagonal{Float32, SparseArrays.SparseVector{Float32, Int64}}
    diag: SparseArrays.SparseVector{Float32, Int64}
      n: Int64 3
      nzind: Array{Int64}((3,)) [1, 2, 3]
      nzval: Array{Float32}((3,)) Float32[0.24042994, 0.6578865, 0.52947]
  bias: Bool false
  σ: identity (function of type typeof(identity))

And another one:

julia> d = Dense(rand(Float32, 3, 2), @SArray ones(3))
Dense(2 => 3)       # 9 parameters

julia> p, re = destructure(d)
(Float32[0.9659148, -0.7210188, 0.20607175, 0.7583495, 0.35627228, -0.5444089, 0.0, 0.0, 0.0], Restructure(Dense, ..., 9))

julia> re(p)
Dense(2 => 3)       # 9 parameters

julia> re(p) |> dump
Dense{typeof(identity), Matrix{Float32}, SizedVector{3, Float32, Vector{Float32}}}
  weight: Array{Float32}((3, 2)) Float32[0.9659148 0.7583495; -0.7210188 0.35627228; 0.20607175 -0.5444089]
  bias: SizedVector{3, Float32, Vector{Float32}}
    data: Array{Float32}((3,)) Float32[0.0, 0.0, 0.0]
  σ: identity (function of type typeof(identity))

julia> cu_p = cu(p)
9-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}:
  0.9659148
 -0.7210188
  0.20607175
  0.7583495
  0.35627228
 -0.5444089
  0.0
  0.0
  0.0

julia> re(cu_p) |> dump
Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, SizedVector{3, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}
  weight: CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}
    storage: CUDA.ArrayStorage{CUDA.Mem.DeviceBuffer}
      buffer: CUDA.Mem.DeviceBuffer
        ctx: CuContext
          handle: Ptr{Nothing} @0x0000000002ab0400
          valid: Bool true
        ptr: CuPtr{Nothing} CuPtr{Nothing}(0x0000000701bc0800)
        bytesize: Int64 24
        async: Bool false
      refcount: Base.Threads.Atomic{Int64}
        value: Int64 1
    maxsize: Int64 24
    offset: Int64 0
    dims: Tuple{Int64, Int64}
      1: Int64 3
      2: Int64 2
  bias: SizedVector{3, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}
    data: CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}
      storage: CUDA.ArrayStorage{CUDA.Mem.DeviceBuffer}
        buffer: CUDA.Mem.DeviceBuffer
          ctx: CuContext
            handle: Ptr{Nothing} @0x0000000002ab0400
            valid: Bool true
          ptr: CuPtr{Nothing} CuPtr{Nothing}(0x0000000701bc0a00)
          bytesize: Int64 12
          async: Bool false
        refcount: Base.Threads.Atomic{Int64}
          value: Int64 1
      maxsize: Int64 12
      offset: Int64 0
      dims: Tuple{Int64}
        1: Int64 3
  σ: identity (function of type typeof(identity))

@mcabbott
Copy link
Member

mcabbott commented Jul 1, 2022

I repeat that no incorrect gradients have been displayed here. Calling other features you happen to dislike in some context gradient bugs is just muddying the waters. (There are known gradient bugs, they are marked "bug" in the issues here.)

Maybe it's helpful to understand what the goals are of the present design:

  1. Don't assume that all arrays have the same type. The optimisation rules don't need this, nor does destructure.
  2. If some parameters of the model are real, they must never be made complex. This is a correctness question.
  3. One of the use cases for destructure is to make something ForwardDiff.jl can understand. Thus Dual numbers ought to propagate everywhere.
  4. Try to preserve element types: Unintended promotion to Float64 is an easy source of major performance problems. (And, point 1, you are welcome to have different precisions in different parts of the model; destructure will assume that wasn't an accident.)

For 3., you may recall that #66 was precisely to address your complaint that Dual numbers did not propagate through some operations.

Since ReverseDiff.jl also likes flat arrays not nested trees, the same should go for its tracked arrays. If they don't propagate, I think that's a bug. But no need to guess. Tracker's arrays seem to work fine, something seems to make Vector{ReverseDiff.TrackedReal}, but surely that could be solved.

At present, this package does not know about GPU arrays, and thus makes no distinctions. If you think it's confusing that re from a CPU model can be fed a GPU array and construct a GPU model, it would not be very hard to forbid that. (Models with a mix of GPU and CPU arrays won't work very well right now. Various policies could be adopted, but nobody has got around to it.)

@mcabbott
Copy link
Member

mcabbott commented Jul 1, 2022

Re structured arrays, I suspect most of them should be marked @functor. I think you are suggesting that the sparse array outcome is undesirable, but I can't reproduce it on Julia nightly, so I suspect some of the weirdness about Diagonal being sometimes seen as sparse has gone away (with SparseArrays.jl moving out?)

julia> v, re = destructure(Diagonal([0.11, 0.22]));

julia> v
4-element Vector{Float64}:
 0.11
 0.0
 0.0
 0.22

julia> re([1.0, 2.0, 3.0, 4.0])
2×2 Diagonal{Float64, Vector{Float64}}:
 1.0   ⋅ 
  ⋅   4.0

This discards 2.0, 3.0 components of the new parameters. The zeros in v are structural, not accidental, so they are preserved.

@ToucheSir
Copy link
Member

I don't claim to know what the right answer is, so I posted those examples because it's not clear if they'd be considered consistent enough to pass muster. Another one is Transpose/Adjoint, which projects back to a dense array (AIUI it's a no-op) rather than the wrapper type.

On a meta level, I feel even more strongly now that the behaviour of destructure was way underspecified when it was first written. Not only is there disagreement about what counts as sufficiently "reconstructed" for various parameter types, but the proliferation of fancy array types in the ecosystem makes post-hoc specification (as we're attempting now) a significant undertaking. Here I'd be interested to know from @willtebbutt or others working packages like ParameterHandling.jl how they handle this particular can of worms :)

@mcabbott
Copy link
Member

mcabbott commented Jul 1, 2022

Ok. Adjoint should now reconstruct:

julia> destructure(rand(2)')[2]([1.0, 2.0])
1×2 adjoint(::Vector{Float64}) with eltype Float64:
 1.0  2.0

julia> destructure(transpose(rand(2,2)))[2]([1, 2, 3, 4])
2×2 transpose(::Matrix{Float64}) with eltype Float64:
 1.0  2.0
 3.0  4.0

I agree that things are a bit under-specified. Like everything else in Julia really -- it's a bit of an exploration to see what properties turn out to be useful, and how to compose them.

@ChrisRackauckas
Copy link
Member Author

I repeat that no incorrect gradients have been displayed here. Calling other features you happen to dislike in some context gradient bugs is just muddying the waters. (There are known gradient bugs, they are marked "bug" in the issues here.)

I don't disagree. There are no incorrect gradients here by the definition now in the docs. It's just an issue that only presents itself to downstream users via incorrect gradients (as demonstrated) in functions which expect to have the normal action that a generic Julia function generally has. It's a very subtle distinction. I agree it's not incorrect as documented, but it is also very hard to spot that it's happening in most cases (with demonstrations as to why)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants