Skip to content

Commit

Permalink
Added output_length and output_size (#270)
Browse files Browse the repository at this point in the history
* added output_length and output_size to compute output, well, leengths
and sizes for transformations

* added tests for size of transformed dist using VcCorrBijector

* use already constructed transfrormation

* TransformedDistribution should now also have correct variate form

* added proper variateform handling for VecCholeskyBijector too

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* added output_size impl for Reshape too

* bump minor version

* Apply suggestions from code review

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* Update src/interface.jl

* Update src/bijectors/corr.jl

* reverted removal of length as we'll need it now

* updated Stacked to be compat with changing sizes

* forgot to commit deetion

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* add testing of sizes to `test_bijector`

* some more tests for stacked

* Update test/bijectors/stacked.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* added awful generated function to determine output ranges for Stacked
with tuple because recursive implementation fail

* added slightly more informative comment

* format

* more fixes to that damned Stacked

* Update test/interface.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* specialized constructors for Stacked further

* fixed bug in output_size for CorrVecBijector

* Apply suggestions from code review

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
  • Loading branch information
3 people committed Jun 18, 2023
1 parent 3fe36ac commit 2147089
Show file tree
Hide file tree
Showing 10 changed files with 235 additions and 50 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.12.8"
version = "0.13.0"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
16 changes: 16 additions & 0 deletions src/bijectors/corr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,17 @@ function logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real})
return _logabsdetjac_inv_corr(y)
end

function output_size(::VecCorrBijector, sz::Tuple{Int,Int})
sz[1] == sz[2] || error("sizes should be equal; received $(sz)")
n = sz[1]
return ((n * (n - 1)) ÷ 2,)
end

function output_size(::Inverse{VecCorrBijector}, sz::Tuple{Int})
n = _triu1_dim_from_length(first(sz))
return (n, n)
end

"""
VecCholeskyBijector <: Bijector
Expand Down Expand Up @@ -317,6 +328,11 @@ function logabsdetjac(::Inverse{VecCholeskyBijector}, y::AbstractVector{<:Real})
return _logabsdetjac_inv_chol(y)
end

output_size(::VecCholeskyBijector, sz::Tuple{Int,Int}) = output_size(VecCorrBijector(), sz)
function output_size(::Inverse{<:VecCholeskyBijector}, sz::Tuple{Int})
return output_size(inverse(VecCorrBijector()), sz)
end

"""
function _link_chol_lkj(w)
Expand Down
2 changes: 2 additions & 0 deletions src/bijectors/reshape.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,5 @@ end
inverse(b::Reshape) = Reshape(b.out_shape, b.in_shape)

with_logabsdet_jacobian(b::Reshape, x) = reshape(x, b.out_shape), zero(eltype(x))

output_size(b::Reshape, in_size) = b.out_shape
143 changes: 118 additions & 25 deletions src/bijectors/stacked.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,78 +23,157 @@ b([0.0, 1.0]) == [b1(0.0), 1.0] # => true
"""
struct Stacked{Bs,Rs<:Union{Tuple,AbstractArray}} <: Transform
bs::Bs
ranges::Rs
ranges_in::Rs
ranges_out::Rs
length_in::Int
length_out::Int
end

function Stacked(bs::AbstractArray, ranges_in::AbstractArray)
ranges_out = determine_output_ranges(bs, ranges_in)
return Stacked{typeof(bs),typeof(ranges_in)}(
bs, ranges_in, ranges_out, sum(length, ranges_in), sum(length, ranges_out)
)
end
function Stacked(bs::Tuple, ranges_in::Tuple)
ranges_out = determine_output_ranges(bs, ranges_in)
return Stacked{typeof(bs),typeof(ranges_in)}(
bs, ranges_in, ranges_out, sum(length, ranges_in), sum(length, ranges_out)
)
end
Stacked(bs::AbstractArray, ranges::Tuple) = Stacked(bs, collect(ranges))
Stacked(bs::Tuple, ranges::AbstractArray) = Stacked(collect(bs), ranges)
Stacked(bs::Tuple) = Stacked(bs, ntuple(i -> i:i, length(bs)))
Stacked(bs::AbstractArray) = Stacked(bs, [i:i for i in 1:length(bs)])
Stacked(bs...) = Stacked(bs, ntuple(i -> i:i, length(bs)))

function determine_output_ranges(bs, ranges)
offset = 0
return map(bs, ranges) do b, r
out_length = output_length(b, length(r))
r = offset .+ (1:out_length)
offset += out_length
return r
end
end

# NOTE: I don't like this but it seems necessary because `Stacked(...)` can occur in hot code paths.
function determine_output_ranges(bs::Tuple, ranges::Tuple)
return determine_output_ranges_generated(bs, ranges)
end
@generated function determine_output_ranges_generated(bs::Tuple, ranges::Tuple)
N = length(bs.parameters)
exprs = []
push!(exprs, :(offset = 0))

rsyms = []
for i in 1:N
rsym = Symbol("r_$i")
lengthsym = Symbol("length_$i")
push!(exprs, :($lengthsym = output_length(bs[$i], length(ranges[$i]))))
push!(exprs, :($rsym = offset .+ (1:($lengthsym))))
push!(exprs, :(offset += $lengthsym))

push!(rsyms, rsym)
end

acc_expr = Expr(:tuple, rsyms...)

return quote
$(exprs...)
return $acc_expr
end
end

# Avoid mixing tuples and arrays.
Stacked(bs::Tuple, ranges::AbstractArray) = Stacked(collect(bs), ranges)

Functors.@functor Stacked (bs,)

Base.show(io::IO, b::Stacked) = print(io, "Stacked($(b.bs), $(b.ranges))")
function Base.show(io::IO, b::Stacked)
return print(io, "Stacked($(b.bs), $(b.ranges_in), $(b.ranges_out))")
end

function Base.:(==)(b1::Stacked, b2::Stacked)
bs1, bs2 = b1.bs, b2.bs
if !(bs1 isa Tuple && bs2 isa Tuple || bs1 isa Vector && bs2 isa Vector)
return false
end
return all(bs1 .== bs2) && all(b1.ranges .== b2.ranges)
return all(bs1 .== bs2) &&
all(b1.ranges_in .== b2.ranges_in) &&
all(b1.ranges_out .== b2.ranges_out)
end

isclosedform(b::Stacked) = all(isclosedform, b.bs)

isinvertible(b::Stacked) = all(isinvertible, b.bs)

# For some reason `inverse.(sb.bs)` was unstable... This works though.
inverse(sb::Stacked) = Stacked(map(inverse, sb.bs), sb.ranges)
function inverse(sb::Stacked)
return Stacked(
map(inverse, sb.bs), sb.ranges_out, sb.ranges_in, sb.length_out, sb.length_in
)
end
# map is not type stable for many stacked bijectors as a large tuple
# hence the generated function
@generated function inverse(sb::Stacked{A}) where {A<:Tuple}
exprs = []
for i in 1:length(A.parameters)
push!(exprs, :(inverse(sb.bs[$i])))
end
return :(Stacked(($(exprs...),), sb.ranges))
return :(Stacked(
($(exprs...),), sb.ranges_out, sb.ranges_in, sb.length_out, sb.length_in
))
end

@generated function _transform(x, rs::NTuple{N,UnitRange{Int}}, bs...) where {N}
output_size(b::Stacked, sz::Tuple{Int}) = (b.length_out,)

@generated function _transform_stacked_recursive(
x, rs::NTuple{N,UnitRange{Int}}, bs...
) where {N}
exprs = []
for i in 1:N
push!(exprs, :(bs[$i](x[rs[$i]])))
end
return :(vcat($(exprs...)))
end
function _transform(x, rs::NTuple{1,UnitRange{Int}}, b)
@assert rs[1] == 1:length(x)
function _transform_stacked_recursive(x, rs::NTuple{1,UnitRange{Int}}, b)
rs[1] == 1:length(x) || error("range must be 1:length(x)")
return b(x)
end
function transform(sb::Stacked{<:Tuple,<:Tuple}, x::AbstractVector{<:Real})
y = _transform(x, sb.ranges, sb.bs...)
@assert size(y) == size(x) "x is size $(size(x)) but y is $(size(y))"
function _transform_stacked(sb::Stacked{<:Tuple,<:Tuple}, x::AbstractVector{<:Real})
y = _transform_stacked_recursive(x, sb.ranges_in, sb.bs...)
return y
end
# The Stacked{<:AbstractArray} version is not TrackedArray friendly
function transform(sb::Stacked{<:AbstractArray}, x::AbstractVector{<:Real})
function _transform_stacked(sb::Stacked{<:AbstractArray}, x::AbstractVector{<:Real})
N = length(sb.bs)
N == 1 && return sb.bs[1](x[sb.ranges[1]])
N == 1 && return sb.bs[1](x[sb.ranges_in[1]])

y = mapvcat(1:N) do i
sb.bs[i](x[sb.ranges[i]])
sb.bs[i](x[sb.ranges_in[i]])
end
return y
end

function transform(sb::Stacked, x::AbstractVector{<:Real})
if sb.length_in != length(x)
error("input length mismatch ($(sb.length_in) != $(length(x)))")
end
y = _transform_stacked(sb, x)
if sb.length_out != length(y)
error("output length mismatch ($(sb.length_out) != $(length(y)))")
end
@assert size(y) == size(x) "x is size $(size(x)) but y is $(size(y))"
return y
end

function logabsdetjac(b::Stacked, x::AbstractVector{<:Real})
N = length(b.bs)
init = sum(logabsdetjac(b.bs[1], x[b.ranges[1]]))
init = sum(logabsdetjac(b.bs[1], x[b.ranges_in[1]]))

return if N > 1
init + sum(2:N) do i
sum(logabsdetjac(b.bs[i], x[b.ranges[i]]))
sum(logabsdetjac(b.bs[i], x[b.ranges_in[i]]))
end
else
init
Expand All @@ -104,13 +183,13 @@ end
function logabsdetjac(
b::Stacked{<:NTuple{N,Any},<:NTuple{N,Any}}, x::AbstractVector{<:Real}
) where {N}
init = sum(logabsdetjac(b.bs[1], x[b.ranges[1]]))
init = sum(logabsdetjac(b.bs[1], x[b.ranges_in[1]]))

return if N == 1
init
else
init + sum(2:N) do i
sum(logabsdetjac(b.bs[i], x[b.ranges[i]]))
sum(logabsdetjac(b.bs[i], x[b.ranges_in[i]]))
end
end
end
Expand All @@ -124,21 +203,23 @@ end
# logjac += sum(_logjac)
# return (vcat(y_1, y_2), logjac)
# end
@generated function with_logabsdet_jacobian(
@generated function _with_logabsdet_jacobian(
b::Stacked{<:NTuple{N,Any},<:NTuple{N,Any}}, x::AbstractVector
) where {N}
expr = Expr(:block)
y_names = []

push!(expr.args, :((y_1, _logjac) = with_logabsdet_jacobian(b.bs[1], x[b.ranges[1]])))
push!(
expr.args, :((y_1, _logjac) = with_logabsdet_jacobian(b.bs[1], x[b.ranges_in[1]]))
)
# TODO: drop the `sum` when we have dimensionality
push!(expr.args, :(logjac = sum(_logjac)))
push!(y_names, :y_1)
for i in 2:N
y_name = Symbol("y_$i")
push!(
expr.args,
:(($y_name, _logjac) = with_logabsdet_jacobian(b.bs[$i], x[b.ranges[$i]])),
:(($y_name, _logjac) = with_logabsdet_jacobian(b.bs[$i], x[b.ranges_in[$i]])),
)

# TODO: drop the `sum` when we have dimensionality
Expand All @@ -151,14 +232,26 @@ end
return expr
end

function with_logabsdet_jacobian(sb::Stacked, x::AbstractVector)
function _with_logabsdet_jacobian(sb::Stacked, x::AbstractVector)
N = length(sb.bs)
yinit, linit = with_logabsdet_jacobian(sb.bs[1], x[sb.ranges[1]])
yinit, linit = with_logabsdet_jacobian(sb.bs[1], x[sb.ranges_in[1]])
logjac = sum(linit)
ys = mapreduce(vcat, sb.bs[2:end], sb.ranges[2:end]; init=yinit) do b, r
ys = mapreduce(vcat, sb.bs[2:end], sb.ranges_in[2:end]; init=yinit) do b, r
y, l = with_logabsdet_jacobian(b, x[r])
logjac += sum(l)
y
end
return (ys, logjac)
end

function with_logabsdet_jacobian(sb::Stacked, x::AbstractVector)
if sb.length_in != length(x)
error("input length mismatch ($(sb.length_in) != $(length(x)))")
end
y, logjac = _with_logabsdet_jacobian(sb, x)
if output_length(sb, length(x)) != length(y)
error("output length mismatch ($(output_length(sb, length(x))) != $(length(y)))")
end

return (y, logjac)
end
14 changes: 14 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,20 @@ function logabsdetjac(f::Columnwise, x::AbstractMatrix)
end
with_logabsdet_jacobian(f::Columnwise, x::AbstractMatrix) = (f(x), logabsdetjac(f, x))

"""
output_size(f, sz)
Returns the output size of `f` given the input size `sz`.
"""
output_size(f, sz) = sz

"""
output_length(f, len::Int)
Returns the output length of `f` given the input length `len`.
"""
output_length(f, len::Int) = only(output_size(f, (len,)))

######################
# Bijector interface #
######################
Expand Down
27 changes: 13 additions & 14 deletions src/transformed_distribution.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
function variateform(d::Distribution, b)
sz_in = size(d)
sz_out = output_size(b, sz_in)
return ArrayLikeVariate{length(sz_out)}
end

variateform(::MultivariateDistribution, ::Inverse{VecCholeskyBijector}) = CholeskyVariate

# Transformed distributions
struct TransformedDistribution{D,B,V} <:
Distribution{V,Continuous} where {D<:Distribution{V,Continuous},B}
Distribution{V,Continuous} where {D<:ContinuousDistribution,B}
dist::D
transform::B

function TransformedDistribution(d::UnivariateDistribution, b)
return new{typeof(d),typeof(b),Univariate}(d, b)
end
function TransformedDistribution(d::MultivariateDistribution, b)
return new{typeof(d),typeof(b),Multivariate}(d, b)
end
function TransformedDistribution(d::MatrixDistribution, b)
return new{typeof(d),typeof(b),Matrixvariate}(d, b)
end
function TransformedDistribution(d::Distribution{CholeskyVariate}, b)
return new{typeof(d),typeof(b),CholeskyVariate}(d, b)
function TransformedDistribution(d::ContinuousDistribution, b)
return new{typeof(d),typeof(b),variateform(d, b)}(d, b)
end
end

Expand Down Expand Up @@ -101,8 +100,8 @@ end
##############################

# size
Base.length(td::Transformed) = length(td.dist)
Base.size(td::Transformed) = size(td.dist)
Base.length(td::Transformed) = prod(output_size(td.transform, size(td.dist)))
Base.size(td::Transformed) = output_size(td.transform, size(td.dist))

function logpdf(td::UnivariateTransformed, y::Real)
x, logjac = with_logabsdet_jacobian(inverse(td.transform), y)
Expand Down
18 changes: 18 additions & 0 deletions test/bijectors/corr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ using Bijectors: VecCorrBijector, VecCholeskyBijector, CorrBijector
test_bijector(bvec, x; test_not_identity=d != 1, changes_of_variables_test=false)

test_ad(x -> sum(bvec(bvecinv(x))), yvec)

# Check that output sizes are computed correctly.
tdist = transformed(dist)
@test length(tdist) == length(yvec)
@test tdist isa MultivariateDistribution

dist_unconstrained = transformed(MvNormal(zeros(length(tdist)), I), inverse(bvec))
@test size(dist_unconstrained) == size(x)
@test dist_unconstrained isa MatrixDistribution
end
end

Expand Down Expand Up @@ -60,6 +69,15 @@ end
# test_bijector is commented out for now,
# as isapprox is not defined for ::Cholesky types (the domain of LKJCholesky)
# test_bijector(b, x; test_not_identity=d != 1, changes_of_variables_test=false)

# Check that output sizes are computed correctly.
tdist = transformed(dist)
@test length(tdist) == length(y)
@test tdist isa MultivariateDistribution

dist_unconstrained = transformed(MvNormal(zeros(length(tdist)), I), inverse(b))
@test size(dist_unconstrained) == size(x)
@test dist_unconstrained isa Distribution{CholeskyVariate,Continuous}
end
end
end
Loading

1 comment on commit 2147089

@torfjelde
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't release until #263 and #271 has gone throough!

Please sign in to comment.