From 214708974cdb3fa56d7ed3289032659bb447015e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 18 Jun 2023 23:08:53 +0100 Subject: [PATCH] Added `output_length` and `output_size` (#270) * added output_length and output_size to compute output, well, leengths and sizes for transformations * added tests for size of transformed dist using VcCorrBijector * use already constructed transfrormation * TransformedDistribution should now also have correct variate form * added proper variateform handling for VecCholeskyBijector too * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added output_size impl for Reshape too * bump minor version * Apply suggestions from code review Co-authored-by: David Widmann * Update src/interface.jl * Update src/bijectors/corr.jl * reverted removal of length as we'll need it now * updated Stacked to be compat with changing sizes * forgot to commit deetion * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * add testing of sizes to `test_bijector` * some more tests for stacked * Update test/bijectors/stacked.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added awful generated function to determine output ranges for Stacked with tuple because recursive implementation fail * added slightly more informative comment * format * more fixes to that damned Stacked * Update test/interface.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * specialized constructors for Stacked further * fixed bug in output_size for CorrVecBijector * Apply suggestions from code review Co-authored-by: David Widmann * Apply suggestions from code review Co-authored-by: David Widmann --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: David Widmann --- Project.toml | 2 +- src/bijectors/corr.jl | 16 ++++ src/bijectors/reshape.jl | 2 + src/bijectors/stacked.jl | 143 ++++++++++++++++++++++++++------ src/interface.jl | 14 ++++ src/transformed_distribution.jl | 27 +++--- test/bijectors/corr.jl | 18 ++++ test/bijectors/stacked.jl | 37 +++++++++ test/bijectors/utils.jl | 6 ++ test/interface.jl | 20 ++--- 10 files changed, 235 insertions(+), 50 deletions(-) create mode 100644 test/bijectors/stacked.jl diff --git a/Project.toml b/Project.toml index 978d97df..dcf1ce79 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.12.8" +version = "0.13.0" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 9367b0cd..a4ed4740 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -232,6 +232,17 @@ function logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) return _logabsdetjac_inv_corr(y) end +function output_size(::VecCorrBijector, sz::Tuple{Int,Int}) + sz[1] == sz[2] || error("sizes should be equal; received $(sz)") + n = sz[1] + return ((n * (n - 1)) รท 2,) +end + +function output_size(::Inverse{VecCorrBijector}, sz::Tuple{Int}) + n = _triu1_dim_from_length(first(sz)) + return (n, n) +end + """ VecCholeskyBijector <: Bijector @@ -317,6 +328,11 @@ function logabsdetjac(::Inverse{VecCholeskyBijector}, y::AbstractVector{<:Real}) return _logabsdetjac_inv_chol(y) end +output_size(::VecCholeskyBijector, sz::Tuple{Int,Int}) = output_size(VecCorrBijector(), sz) +function output_size(::Inverse{<:VecCholeskyBijector}, sz::Tuple{Int}) + return output_size(inverse(VecCorrBijector()), sz) +end + """ function _link_chol_lkj(w) diff --git a/src/bijectors/reshape.jl b/src/bijectors/reshape.jl index 8a8bd1e4..4f8665cd 100644 --- a/src/bijectors/reshape.jl +++ b/src/bijectors/reshape.jl @@ -25,3 +25,5 @@ end inverse(b::Reshape) = Reshape(b.out_shape, b.in_shape) with_logabsdet_jacobian(b::Reshape, x) = reshape(x, b.out_shape), zero(eltype(x)) + +output_size(b::Reshape, in_size) = b.out_shape diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index 4f0596cb..73abf51c 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -23,25 +23,85 @@ b([0.0, 1.0]) == [b1(0.0), 1.0] # => true """ struct Stacked{Bs,Rs<:Union{Tuple,AbstractArray}} <: Transform bs::Bs - ranges::Rs + ranges_in::Rs + ranges_out::Rs + length_in::Int + length_out::Int end + +function Stacked(bs::AbstractArray, ranges_in::AbstractArray) + ranges_out = determine_output_ranges(bs, ranges_in) + return Stacked{typeof(bs),typeof(ranges_in)}( + bs, ranges_in, ranges_out, sum(length, ranges_in), sum(length, ranges_out) + ) +end +function Stacked(bs::Tuple, ranges_in::Tuple) + ranges_out = determine_output_ranges(bs, ranges_in) + return Stacked{typeof(bs),typeof(ranges_in)}( + bs, ranges_in, ranges_out, sum(length, ranges_in), sum(length, ranges_out) + ) +end +Stacked(bs::AbstractArray, ranges::Tuple) = Stacked(bs, collect(ranges)) +Stacked(bs::Tuple, ranges::AbstractArray) = Stacked(collect(bs), ranges) Stacked(bs::Tuple) = Stacked(bs, ntuple(i -> i:i, length(bs))) Stacked(bs::AbstractArray) = Stacked(bs, [i:i for i in 1:length(bs)]) Stacked(bs...) = Stacked(bs, ntuple(i -> i:i, length(bs))) +function determine_output_ranges(bs, ranges) + offset = 0 + return map(bs, ranges) do b, r + out_length = output_length(b, length(r)) + r = offset .+ (1:out_length) + offset += out_length + return r + end +end + +# NOTE: I don't like this but it seems necessary because `Stacked(...)` can occur in hot code paths. +function determine_output_ranges(bs::Tuple, ranges::Tuple) + return determine_output_ranges_generated(bs, ranges) +end +@generated function determine_output_ranges_generated(bs::Tuple, ranges::Tuple) + N = length(bs.parameters) + exprs = [] + push!(exprs, :(offset = 0)) + + rsyms = [] + for i in 1:N + rsym = Symbol("r_$i") + lengthsym = Symbol("length_$i") + push!(exprs, :($lengthsym = output_length(bs[$i], length(ranges[$i])))) + push!(exprs, :($rsym = offset .+ (1:($lengthsym)))) + push!(exprs, :(offset += $lengthsym)) + + push!(rsyms, rsym) + end + + acc_expr = Expr(:tuple, rsyms...) + + return quote + $(exprs...) + return $acc_expr + end +end + # Avoid mixing tuples and arrays. Stacked(bs::Tuple, ranges::AbstractArray) = Stacked(collect(bs), ranges) Functors.@functor Stacked (bs,) -Base.show(io::IO, b::Stacked) = print(io, "Stacked($(b.bs), $(b.ranges))") +function Base.show(io::IO, b::Stacked) + return print(io, "Stacked($(b.bs), $(b.ranges_in), $(b.ranges_out))") +end function Base.:(==)(b1::Stacked, b2::Stacked) bs1, bs2 = b1.bs, b2.bs if !(bs1 isa Tuple && bs2 isa Tuple || bs1 isa Vector && bs2 isa Vector) return false end - return all(bs1 .== bs2) && all(b1.ranges .== b2.ranges) + return all(bs1 .== bs2) && + all(b1.ranges_in .== b2.ranges_in) && + all(b1.ranges_out .== b2.ranges_out) end isclosedform(b::Stacked) = all(isclosedform, b.bs) @@ -49,7 +109,11 @@ isclosedform(b::Stacked) = all(isclosedform, b.bs) isinvertible(b::Stacked) = all(isinvertible, b.bs) # For some reason `inverse.(sb.bs)` was unstable... This works though. -inverse(sb::Stacked) = Stacked(map(inverse, sb.bs), sb.ranges) +function inverse(sb::Stacked) + return Stacked( + map(inverse, sb.bs), sb.ranges_out, sb.ranges_in, sb.length_out, sb.length_in + ) +end # map is not type stable for many stacked bijectors as a large tuple # hence the generated function @generated function inverse(sb::Stacked{A}) where {A<:Tuple} @@ -57,44 +121,59 @@ inverse(sb::Stacked) = Stacked(map(inverse, sb.bs), sb.ranges) for i in 1:length(A.parameters) push!(exprs, :(inverse(sb.bs[$i]))) end - return :(Stacked(($(exprs...),), sb.ranges)) + return :(Stacked( + ($(exprs...),), sb.ranges_out, sb.ranges_in, sb.length_out, sb.length_in + )) end -@generated function _transform(x, rs::NTuple{N,UnitRange{Int}}, bs...) where {N} +output_size(b::Stacked, sz::Tuple{Int}) = (b.length_out,) + +@generated function _transform_stacked_recursive( + x, rs::NTuple{N,UnitRange{Int}}, bs... +) where {N} exprs = [] for i in 1:N push!(exprs, :(bs[$i](x[rs[$i]]))) end return :(vcat($(exprs...))) end -function _transform(x, rs::NTuple{1,UnitRange{Int}}, b) - @assert rs[1] == 1:length(x) +function _transform_stacked_recursive(x, rs::NTuple{1,UnitRange{Int}}, b) + rs[1] == 1:length(x) || error("range must be 1:length(x)") return b(x) end -function transform(sb::Stacked{<:Tuple,<:Tuple}, x::AbstractVector{<:Real}) - y = _transform(x, sb.ranges, sb.bs...) - @assert size(y) == size(x) "x is size $(size(x)) but y is $(size(y))" +function _transform_stacked(sb::Stacked{<:Tuple,<:Tuple}, x::AbstractVector{<:Real}) + y = _transform_stacked_recursive(x, sb.ranges_in, sb.bs...) return y end # The Stacked{<:AbstractArray} version is not TrackedArray friendly -function transform(sb::Stacked{<:AbstractArray}, x::AbstractVector{<:Real}) +function _transform_stacked(sb::Stacked{<:AbstractArray}, x::AbstractVector{<:Real}) N = length(sb.bs) - N == 1 && return sb.bs[1](x[sb.ranges[1]]) + N == 1 && return sb.bs[1](x[sb.ranges_in[1]]) y = mapvcat(1:N) do i - sb.bs[i](x[sb.ranges[i]]) + sb.bs[i](x[sb.ranges_in[i]]) + end + return y +end + +function transform(sb::Stacked, x::AbstractVector{<:Real}) + if sb.length_in != length(x) + error("input length mismatch ($(sb.length_in) != $(length(x)))") + end + y = _transform_stacked(sb, x) + if sb.length_out != length(y) + error("output length mismatch ($(sb.length_out) != $(length(y)))") end - @assert size(y) == size(x) "x is size $(size(x)) but y is $(size(y))" return y end function logabsdetjac(b::Stacked, x::AbstractVector{<:Real}) N = length(b.bs) - init = sum(logabsdetjac(b.bs[1], x[b.ranges[1]])) + init = sum(logabsdetjac(b.bs[1], x[b.ranges_in[1]])) return if N > 1 init + sum(2:N) do i - sum(logabsdetjac(b.bs[i], x[b.ranges[i]])) + sum(logabsdetjac(b.bs[i], x[b.ranges_in[i]])) end else init @@ -104,13 +183,13 @@ end function logabsdetjac( b::Stacked{<:NTuple{N,Any},<:NTuple{N,Any}}, x::AbstractVector{<:Real} ) where {N} - init = sum(logabsdetjac(b.bs[1], x[b.ranges[1]])) + init = sum(logabsdetjac(b.bs[1], x[b.ranges_in[1]])) return if N == 1 init else init + sum(2:N) do i - sum(logabsdetjac(b.bs[i], x[b.ranges[i]])) + sum(logabsdetjac(b.bs[i], x[b.ranges_in[i]])) end end end @@ -124,13 +203,15 @@ end # logjac += sum(_logjac) # return (vcat(y_1, y_2), logjac) # end -@generated function with_logabsdet_jacobian( +@generated function _with_logabsdet_jacobian( b::Stacked{<:NTuple{N,Any},<:NTuple{N,Any}}, x::AbstractVector ) where {N} expr = Expr(:block) y_names = [] - push!(expr.args, :((y_1, _logjac) = with_logabsdet_jacobian(b.bs[1], x[b.ranges[1]]))) + push!( + expr.args, :((y_1, _logjac) = with_logabsdet_jacobian(b.bs[1], x[b.ranges_in[1]])) + ) # TODO: drop the `sum` when we have dimensionality push!(expr.args, :(logjac = sum(_logjac))) push!(y_names, :y_1) @@ -138,7 +219,7 @@ end y_name = Symbol("y_$i") push!( expr.args, - :(($y_name, _logjac) = with_logabsdet_jacobian(b.bs[$i], x[b.ranges[$i]])), + :(($y_name, _logjac) = with_logabsdet_jacobian(b.bs[$i], x[b.ranges_in[$i]])), ) # TODO: drop the `sum` when we have dimensionality @@ -151,14 +232,26 @@ end return expr end -function with_logabsdet_jacobian(sb::Stacked, x::AbstractVector) +function _with_logabsdet_jacobian(sb::Stacked, x::AbstractVector) N = length(sb.bs) - yinit, linit = with_logabsdet_jacobian(sb.bs[1], x[sb.ranges[1]]) + yinit, linit = with_logabsdet_jacobian(sb.bs[1], x[sb.ranges_in[1]]) logjac = sum(linit) - ys = mapreduce(vcat, sb.bs[2:end], sb.ranges[2:end]; init=yinit) do b, r + ys = mapreduce(vcat, sb.bs[2:end], sb.ranges_in[2:end]; init=yinit) do b, r y, l = with_logabsdet_jacobian(b, x[r]) logjac += sum(l) y end return (ys, logjac) end + +function with_logabsdet_jacobian(sb::Stacked, x::AbstractVector) + if sb.length_in != length(x) + error("input length mismatch ($(sb.length_in) != $(length(x)))") + end + y, logjac = _with_logabsdet_jacobian(sb, x) + if output_length(sb, length(x)) != length(y) + error("output length mismatch ($(output_length(sb, length(x))) != $(length(y)))") + end + + return (y, logjac) +end diff --git a/src/interface.jl b/src/interface.jl index 91c9a961..099df1bb 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -35,6 +35,20 @@ function logabsdetjac(f::Columnwise, x::AbstractMatrix) end with_logabsdet_jacobian(f::Columnwise, x::AbstractMatrix) = (f(x), logabsdetjac(f, x)) +""" + output_size(f, sz) + +Returns the output size of `f` given the input size `sz`. +""" +output_size(f, sz) = sz + +""" + output_length(f, len::Int) + +Returns the output length of `f` given the input length `len`. +""" +output_length(f, len::Int) = only(output_size(f, (len,))) + ###################### # Bijector interface # ###################### diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index eccbb64c..04c3a559 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -1,20 +1,19 @@ +function variateform(d::Distribution, b) + sz_in = size(d) + sz_out = output_size(b, sz_in) + return ArrayLikeVariate{length(sz_out)} +end + +variateform(::MultivariateDistribution, ::Inverse{VecCholeskyBijector}) = CholeskyVariate + # Transformed distributions struct TransformedDistribution{D,B,V} <: - Distribution{V,Continuous} where {D<:Distribution{V,Continuous},B} + Distribution{V,Continuous} where {D<:ContinuousDistribution,B} dist::D transform::B - function TransformedDistribution(d::UnivariateDistribution, b) - return new{typeof(d),typeof(b),Univariate}(d, b) - end - function TransformedDistribution(d::MultivariateDistribution, b) - return new{typeof(d),typeof(b),Multivariate}(d, b) - end - function TransformedDistribution(d::MatrixDistribution, b) - return new{typeof(d),typeof(b),Matrixvariate}(d, b) - end - function TransformedDistribution(d::Distribution{CholeskyVariate}, b) - return new{typeof(d),typeof(b),CholeskyVariate}(d, b) + function TransformedDistribution(d::ContinuousDistribution, b) + return new{typeof(d),typeof(b),variateform(d, b)}(d, b) end end @@ -101,8 +100,8 @@ end ############################## # size -Base.length(td::Transformed) = length(td.dist) -Base.size(td::Transformed) = size(td.dist) +Base.length(td::Transformed) = prod(output_size(td.transform, size(td.dist))) +Base.size(td::Transformed) = output_size(td.transform, size(td.dist)) function logpdf(td::UnivariateTransformed, y::Real) x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index 71bfd8d7..8a423bc3 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -33,6 +33,15 @@ using Bijectors: VecCorrBijector, VecCholeskyBijector, CorrBijector test_bijector(bvec, x; test_not_identity=d != 1, changes_of_variables_test=false) test_ad(x -> sum(bvec(bvecinv(x))), yvec) + + # Check that output sizes are computed correctly. + tdist = transformed(dist) + @test length(tdist) == length(yvec) + @test tdist isa MultivariateDistribution + + dist_unconstrained = transformed(MvNormal(zeros(length(tdist)), I), inverse(bvec)) + @test size(dist_unconstrained) == size(x) + @test dist_unconstrained isa MatrixDistribution end end @@ -60,6 +69,15 @@ end # test_bijector is commented out for now, # as isapprox is not defined for ::Cholesky types (the domain of LKJCholesky) # test_bijector(b, x; test_not_identity=d != 1, changes_of_variables_test=false) + + # Check that output sizes are computed correctly. + tdist = transformed(dist) + @test length(tdist) == length(y) + @test tdist isa MultivariateDistribution + + dist_unconstrained = transformed(MvNormal(zeros(length(tdist)), I), inverse(b)) + @test size(dist_unconstrained) == size(x) + @test dist_unconstrained isa Distribution{CholeskyVariate,Continuous} end end end diff --git a/test/bijectors/stacked.jl b/test/bijectors/stacked.jl new file mode 100644 index 00000000..655b63eb --- /dev/null +++ b/test/bijectors/stacked.jl @@ -0,0 +1,37 @@ +struct ProjectionBijector <: Bijectors.Bijector end + +Bijectors.output_size(::ProjectionBijector, sz::Tuple{Int}) = (sz[1] - 1,) +Bijectors.output_size(::Inverse{ProjectionBijector}, sz::Int) = (sz[1] + 1,) + +function Bijectors.with_logabsdet_jacobian(::ProjectionBijector, x::AbstractVector) + return x[1:(end - 1)], 0 +end +function Bijectors.with_logabsdet_jacobian(::Inverse{ProjectionBijector}, x::AbstractVector) + return vcat(x, 0), 0 +end + +@testset "Stacked with differing input and output size" begin + bs = [ + Stacked((elementwise(exp), ProjectionBijector()), (1:1, 2:3)), + Stacked([elementwise(exp), ProjectionBijector()], [1:1, 2:3]), + Stacked([elementwise(exp), ProjectionBijector()], (1:1, 2:3)), + Stacked((elementwise(exp), ProjectionBijector()), [1:1, 2:3]), + ] + @testset "$b" for b in bs + binv = inverse(b) + x = [1.0, 2.0, 3.0] + y = b(x) + x_ = binv(y) + + # Are the values of correct size? + @test size(y) == (2,) + @test size(x_) == (3,) + # Can we determine the sizes correctly? + @test Bijectors.output_size(b, size(x)) == (2,) + @test Bijectors.output_size(binv, size(y)) == (3,) + + # Are values correct? + @test y == [exp(1.0), 2.0] + @test binv(y) == [1.0, 2.0, 0.0] + end +end diff --git a/test/bijectors/utils.jl b/test/bijectors/utils.jl index b9ab1242..8c31ccec 100644 --- a/test/bijectors/utils.jl +++ b/test/bijectors/utils.jl @@ -14,6 +14,7 @@ function test_bijector( test_types=false, changes_of_variables_test=true, inverse_functions_test=true, + test_sizes=true, compare=isapprox, kwargs..., ) @@ -31,6 +32,11 @@ function test_bijector( @inferred(with_logabsdet_jacobian(inverse(b), y_test)) end + if test_sizes + @test Bijectors.output_size(b, size(x)) == size(y_test) + @test Bijectors.output_size(ib, size(y_test)) == size(x) + end + # ChangesOfVariables.jl # For non-bijective transformations, these tests always fail since determinant of # the Jacobian is zero. Hence we allow the caller to disable them if necessary. diff --git a/test/interface.jl b/test/interface.jl index 7d989abf..c316ed09 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -267,7 +267,7 @@ end @test sb(x) == [exp(x[1]), log(x[2]), x[3] + 5.0] @test res[1] == [exp(x[1]), log(x[2]), x[3] + 5.0] @test logabsdetjac(sb, x) == - sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i in 1:3]) + sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges_in[i]])) for i in 1:3]) @test res[2] == logabsdetjac(sb, x) # TODO: change when we have dimensionality in the type @@ -278,11 +278,11 @@ end @test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...] @test res[1] == [exp(x[1]), sb.bs[2](x[2:3])...] @test logabsdetjac(sb, x) == - sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i in 1:2]) + sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges_in[i]])) for i in 1:2]) @test res[2] == logabsdetjac(sb, x) x = ones(4) ./ 4.0 - @test_throws AssertionError sb(x) + @test_throws ErrorException sb(x) # Array-version sb = Stacked([elementwise(exp), SimplexBijector()], [1:1, 2:3]) @@ -292,11 +292,11 @@ end @test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...] @test res[1] == [exp(x[1]), sb.bs[2](x[2:3])...] @test logabsdetjac(sb, x) == - sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i in 1:2]) + sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges_in[i]])) for i in 1:2]) @test res[2] == logabsdetjac(sb, x) x = ones(4) ./ 4.0 - @test_throws AssertionError sb(x) + @test_throws ErrorException sb(x) # Mixed versions # Tuple, Array @@ -307,11 +307,11 @@ end @test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...] @test res[1] == [exp(x[1]), sb.bs[2](x[2:3])...] @test logabsdetjac(sb, x) == - sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i in 1:2]) + sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges_in[i]])) for i in 1:2]) @test res[2] == logabsdetjac(sb, x) x = ones(4) ./ 4.0 - @test_throws AssertionError sb(x) + @test_throws ErrorException sb(x) # Array, Tuple sb = Stacked((elementwise(exp), SimplexBijector()), [1:1, 2:3]) @@ -321,11 +321,11 @@ end @test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...] @test res[1] == [exp(x[1]), sb.bs[2](x[2:3])...] @test logabsdetjac(sb, x) == - sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i in 1:2]) + sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges_in[i]])) for i in 1:2]) @test res[2] == logabsdetjac(sb, x) x = ones(4) ./ 4.0 - @test_throws AssertionError sb(x) + @test_throws ErrorException sb(x) @testset "Stacked: ADVI with MvNormal" begin # MvNormal test @@ -369,7 +369,7 @@ end # check that wrong ranges fails sb = Stacked(ibs) x = rand(d) - @test_throws AssertionError sb(x) + @test_throws ErrorException sb(x) # Stacked{<:Tuple} bs = bijector.(tuple(dists...))