Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generic_matmul! hit in back! because type-promotion in activation function #613

Open
oxinabox opened this issue Feb 10, 2019 · 17 comments
Open

Comments

@oxinabox
Copy link
Member

Sometimes generic_matmul! is hit in back!
For examopole adding a leak too unit can be done
by writing an activation function like

    leaky_relu6(x) = 0.01x + clamp(x, 0, 6)

And this is well and good, of x is a Float64.
But if x is a Float32 this will trigger a type-promotion.
Which is bad, because the user almost certainly did not intend the type promotion.
But worse,
it means rather than hitting fast BLAS, we fall back to slow generic_matmul!.

Here is a MWE:

function flux_model()
    return Chain(
#        Dense(1280, 64, x->0.1f0x),    # Uncomment one of these lines
#        Dense(1280, 64, x->0.1e0x),   # Uncomment one these lines
        Dense(64, 1),
    )
end

function demo_flux()
    mdl = flux_model()
    features = rand(Float32, (1280, 1000))
    
    Flux.train!(
        params(mdl),
        [(features,)],
        Flux.ADAM()
    ) do xs
        sum(mdl(xs))
    end
end

Time if it has to promote: @time demo_flux()

0.143774 seconds (607 allocations: 19.635 MiB)

Time normally: @time demo_flux()

0.016475 seconds (568 allocations: 13.218 MiB, 47.67% gc time)

That is a 10x time diifference, and it scales up as your matrix sizes scale up.

@KristofferC
Copy link
Contributor

Isn't this expected though? Why aren't you using the eltype of the input to determine the type of the float constants?

@oxinabox
Copy link
Member Author

Depends what you mean by "expected".
It follows the rules of the language, yes.
But the fact that a mistake like this can trash your performance is not great.

As a rule when looking for code that might be causing a slow down one doesn't immediately go looking for constants not type matching.
This is fully type stable, it doesn't allocate etc etc.
Even looking at the profiler output it took @staticfloat and I a while to work this out.

It is of course obvious in retrospect,
but not in from inspection of the code base that this is causing a order of magnitude slow down.

We could certainly consider giving a surpressable warning if the return type of the activation does not match it's inputs.
(Or even just if it switched between floating point types).

Or we could do other things.

@oxinabox
Copy link
Member Author

If nothing else we can have a "performance tips" page, saying to be careful of the types of literals in your activation functions.

I also think we can probably make this faster than it is. If nothing else we can promote the matrix and use BLAS rather than the generic matmul.

@staticfloat
Copy link
Contributor

I think likely what we should do is trigger a warning on our fallback operations; by default I don’t think any user ever wants to use the generic matmul, and so while I like that we support using it, we should have a (silencable) warning that spits out the first time it is invoked, along with a backtrace to figure out where it’s coming from.

@MikeInnes
Copy link
Member

At the risk of flagrant type piracy, we could just override the behaviour of x::Array{T} * y::Array{S} from Base.

It would be worth some notes in the activation functions section of the docs though. NNlib's ones are all set up to preserve input type and there's testing infrastructure for this as well; it's really a matter of following standard Julia style.

@darsnack
Copy link
Member

Given that #615 added this to the docs, do we still want to address this with a warning somehow?

@oxinabox
Copy link
Member Author

Yeah, I think we should, it can wreck your performance.

@darsnack
Copy link
Member

I believe changing the types to hit BLAS makes things troublesome for mixed precision. I'm not an expert on the topic, but I've heard that mentioned quite a few times on orthogonal issues/PRs.

A warning would be good though. Only concern is the type-piracy. Either way I added this to the triage project so it gets discussed during the next ML community call.

@KristofferC
Copy link
Contributor

Please no type piracy.

@oxinabox
Copy link
Member Author

You don't need type piracy.
You put the warning in Dense (or a helper called by Dense etc, maybe even shadowing *) that checks the types before it the calls Base.*

@DhairyaLGandhi
Copy link
Member

Doing these things generically in a manner that doesn't touch ad and runtime performance in forward or backwards pass can be tough.

@oxinabox
Copy link
Member Author

You can have a disable-able safety rails mode that compiles away.

# Safety rails default on
has_safety_rails() = true

# Function to let advanced user turn them off.
#Triggers recompilation
safety_rails!(enable) = @eval has_safety_rails() = $enable


macro safety_rail(cond, msg::String, logargs...)
    disable = " To disable this warning, run `Flux.safety_rails!(false)`."
    #TODO this doesn't quite display logargs right.
    warning = :(@warn($msg*$disable, $(esc.(logargs)...)))
    return quote
        has_safety_rails() && Zygote.ignore() do
            $(esc(cond)) && $warning
        end
    end
