Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Silence some warnings #383

Merged
merged 2 commits into from
Feb 22, 2022
Merged

Silence some warnings #383

merged 2 commits into from
Feb 22, 2022

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Feb 5, 2022

This avoids:

┌ Warning: Slow fallback implementation invoked for conv!  You probably don't want this; check your datatypes.
[183](https://github.com/FluxML/Flux.jl/runs/5079634118?check_suite_focus=true#step:6:183)
│   yT = ForwardDiff.Dual{Nothing, Float32, 8}
[184](https://github.com/FluxML/Flux.jl/runs/5079634118?check_suite_focus=true#step:6:184)
│   T1 = ForwardDiff.Dual{Nothing, Float32, 8}
[185](https://github.com/FluxML/Flux.jl/runs/5079634118?check_suite_focus=true#step:6:185)
│   T2 = Float32
[186](https://github.com/FluxML/Flux.jl/runs/5079634118?check_suite_focus=true#step:6:186)
└ @ NNlib ~/.julia/packages/NNlib/JQe1Z/src/conv.jl:288
[187](https://github.com/FluxML/Flux.jl/runs/5079634118?check_suite_focus=true#step:6:187)

You will still get the warning if you mix Float32 & Float64, but not for dual numbers.

Xref #349. I'm not sure that's closed, as they still take the slow path. But it's not by mistake.

@ToucheSir
Copy link
Member

I wonder if we should have an info message instead for non-AbstractFloat outputs. WDYT?

@mcabbott
Copy link
Member Author

I guess my take here is that these warnings are there in case you messed up your float eltypes, and things will become slow by mistake. If you call ForwardDiff.gradient on something, then you really did want to put dual numbers in there. It will be slow, but so will e.g. matrix multiplication. It's not a mistake though, and so I think a noisy warning is unhelpful.

@ToucheSir
Copy link
Member

I guess what I'm looking for is a heads-up for users doing nested AD so they're not blindsided by models being much slower or flat out not working on GPU. This would be especially relevant for those using higher level functions like Zygote.hessian_* or AbstractDiff's API where one might be unaware of ForwardDiff's presence. Currently a rather obscure and (seemingly) unrelated error is thrown from deep within the bowels of NNlib.

@darsnack
Copy link
Member

Other comment: this should check for subtypes of G which is what actually invokes the fast vs slow paths.

@mcabbott
Copy link
Member Author

The PR's take is that the warning shouldn't be for all slow paths, only for slow paths which are clearly a mistake.

I agree there's an argument for telling people that Dual numbers are slow. But they are slow in many things, e.g. matmul, and may still be what you want to do. Nested AD is a bit of a performance minefield and you may want to time different approaches, if Dual happens to be best for you, then I don't think we should make it slower and noisy.

I'm not sure I follow re G. What case will this warning not catch which should be caught, or the reverse?

@darsnack
Copy link
Member

darsnack commented Feb 18, 2022

I meant that the check should be all(T -> T <: G, (yT, T1, T2)) && !(T1 == T2 == yT) which is the actual slow case: you have types that could be fast if only they matched. Any other case is "slow" but there isn't a faster option. Currently, the PR implies that any dispatch where the output subtypes AbstractFloat could be faster.

@mcabbott
Copy link
Member Author

Ok. But can this come later?

This PR wants to disable some warnings, on the assumption that there are use cases for non-floats, which will take the slow path (as they will for matmul!) and if you want that, it's rude for the package to shout at you. The motivating example is that Flux's tests push Dual numbers through here.

You could disable some more, but with a more complicated rule, but perhaps that can be a follow-up PR? I'm not sure what cases this logic is trying to catch. If I have a mix of floats, should the warning depend on whether one or all of them is individually acceptable, does this make it more likely to be a mistake? If I'm doing exotic floats like BigFloat, then perhaps no warnings is ideal? But maybe we can wait until someone wants that to argue the precise rule? This PR makes no changes there.

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, I don't feel strongly enough that I wouldn't accept this PR, but I've left a suggestion to make it more concrete.

This PR makes no changes there.

That's is my point: NNlib already has a precise rule, this PR doesn't follow that definition. We have a for-loop based convolution and a GEMM-based one. The latter is called when the destination and source types match, and they are all subtypes of G which is the union of supported GEMM types (defined in src/gemm.jl). I would think that the only time the "You probably don't want this" warning applies is if you almost passed the GEMM check but had mismatched types. So, if one of your arguments is BigFloat, then the current version of the PR prints a warning that's misleading...with BigFloat there is no way you could be faster.

src/conv.jl Outdated Show resolved Hide resolved
Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants