Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type Promotion often Unwieldy and day Ruining #1026

Closed
caseykneale opened this issue Feb 7, 2020 · 13 comments
Closed

Type Promotion often Unwieldy and day Ruining #1026

caseykneale opened this issue Feb 7, 2020 · 13 comments

Comments

@caseykneale
Copy link

So I've blown 2 days tracking pervasive errors in making basic toy models for a project at work... The errors are Undef Ref errors... Somewhere in Zygote on an @adjoint function...
I believe the issue is reshape in a chain or a model or something explodes the world... Below is a minimum not working example...

using Flux
function permute_data( batch::Array )
    s = size(batch)
    reshape( batch, (s[2],1,1,s[1]) )
end

Flux.@nograd permute_data
squish(z) = reshape(z, :, size(z, 4))

Flux.@nograd squish
squishgrad(z) = reshape(z, ( prod( size(z)[1:3]),  size(z, 4) ) )

stop_gradient(x) = x
Flux.@nograd stop_gradient
#Let's make a model....
obs, vars       = ( 10, 200 )
proxydata       = randn(obs,vars)
proxydata       = permute_data( proxydata )
ToEmbeddingA    = Flux.Conv( ( 5, 1 ),  1 => 16, relu)
ToEmbeddingB    = Flux.Conv( ( 5, 1 ),  16 => 24, relu)

freeze_forelayer(x) = ToEmbeddingA(x) |> ToEmbeddingB
Flux.@nograd freeze_forelayer
newproxy        = freeze_forelayer(proxydata)
newproxysize    = size(reshape(newproxy, :, size(newproxy, 4)))

EmbeddingA      = Flux.Dense( newproxysize[1], 32,  relu )
EmbeddingB      = Flux.Dense( 32, 32,  relu )

regression_head = Flux.Dense( 32, 1, relu )
siamese_head    = Flux.Dense( 32, 1, relu )

regr_model = Chain(  freeze_forelayer, squishgrad,
                    EmbeddingA, EmbeddingB,
                    regression_head )


siam_model2 = Chain(  ToEmbeddingA, ToEmbeddingB, squishgrad,
                    EmbeddingA, EmbeddingB,
                    regression_head )

SGD = Descent( 10.0 )
mseloss( x1, y1 ) = sum( regr_model( x1 ) .- y1 )
smseloss( x1, y1 ) = sum( siam_model( x1 ) .- y1 )

ex = proxydata

yeye =  Matrix(transpose(randn(10)))

println(siam_model.layers[1].weight[1,1,1,1])
Flux.train!( mseloss, Flux.params( regr_model ), [ (ex, yeye) ], SGD )
Flux.train!( smseloss, Flux.params( siam_model ), [ (ex, yeye) ], SGD )
println(siam_model.layers[1].weight[1,1,1,1])
@mkschleg
Copy link
Contributor

mkschleg commented Feb 7, 2020

More minimum + usable w/ revise...

using Flux

flatten(x) = reshape(x, :, size(x, 4))

stop_gradient(x) = x
Flux.@nograd stop_gradient

function main()
    
    #Let's make a model....
    obs, vars       = ( 10, 200 )
    proxydata = randn(vars, 1, 1, obs)
    
    model = Chain(Conv( ( 5, 1 ),  1 => 16, relu),
                  Conv( ( 5, 1 ),  16 => 24, relu),
                  # stop_gradient,
                  flatten,
                  Dense( 4608, 32,  relu ),
                  Dense( 32, 32,  relu ),
                  Dense( 32, 1, relu ))

    println(model(proxydata))
    mseloss( x1, y1 ) = sum( model( x1 ) .- y1 )
    Flux.gradient(()->mseloss(proxydata, randn(1,10)), Flux.params(model))

end