end


function *(a::T, b::S) where {T, S}
    @safety_rail(
        T!==S,
        "Mixed type multiplication encountered. This probably means you ...",
        T, S
    )

    return Base.:(*)(a, b)
end

1.0 * 2

I think I stole this trick from TimerOutputs.jl, to have a debug mode that compiles away when not in use.
ChainRulesCore uses it.

@darsnack
Copy link
Member

Summarizing what was discussed during triage today:

The appropriate place for a @safety_rail style check is probably NNlib instead of Flux (if we want it). People were generally uncomfortable with shadowing *, and it was suggested that if generic_matmul! is almost never wanted, then maybe the warning mechanism for it belongs in Base. Admittedly, the penalty is much greater when used in AD like Zygote, but Flux seems like the wrong place in the hierarchy to address this issue generically.

What was suggested instead is to package the promotion check into a utility function. Something like performance_check (could work on a similar mechanism to outputsize) that runs a forwards and backwards pass on the model and throws warnings for any performance issues. Could have a forward pointer to the performance tips docs in the warning string. I don't think there is anything like this w.r.t. performance, but other frameworks do have such utilities as sanity checks.

@KristofferC
Copy link
Contributor

and it was suggested that if generic_matmul! is almost never wanted, then maybe the warning mechanism for it belongs in Base

It is wanted in Base all the time though so I think you will have a hard time putting such a warning there.

@darsnack
Copy link
Member

Would a warning via a utility function be acceptable then @oxinabox?

@oxinabox
Copy link
Member Author

I honestly don't care how it is done.
The point is to protect people who are first learning julia and first learning flux from footguns.
People who haven't yet read all the docs.

We should understand the context.

The thing that matters here is that it is very easy to get Float64's in your network.
Because floating point literals are Float64 . and float(::Int) returns a Float64.
So if you are not careful, you can end up with one coming out of a helper function (e.g. had to fix a few things in Distributions.jl for this recently) or even by just making a mistake and using a literal youself.
Now i normal julia code promoting to Float64 is fine.
It isn't much slower, it only uses 2x as much memory, and it is more accurate.
It is the safe and good bet.

But in NN code you often intentionally want Float32 because lower precision cost you nothing, and even is said to act as a regularizer.
Further the whole process of training a NN boils down to a loop of matmuls.
We need to hit the fast matmul.

One day we might be able to do |> f32 like we do |>gpu.
That also would solve this

@mcabbott
Copy link
Member

mcabbott commented Sep 6, 2021

Times with FluxML/Zygote.jl#1044 : now hardly any slowdown, but a few more allocations than the all-Float32 version:

julia> function flux_model()  # Float32
           return Chain(
               Dense(1280, 64, x->0.1f0x),    # Uncomment one of these lines
       #        Dense(1280, 64, x->0.1e0x),   # Uncomment one these lines
               Dense(64, 1),
           )
       end
flux_model (generic function with 1 method)

julia> @time demo_flux()
  0.571037 seconds (1.88 M allocations: 110.176 MiB, 3.09% gc time, 0.97% compilation time)

julia> @time demo_flux()
  0.011878 seconds (388 allocations: 13.074 MiB)
  0.007133 seconds (388 allocations: 13.074 MiB)  # another run
  
julia> function flux_model()  # Float64
           return Chain(
       #        Dense(1280, 64, x->0.1f0x),    # Uncomment one of these lines
               Dense(1280, 64, x->0.1e0x),   # Uncomment one these lines
               Dense(64, 1),
           )
       end
flux_model (generic function with 1 method)

julia> @time demo_flux()
  0.583360 seconds (1.91 M allocations: 113.019 MiB, 3.05% gc time, 0.88% compilation time)

julia> @time demo_flux()
  0.010863 seconds (389 allocations: 14.543 MiB)
  0.011858 seconds (389 allocations: 14.543 MiB)  # another run

Compared to tagged version without that PR, just the slow case -- 10x slower than Float32, as above:

julia> @time demo_flux()  # Float64, first run after re-defining demo_flux()
  0.655448 seconds (1.91 M allocations: 117.824 MiB, 2.61% gc time, 0.79% compilation time)

julia> @time demo_flux()
  0.097526 seconds (388 allocations: 19.495 MiB, 17.16% gc time)
  0.107293 seconds (388 allocations: 19.495 MiB)  # another run

(@v1.7) pkg> st Zygote
      Status `~/.julia/environments/v1.7/Project.toml`
  [e88e6eb3] Zygote v0.6.20

This version without the PR has some ProjectTo stuff in place, e.g. in the rule for *, but not in broadcasting, so it catches the problem a little later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants