diff --git a/Project.toml b/Project.toml index cffa220e..afd1d44c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.8.3" +version = "0.8.4" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index d5b3faf6..9be9ebaf 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -104,6 +104,10 @@ function logabsdetjac( end end +function logabsdetjac(b::Stacked{<:Any, 1}, x::AbstractVector{<:Real}) + return sum(logabsdetjac(b.bs[1], x[b.ranges[1]])) +end + function logabsdetjac(b::Stacked, x::AbstractMatrix{<:Real}) return map(eachcol(x)) do c logabsdetjac(b, c) diff --git a/test/interface.jl b/test/interface.jl index f87b2bca..1fa6d0fe 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -708,6 +708,14 @@ end # AD verification @test log(abs(det(ForwardDiff.jacobian(sb, x)))) ≈ logabsdetjac(sb, x) @test log(abs(det(ForwardDiff.jacobian(isb, y)))) ≈ logabsdetjac(isb, y) + + # Ensure `Stacked` works for a single bijector + d = (MvNormal(2, 1.0),) + sb = Stacked(bijector.(d), [1:2]) + x = [.5, 1.] + @test sb(x) == x + @test logabsdetjac(sb, x) == 0 + @test forward(sb, x) == (rv = x, logabsdetjac = zero(eltype(x))) end end