Skip to content

Commit

Permalink
fix Stacked to work with Zygote (#315)
Browse files Browse the repository at this point in the history
* fix Stacked bijector to work with Zygote

* fix formatting

* fix `Stacked` by fusing two loops into one in a Zygote-friendly way

* remove trailing whitespace
  • Loading branch information
Red-Portal committed Jun 13, 2024
1 parent fd53666 commit 9115f1c
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 8 deletions.
13 changes: 5 additions & 8 deletions src/bijectors/stacked.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,15 +230,12 @@ end
end

function _with_logabsdet_jacobian(sb::Stacked, x::AbstractVector)
N = length(sb.bs)
yinit, linit = with_logabsdet_jacobian(sb.bs[1], x[sb.ranges_in[1]])
logjac = sum(linit)
ys = mapreduce(vcat, sb.bs[2:end], sb.ranges_in[2:end]; init=yinit) do b, r
y, l = with_logabsdet_jacobian(b, x[r])
logjac += sum(l)
y
ys_and_logjacs = map(zip(sb.bs, sb.ranges_in)) do (b, r)
with_logabsdet_jacobian(b, x[r])
end
return (ys, logjac)
y = reduce(vcat, map(first, ys_and_logjacs))
logjac = sum(map(last, ys_and_logjacs))
return (y, logjac)
end

function with_logabsdet_jacobian(sb::Stacked, x::AbstractVector)
Expand Down
27 changes: 27 additions & 0 deletions test/ad/stacked.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
@testset "AD for StackedBijector" begin
dist1 = Dirichlet(4, 1.0)
b1 = bijector(dist1)

dist2 = LogNormal(0.0, 1.0)
b2 = bijector(dist2)

x1 = rand(dist1)
x2 = rand(dist2)

y1 = b1(x1)
y2 = b2(x2)

b = Stacked((b1, b2), (1:4, 5:5))
binv = inverse(b)

y = vcat(y1, [y2])
x = binv(y)

test_ad(y) do x
sum(transform(b, binv(x)))
end

test_ad(y) do y
sum(transform(binv, y))
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,5 @@ if GROUP == "All" || GROUP == "AD"
include("ad/flows.jl")
include("ad/pd.jl")
include("ad/corr.jl")
include("ad/stacked.jl")
end

0 comments on commit 9115f1c

Please sign in to comment.