-
-
Notifications
You must be signed in to change notification settings - Fork 609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add @autosize
#2078
Add @autosize
#2078
Changes from 4 commits
ac34df9
604f2b4
46e06c7
310b71e
537b011
b2016e2
e2ab1ec
67ea6a7
936bb5b
5c1ed68
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -142,16 +142,81 @@ end | |
m = LayerNorm(32) | ||
@test outputsize(m, (32, 32, 3, 16)) == (32, 32, 3, 16) | ||
@test outputsize(m, (32, 32, 3); padbatch=true) == (32, 32, 3, 1) | ||
m2 = LayerNorm(3, 2) | ||
@test outputsize(m2, (3, 2)) == (3, 2) == size(m2(randn(3, 2))) | ||
@test outputsize(m2, (3,)) == (3, 2) == size(m2(randn(3, 2))) | ||
|
||
m = BatchNorm(3) | ||
@test outputsize(m, (32, 32, 3, 16)) == (32, 32, 3, 16) | ||
@test outputsize(m, (32, 32, 3); padbatch=true) == (32, 32, 3, 1) | ||
@test_throws Exception m(randn(Float32, 32, 32, 5, 1)) | ||
@test_throws DimensionMismatch outputsize(m, (32, 32, 5, 1)) | ||
|
||
m = InstanceNorm(3) | ||
@test outputsize(m, (32, 32, 3, 16)) == (32, 32, 3, 16) | ||
@test outputsize(m, (32, 32, 3); padbatch=true) == (32, 32, 3, 1) | ||
@test_throws Exception m(randn(Float32, 32, 32, 5, 1)) | ||
@test_throws DimensionMismatch outputsize(m, (32, 32, 5, 1)) | ||
|
||
m = GroupNorm(16, 4) | ||
@test outputsize(m, (32, 32, 16, 16)) == (32, 32, 16, 16) | ||
@test outputsize(m, (32, 32, 16); padbatch=true) == (32, 32, 16, 1) | ||
@test_throws Exception m(randn(Float32, 32, 32, 15, 4)) | ||
@test_throws DimensionMismatch outputsize(m, (32, 32, 15, 4)) | ||
end | ||
|
||
@testset "autosize macro" begin | ||
m = @autosize (3,) Dense(_ => 4) | ||
@test randn(3) |> m |> size == (4,) | ||
|
||
m = @autosize (3, 1) Chain(Dense(_ => 4), Dense(4 => 10), softmax) | ||
@test randn(3, 5) |> m |> size == (10, 5) | ||
|
||
m = @autosize (2, 3, 4, 5) Dense(_ => 10) # goes by first dim, not 2nd-last | ||
@test randn(2, 3, 4, 5) |> m |> size == (10, 3, 4, 5) | ||
|
||
m = @autosize (9,) Dense(_ => div(_,2)) | ||
@test randn(9) |> m |> size == (4,) | ||
|
||
m = @autosize (3,) Chain(one = Dense(_ => 4), two = softmax) # needs kw | ||
@test randn(3) |> m |> size == (4,) | ||
|
||
m = @autosize (3, 45) Maxout(() -> Dense(_ => 6, tanh), 2) # needs ->, block | ||
@test randn(3, 45) |> m |> size == (6, 45) | ||
|
||
# here Parallel gets two inputs, no problem: | ||
m = @autosize (3,) Chain(SkipConnection(Dense(_ => 4), Parallel(vcat, Dense(_ => 5), Dense(_ => 6))), Flux.Scale(_)) | ||
@test randn(3) |> m |> size == (11,) | ||
|
||
# like Dense, LayerNorm goes by the first dimension: | ||
m = @autosize (3, 4, 5) LayerNorm(_) | ||
@test rand(3, 6, 7) |> m |> size == (3, 6, 7) | ||
|
||
m = @autosize (3, 3, 10) LayerNorm(_, _) # does not check that sizes match | ||
@test rand(3, 3, 10) |> m |> size == (3, 3, 10) | ||
|
||
m = @autosize (3,) Flux.Bilinear(_ => 10) | ||
@test randn(3) |> m |> size == (10,) | ||
|
||
m = @autosize (3, 1) Flux.Bilinear(_ => 10) | ||
@test randn(3, 4) |> m |> size == (10, 4) | ||
|
||
@test_throws Exception @eval @autosize (3,) Flux.Bilinear((_,3) => 10) | ||
|
||
# first docstring example | ||
m = @autosize (3, 1) Chain(Dense(_ => 2, sigmoid), BatchNorm(_, affine=false)) | ||
@test randn(3, 4) |> m |> size == (2, 4) | ||
|
||
# evil docstring example | ||
img = [28, 28]; | ||
m = @autosize (img..., 1, 32) Chain( # size is only needed at runtime | ||
Chain(c = Conv((3,3), _ => 5; stride=2, pad=SamePad()), | ||
p = MeanPool((3,3)), | ||
b = BatchNorm(_), | ||
f = Flux.flatten), | ||
Dense(_ => _÷4, relu, init=Flux.rand32), # can calculate output size _÷4 | ||
SkipConnection(Dense(_ => _, relu), +), | ||
Dense(_ => 10), | ||
) |> gpu # moves to GPU after initialisation | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mcabbott I missed this in the review, but GPU tests are failing because the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh no I thought I checked this, sorry. It's not working but the binding is what I expected:
It seems the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My scheme was I think aimed at So we should just make it an error. And remove this use from the docs. Call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good to me. |
||
@test randn(Float32, img..., 1, 32) |> gpu |> m |> size == (10, 32) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we force users to call
recursive_striplazy(model, input_size)
or something before using an incrementally constructed network like this? Maybe define arrule
which throws an error?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
striplazy
should be fully recursive. We could make a function that calls this afteroutputsize
& returns the model. And indeed anrrule
would be one way to forbid you not to strip the model before using it for real.I suppose the other policy would just be to allow these things to survive in the model. As long as you never change it, and don't care about the cost of the
if
& type instability, it should work?But any use outside of
@autosize
probably needs another macro... writingFlux.LazyLayer("", x -> Dense(size(x,1) => 10), nothing)
seems sufficiently obscure that perhaps it's OK to say that's obviously at own risk, for now?@autosize
can be the only API until we decide if we want more.