-
-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use the cache less often #39
Conversation
I thought about only considering leaves for caching, but one hiccup is when a mutable struct is used as a shared non-leaf node. Think Flux.BatchNorm. We could say that non-leaf nodes will always be untied, but that needs to be a) decided and b) documented with a prominent warning banner. |
It seems that by making your type non-leaf, you have opted in to having Functors traverse into it. What's gained by not doing so? Can you clarify what problem traversing it will cause? My Since |
Traversing it won't cause any issues, the question is how it should be re-assembled post traversal. Specifically if you run into the same mutable non-leaf multiple times, whether it should be untied as part of the reassembly process (vs fetched from the cache on subsequent occurrences). The question is whether we're ok with breaking the following behaviour: given
My thought for this was to lean on function usecache(x::AbstractArray)
p = parent(x)
p !== x && return usecache(p) # alt. typeof(p) !== typeof(x), etc. if we wanted type stability or to avoid potentially expensive `===` methods.
return ismutable(x) # fallback
end
usecache(x::Array) = true
... |
I guess " any nodes in x that are === should also be so in x′." is a simple-to-explain policy. Would be nice if the policy for how often It seems a bit fragile to depend on the right parent methods existing (in one case) and not existing (in the other), since nothing requires them. |
What wrapper types don't expose a parent method? Worst case they are considered not cacheable due to the fallback in https://github.com/JuliaLang/julia/blob/master/base/abstractarray.jl#L1398.
Aye, this is much easier in a purer functional language where you're only traversing over trees. Sometimes I'm tempted to try representing the object graph as an actual digraph for this reason, but that's a little far off the deep end :P |
fe0fe80
to
a1050b2
Compare
Now updated with a different rule:
Weird cases:
|
src/functor.jl
Outdated
# function _anymutable(x::T) where {T} | ||
# ismutable(x) && return true | ||
# fs = fieldnames(T) | ||
# isempty(fs) && return false | ||
# return any(f -> anymutable(getfield(x, f)), fs) | ||
# end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this fail to constant fold sometimes?
Otherwise LGTM. About the weird cases, we could argue it's more conservative to not cache in both. A false positive seems much worse than a false negative here IMO. Asking uses of a higher-level isleaf
to take on additional responsibility for caching is also fine. Incidentally, this is why I think extracting out caching from Functors and making callbacks handle memoization themselves would be nice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It fails to be instant on surprisingly simple functions, I didn't try to dig into why:
julia> @btime fun_anymutable((x=(1,2), y=3))
min 36.500 ns, mean 38.715 ns (1 allocation, 32 bytes)
false
julia> @btime gen_anymutable((x=(1,2), y=3))
min 0.001 ns, mean 0.014 ns (0 allocations)
false
Perhaps more surprisingly, the generated one is also not free e.g. here:
julia> @btime fun_anymutable($(Metalhead.ResNet()))
min 275.685 ns, mean 323.217 ns (9 allocations, 320 bytes)
true
julia> @btime gen_anymutable($(Metalhead.ResNet()))
min 147.536 ns, mean 161.010 ns (1 allocation, 32 bytes)
true
That contains Chain([...])
which... should just stop the recursion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Smaller example, in which the number of layers seems to matter:
julia> model = Chain(
Conv((3, 3), 1 => 16), # 160 parameters
Conv((3, 3), 16 => 16), # 2_320 parameters
Conv((3, 3), 16 => 32), # 4_640 parameters
Conv((3, 3), 32 => 32), # 9_248 parameters
Conv((3, 3), 32 => 64), # 18_496 parameters
Conv((3, 3), 64 => 64), # 36_928 parameters
Dense(16384 => 10), # 163_850 parameters
);
julia> @btime fun_anymutable($model)
min 327.851 ns, mean 448.404 ns (10 allocations, 3.17 KiB)
true
julia> @btime gen_anymutable($model)
min 215.463 ns, mean 238.700 ns (8 allocations, 608 bytes)
true
julia> model = Chain(
Conv((3, 3), 1 => 16), # 160 parameters
Conv((3, 3), 16 => 16), # 2_320 parameters
# Conv((3, 3), 16 => 32), # 4_640 parameters
# Conv((3, 3), 32 => 32), # 9_248 parameters
# Conv((3, 3), 32 => 64), # 18_496 parameters
Conv((3, 3), 64 => 64), # 36_928 parameters
Dense(16384 => 10), # 163_850 parameters
);
julia> @btime fun_anymutable($model)
min 344.818 ns, mean 391.967 ns (10 allocations, 1.75 KiB)
true
julia> @btime gen_anymutable($model)
min 0.001 ns, mean 0.014 ns (0 allocations)
true
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the Metalhead example at least, that one allocation is coming from https://github.com/FluxML/Metalhead.jl/blob/7827ca6ec4ef7c5e07d04cd6d84a1a3b11289dc0/src/convnets/resnets/resnet.jl#L17. For the longer Chain, Cthulhu tells me
┌ Info: Inference didn't cache this call information because of imprecise analysis due to recursion:
└ Cthulhu nevertheless is trying to descend into it for further inspection.
If I add a guard against the possible missing
from any
in gen_anymutable
and assert the return value like so:
@generated function gen_anymutable(x::T) where {T}
ismutabletype(T) && return true
fs = fieldnames(T)
isempty(fs) && return false
subs = [:(gen_anymutable(getfield(x, $f))) for f in QuoteNode.(fs)]
return :(coalesce(|($(subs...)), false)::Bool)
end
That eliminates all but 6 of the allocations. I believe these correspond to the 6 Conv layers because the check on the Dense layer appears to be fully const folded (why only the Dense? Not sure).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh nice. Just the ::Bool
seems to be enough, and should be safe I think.
Weirdly it is instant for 5 and 7 conv layers, only exactly 6 causes it to fail & take 100ns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is absolutely bizarre. It also works for 6 Conv layers if I remove the final Dense and up to at least 32 with/without. Granted, this is on nightly—I couldn't get close to your timings on 1.8.2 IIRC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, less bizarre. Trying to simplify things a bit, it looks like the first call is always taking a hit here, but subsequent calls are fine.
using BenchmarkTools
struct Conv{N,M,F,A,V}
σ::F
weight::A
bias::V
stride::NTuple{N,Int}
pad::NTuple{M,Int}
dilation::NTuple{N,Int}
groups::Int
end
struct Dense{F, M, B}
weight::M
bias::B
σ::F
end
@generated function anymutable(x::T) where {T}
ismutabletype(T) && return true
fs = fieldnames(T)
isempty(fs) && return false
subs = [:(anymutable(getfield(x, $f))) for f in QuoteNode.(fs)]
return :(|($(subs...))::Bool)
end
function test()
for N in (5, 6, 7)
@info N
layers = ntuple(_ -> Conv(identity, ones(1), ones(1), (1,), (1,), (1,), 1), N)
layers = (layers..., Dense(ones(1), ones(1), identity))
@btime anymutable($layers)
end
end
test()
Perhaps that has something to do with the generated function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's weird.
I don't suggest doing this, but this does seem to compile away:
julia> Base.@assume_effects :total function fun_anymutable3(x::T) where {T}
ismutable(x) && return true
fs = fieldnames(T)
isempty(fs) && return false
return any(f -> fun_anymutable3(getfield(x, f)), fs)::Bool
end
fun_anymutable3 (generic function with 1 method)
julia> function test_3()
for N in (5, 6, 7)
@info N
layers = ntuple(_ -> Conv(identity, ones(1), ones(1), (1,), (1,), (1,), 1), N)
layers = (layers..., Dense(ones(1), ones(1), identity))
@btime fun_anymutable3($layers)
end
end
test_3 (generic function with 1 method)
julia> test_3()
[ Info: 5
min 0.083 ns, mean 0.185 ns (0 allocations)
[ Info: 6
min 0.083 ns, mean 0.208 ns (0 allocations)
[ Info: 7
min 0.083 ns, mean 0.229 ns (0 allocations)
julia> VERSION
v"1.9.0-DEV.1528"
(Edit -- inserted results)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, that still allocates (more, as a matter of fact) for me on nightly and 1.8. At least performance is consistent though.
[ Info: 5
176.931 ns (7 allocations: 1.05 KiB)
[ Info: 6
177.480 ns (7 allocations: 1.23 KiB)
[ Info: 7
188.679 ns (7 allocations: 1.34 KiB)
julia> VERSION
v"1.9.0-DEV.1547"
Am going to merge this so that master has the new behaviour, but won't rush to tag it. The code here will be changed by #43, but the tests may (I think) survive. |
One of the things #32 proposes is to disable the use of the cache for some types, so that e.g. the number
4
appearing at two nodes is not regarded as a form of parameter sharing, just a co-incidence. This PR wants to make that change alone.But what exactly is the right rule here? If
(x = [1,2,3], y = 4)
appears twice, then the shared[1,2,3]
should be cached, but I think the4
should still not be.This PR thinks that the right test is to only use cache on leaf nodes. #32 tested instead
!isbits(x) && ismutable(x)
, which will also work on these examples. Where they differ is on an immutable container enclosing a mutable array:Right now this uses the
exclude
keyword not the fixedisleaf
. I think that makes sense but haven't thought too hard.