This is a weird bug, as I've been using zygote+flatten for a bit now on the GPU and haven't run into this at all. It could be CPU specific (again which is really odd). Something weird I found when minimizing this is it works w/ init=Flux.zeros for the conv layers (which you can't use in practice obviously) and it works rarely w/ the random initialization (i.e. glorot_uniform, glorot_normal, rand, randn, etc). I don't think it has anything to do w/ reshaping, but am not confident in zygote to know for sure.

@caseykneale
Copy link
Author

Interesting it seems associated with the CPU? Unfortunately for my use case I must be able to backprop on the CPU...

Yes the error is sporadic. Making it even harder to debug. I think 1/20 times the code will actually run making you believe you have solved the problem only to be met with undef ref on the second go around...

Thank you for cleaning up the code. What I posted was a hot mess from ripping apart a somewhat large architecture but light weight architecture(i've been around the block) and presuming Flux 0.1.0 worked how the previous versions did... Once this is fixed I have another fundamental bug to report but this takes serious priority for me and likely many others...

@mcabbott
Copy link
Member

mcabbott commented Feb 8, 2020

If you insert this function in place of stop_gradient, then you get typeof(dx) = Array{AbstractFloat,4}1 or Array{Float32,4} or Array{Float64,4} on different runs. Only on Float64 does it sometimes run without error. And on Float32 the gradient is always zero, and it always errors.

using Zygote
show_gradient(x) = x
Zygote.@adjoint show_gradient(x) = x, dx -> ((@show extrema(dx) typeof(dx); dx),)

If you further change the input data to be proxydata = randn(Float32, vars, 1, 1, obs), then it is either typeof(dx) = Array{Any,4} and always an error, or typeof(dx) = Array{Float32,4} and always runs, although the gradient is then zero.

And with mseloss(proxydata, randn(Float32, 1,10)), it runs without error!

With hindsight this ought to have been obvious from the stacktrace, which contains a mixture of AbstractFloat and Float64 and Float32. Would he helpful if Flux could catch that, though.

@caseykneale
Copy link
Author

okay hang on... So Flux is changing data from Float64 to Float32, then when it compares at the loss function it cannot do the conversion(because it would break the graph). Oof - this reminds me of pytorchy juggling types and tensors.

Why does it does type conversion happen if there is a reshape, but nowhere else? So the solution may be to degrade the y values from double to single precision? I'll try this, because I can accept single precision for this application, but weird! Awesome debugging.

@mcabbott
Copy link
Member

mcabbott commented Feb 8, 2020

Everything in Flux is by default Float32, because GPUs like that. But sometimes things get promoted e.g. FluxML/NNlib.jl#149, and there is some logic to auto-convert things #815 (comment) . If you ask me the default should be to error anytime an array of floats has a gradient of different precision, perhaps Chain could enforce this.

It's extra-weird that in your example, it varies from case to case, that would be worth tracking down.

@caseykneale
Copy link
Author

caseykneale commented Feb 8, 2020

Okay I can confirm changing both the X & Y to Float32's prevents the undef ref error but the gradient does not update. Any thoughts on why the gradient is all zeros?

I do agree that a helpful user message could prevent serious hardship for someone else.

@haampie
Copy link

haampie commented Feb 11, 2020

I'm running into this reshape issue as well; I'm using a reshape in the loss function. A forward pass preserves the Float32s, the loss function takes Float32s and outputs Float32 as well. The stracktrace also shows a bunch of Float64s and AbstractFloats.

For what it's worth, I think my issue happens just in the loss function where I compute a weighted crossentropy like this:

function weighted_crossentropy(ŷ::AbstractArray{T,5}, y::AbstractArray{T,5}) where {T}
    total_pixels = size(y, 1) * size(y, 2) * size(y, 3)

    # Give every class some weight
    α = 0.05f0

    # Rows are pixels, columns are classes
    classes_per_pixel_y = reshape(y, total_pixels, :)
    classes_per_pixel_ŷ = reshape(ŷ, total_pixels, :)

    # Compute weights relative to area (more area = smaller weight)
    weights = α .+ area .- sum(classes_per_pixel_y, dims = 1)
    weights_normalized = weights ./ sum(weights)

    # Finally compute the weighted cross entropy
    β = 1.0f-45
    return -sum(classes_per_pixel_y .* log.(classes_per_pixel_ŷ .+ β) .* weights_normalized) / area
end

If I add @mcabbott's show_gradient to the end of the Chain the gradient is already Array{Float64,5}.

@mcabbott
Copy link
Member

For completeness, the answer here is:

julia> 1f0 / 5  # integer division preserves Float32
0.2f0

julia> gradient(x -> x/5 * x, Float32(1))[1] isa Float64 # but not in a gradient
true

julia> DiffRules.diffrule(:Base, :/, :x, :y) # because of this rule
(:(inv(y)), :(-((x / y) / y)))

julia> 1f0 * inv(5) # and inv(Int) is Float64
0.2

Ref. JuliaDiff/DiffRules.jl#26 perhaps.

@haampie
Copy link

haampie commented Feb 11, 2020

Yes, I can confirm, the issue in my code is not reshape, that's a red herring, it is the division by area causing the promotion to Float64 in the gradient.

@caseykneale caseykneale changed the title Conv -> Reshape -> Dense broken... Type Promotion often Unwieldy and day Ruining Feb 11, 2020
@caseykneale
Copy link
Author

Thank you for getting further along with this. It's no just Reshape, it appears to be type promotion in "unexpected" places. Such as .^ 2 rather then abs2.() and other weird places

@mcabbott
Copy link
Member

Great! With that PR, dividing by an integer seems safe again:

julia> DiffRules.diffrule(:Base, :/, :x, :y)
(:(one(x) / y), :(-((x / y) / y)))

julia> gradient(x -> x/5 * x, Float32(1))[1]
0.4f0

@caseykneale
Copy link
Author

awesome work @haampie ! @mcabbott do you feel this issue made a large enough stink that it can be closed now? I feel with integer division safety back in place, most other issues will naturally occur?

I think exponentiation (elementwise) also had type instability, but am not in a position to check this right now.

@mcabbott
Copy link
Member

Your call but I'd close it. I don't see obvious problems with exp.(x) but if you find some do open another issue:

julia> gradient(x -> sum(identity, exp.(x ./ 2)), [1f-1, 2f-2])
(Float32[0.52563554, 0.5050251],)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants