Skip to content

Commit

Permalink
add tests, release note
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Oct 5, 2022
1 parent 46e06c7 commit 310b71e
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 10 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Flux Release Notes

## v0.13.7
* Added [`@autosize` macro](https://github.com/FluxML/Flux.jl/pull/2078)

## v0.13.4
* Added [`PairwiseFusion` layer](https://github.com/FluxML/Flux.jl/pull/1983)

Expand Down
10 changes: 0 additions & 10 deletions src/outputsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,16 +261,6 @@ function _underscoredepth(ex::Expr)
end
_underscoredepth(ex) = Int(ex == :_)

#=
@autosize (3,) Chain(one = Dense(_ => 4), two = softmax) # needs kw
@autosize (3, 45) Maxout(() -> Dense(_ => 6, tanh), 2) # needs ->, block
# here Parallel gets two inputs, no problem:
@autosize (3,) Chain(SkipConnection(Dense(_ => 4), Parallel(vcat, Dense(_ => 5), Dense(_ => 6))), Flux.Scale(_))
=#

function _makefun(ex)
T = Meta.isexpr(ex, :call) ? ex.args[1] : Type
@gensym x s
Expand Down
65 changes: 65 additions & 0 deletions test/outputsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,16 +142,81 @@ end
m = LayerNorm(32)
@test outputsize(m, (32, 32, 3, 16)) == (32, 32, 3, 16)
@test outputsize(m, (32, 32, 3); padbatch=true) == (32, 32, 3, 1)
m2 = LayerNorm(3, 2)
@test outputsize(m2, (3, 2)) == (3, 2) == size(m2(randn(3, 2)))
@test outputsize(m2, (3,)) == (3, 2) == size(m2(randn(3, 2)))

m = BatchNorm(3)
@test outputsize(m, (32, 32, 3, 16)) == (32, 32, 3, 16)
@test outputsize(m, (32, 32, 3); padbatch=true) == (32, 32, 3, 1)
@test_throws Exception m(randn(Float32, 32, 32, 5, 1))
@test_throws DimensionMismatch outputsize(m, (32, 32, 5, 1))

m = InstanceNorm(3)
@test outputsize(m, (32, 32, 3, 16)) == (32, 32, 3, 16)
@test outputsize(m, (32, 32, 3); padbatch=true) == (32, 32, 3, 1)
@test_throws Exception m(randn(Float32, 32, 32, 5, 1))
@test_throws DimensionMismatch outputsize(m, (32, 32, 5, 1))

m = GroupNorm(16, 4)
@test outputsize(m, (32, 32, 16, 16)) == (32, 32, 16, 16)
@test outputsize(m, (32, 32, 16); padbatch=true) == (32, 32, 16, 1)
@test_throws Exception m(randn(Float32, 32, 32, 15, 4))
@test_throws DimensionMismatch outputsize(m, (32, 32, 15, 4))
end

@testset "autosize macro" begin
m = @autosize (3,) Dense(_ => 4)
@test randn(3) |> m |> size == (4,)

m = @autosize (3, 1) Chain(Dense(_ => 4), Dense(4 => 10), softmax)
@test randn(3, 5) |> m |> size == (10, 5)

m = @autosize (2, 3, 4, 5) Dense(_ => 10) # goes by first dim, not 2nd-last
@test randn(2, 3, 4, 5) |> m |> size == (10, 3, 4, 5)

m = @autosize (9,) Dense(_ => div(_,2))
@test randn(9) |> m |> size == (4,)

m = @autosize (3,) Chain(one = Dense(_ => 4), two = softmax) # needs kw
@test randn(3) |> m |> size == (4,)

m = @autosize (3, 45) Maxout(() -> Dense(_ => 6, tanh), 2) # needs ->, block
@test randn(3, 45) |> m |> size == (6, 45)

# here Parallel gets two inputs, no problem:
m = @autosize (3,) Chain(SkipConnection(Dense(_ => 4), Parallel(vcat, Dense(_ => 5), Dense(_ => 6))), Flux.Scale(_))
@test randn(3) |> m |> size == (11,)

# like Dense, LayerNorm goes by the first dimension:
m = @autosize (3, 4, 5) LayerNorm(_)
@test rand(3, 6, 7) |> m |> size == (3, 6, 7)

m = @autosize (3, 3, 10) LayerNorm(_, _) # does not check that sizes match
@test rand(3, 3, 10) |> m |> size == (3, 3, 10)

m = @autosize (3,) Flux.Bilinear(_ => 10)
@test randn(3) |> m |> size == (10,)

m = @autosize (3, 1) Flux.Bilinear(_ => 10)
@test randn(3, 4) |> m |> size == (10, 4)

@test_throws Exception @eval @autosize (3,) Flux.Bilinear((_,3) => 10)

# first docstring example
m = @autosize (3, 1) Chain(Dense(_ => 2, sigmoid), BatchNorm(_, affine=false))
@test randn(3, 4) |> m |> size == (2, 4)

# evil docstring example
img = [28, 28];
m = @autosize (img..., 1, 32) Chain( # size is only needed at runtime
Chain(c = Conv((3,3), _ => 5; stride=2, pad=SamePad()),
p = MeanPool((3,3)),
b = BatchNorm(_),
f = Flux.flatten),
Dense(_ => _÷4, relu, init=Flux.rand32), # can calculate output size _÷4
SkipConnection(Dense(_ => _, relu), +),
Dense(_ => 10),
) |> gpu # moves to GPU after initialisation
@test randn(Float32, img..., 1, 32) |> gpu |> m |> size == (10, 32)
end

0 comments on commit 310b71e

Please sign in to comment.