Skip to content

Commit

Permalink
fix remove redundant helpers for reparam_with_entropy for bijector
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal committed Jan 3, 2024
1 parent 05dbb51 commit 31db7bc
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 28 deletions.
36 changes: 14 additions & 22 deletions ext/AdvancedVIBijectorsExt.jl
Expand Up @@ -11,24 +11,6 @@ else
using ..Random
end

function transform_samples_with_jacobian(unconst_samples, transform, n_samples)
unconst_iter = AdvancedVI.eachsample(unconst_samples)
unconst_init = first(unconst_iter)

samples_init, logjac_init = with_logabsdet_jacobian(transform, unconst_init)

samples_and_logjac = mapreduce(
AdvancedVI.catsamples_and_acc,
Iterators.drop(unconst_iter, 1);
init=(AdvancedVI.samples_expand_dim(samples_init), logjac_init)
) do sample
with_logabsdet_jacobian(transform, sample)
end
samples = first(samples_and_logjac)
logjac = last(samples_and_logjac)/n_samples
samples, logjac
end

function AdvancedVI.reparam_with_entropy(
rng ::Random.AbstractRNG,
q ::Bijectors.TransformedDistribution,
Expand All @@ -41,14 +23,24 @@ function AdvancedVI.reparam_with_entropy(
q_unconst_stop = q_stop.dist

# Draw samples and compute entropy of the uncontrained distribution
unconst_samples, unconst_entropy = AdvancedVI.reparam_with_entropy(
unconstr_samples, unconst_entropy = AdvancedVI.reparam_with_entropy(
rng, q_unconst, q_unconst_stop, n_samples, ent_est
)

# Apply bijector to samples while estimating its jacobian
samples, logjac = transform_samples_with_jacobian(
unconst_samples, transform, n_samples
)
unconstr_iter = AdvancedVI.eachsample(unconstr_samples)
unconstr_init = first(unconstr_iter)
samples_init, logjac_init = with_logabsdet_jacobian(transform, unconstr_init)
samples_and_logjac = mapreduce(
AdvancedVI.catsamples_and_acc,
Iterators.drop(unconstr_iter, 1);
init=(reshape(samples_init, (:,1)), logjac_init)
) do sample
with_logabsdet_jacobian(transform, sample)
end
samples = first(samples_and_logjac)
logjac = last(samples_and_logjac)/n_samples

entropy = unconst_entropy + logjac
samples, entropy
end
Expand Down
6 changes: 0 additions & 6 deletions src/utils.jl
Expand Up @@ -23,8 +23,6 @@ end

eachsample(samples::AbstractMatrix) = eachcol(samples)

eachsample(samples::AbstractVector) = samples

function catsamples_and_acc(
state_curr::Tuple{<:AbstractArray, <:Real},
state_new ::Tuple{<:AbstractVector, <:Real}
Expand All @@ -34,7 +32,3 @@ function catsamples_and_acc(
return (x, ∑y)
end

function samples_expand_dim(x::AbstractVector)
reshape(x, (:,1))
end

0 comments on commit 31db7bc

Please sign in to comment.