Skip to content

Commit

Permalink
Merge ec1e844 into dfca913
Browse files Browse the repository at this point in the history
  • Loading branch information
vandenman committed Sep 14, 2020
2 parents dfca913 + ec1e844 commit 5502d62
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.8.3"
version = "0.8.4"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
4 changes: 4 additions & 0 deletions src/bijectors/stacked.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ function logabsdetjac(
end
end

function logabsdetjac(b::Stacked{<:Any, 1}, x::AbstractVector{<:Real})
return sum(logabsdetjac(b.bs[1], x[b.ranges[1]]))
end

function logabsdetjac(b::Stacked, x::AbstractMatrix{<:Real})
return map(eachcol(x)) do c
logabsdetjac(b, c)
Expand Down
8 changes: 8 additions & 0 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,14 @@ end
# AD verification
@test log(abs(det(ForwardDiff.jacobian(sb, x)))) logabsdetjac(sb, x)
@test log(abs(det(ForwardDiff.jacobian(isb, y)))) logabsdetjac(isb, y)

# Ensure `Stacked` works for a single bijector
d = (MvNormal(2, 1.0),)
sb = Stacked(bijector.(d), [1:2])
x = [.5, 1.]
@test sb(x) == x
@test logabsdetjac(sb, x) == 0
@test forward(sb, x) == (rv = x, logabsdetjac = zero(eltype(x)))
end
end

Expand Down

0 comments on commit 5502d62

Please sign in to comment.