From 15f06c9a2eb69c7f3ebb477b6aec427ba9607989 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 8 Mar 2022 13:57:24 -0500 Subject: [PATCH] add dims kw to stack, squeeze, etc --- test/outputsize.jl | 2 +- test/utils.jl | 33 +++++++++++++++------------------ 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/test/outputsize.jl b/test/outputsize.jl index 2c90811dcb..8304d84b56 100644 --- a/test/outputsize.jl +++ b/test/outputsize.jl @@ -23,7 +23,7 @@ m = flatten @test outputsize(m, (5, 5, 3, 10)) == (75, 10) - m = Flux.unsqueeze(3) + m = Flux.unsqueeze(dims=3) @test outputsize(m, (5, 7, 13)) == (5, 7, 1, 13) m = Flux.Bilinear(10, 10, 7) diff --git a/test/utils.jl b/test/utils.jl index f240c56882..c75780e360 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -8,14 +8,6 @@ using Statistics, LinearAlgebra using Random using Test -@testset "unsqueeze" begin - x = randn(2, 3, 2) - @test @inferred(unsqueeze(x, 1)) == reshape(x, 1, 2, 3, 2) - @test @inferred(unsqueeze(x, 2)) == reshape(x, 2, 1, 3, 2) - @test @inferred(unsqueeze(x, 3)) == reshape(x, 2, 3, 1, 2) - @test @inferred(unsqueeze(x, 4)) == reshape(x, 2, 3, 2, 1) -end - @testset "Throttle" begin @testset "default behaviour" begin a = [] @@ -244,12 +236,6 @@ end @test size.(params(r)) == [(5, 10), (5, 5), (5,), (5, 1)] end -@testset "Basic Stacking" begin - x = randn(3,3) - stacked = stack([x, x], 2) - @test size(stacked) == (3,2,3) -end - @testset "Precision" begin m = Chain(Dense(10, 5, relu), Dense(5, 2)) x64 = rand(Float64, 10) @@ -297,15 +283,26 @@ end end end +@testset "unsqueeze" begin + x = randn(2, 3, 2) + @test @inferred(unsqueeze(x, dims=1)) == reshape(x, 1, 2, 3, 2) + @test @inferred(unsqueeze(x, dims=2)) == reshape(x, 2, 1, 3, 2) + @test @inferred(unsqueeze(x, dims=3)) == reshape(x, 2, 3, 1, 2) + @test @inferred(unsqueeze(x, dims=4)) == reshape(x, 2, 3, 2, 1) +end + @testset "Stacking" begin + x = randn(3,3) + stacked = stack([x, x], dims=2) + @test size(stacked) == (3,2,3) + stacked_array=[ 8 9 3 5; 9 6 6 9; 9 1 7 2; 7 4 10 6 ] unstacked_array=[[8, 9, 9, 7], [9, 6, 1, 4], [3, 6, 7, 10], [5, 9, 2, 6]] - @test unstack(stacked_array, 2) == unstacked_array - @test stack(unstacked_array, 2) == stacked_array - @test stack(unstack(stacked_array, 1), 1) == stacked_array + @test unstack(stacked_array, dims=2) == unstacked_array + @test stack(unstacked_array, dims=2) == stacked_array + @test stack(unstack(stacked_array, dims=1), dims=1) == stacked_array end - @testset "Batching" begin stacked_array=[ 8 9 3 5 9 6 6 9