Skip to content

Commit

Permalink
Simplify inverse of radial layer
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Aug 8, 2020
1 parent 273dbe6 commit 8b0d2e1
Showing 1 changed file with 45 additions and 17 deletions.
62 changes: 45 additions & 17 deletions src/bijectors/radial_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,29 +70,57 @@ end

function (ib::Inverse{<:RadialLayer})(y::AbstractVector{<:Real})
flow = ib.orig
T = promote_type(eltype(flow.α_), eltype(flow.β), eltype(flow.z_0), eltype(y))
TV = vectorof(T)
z0 = flow.z_0
α = softplus(first(flow.α_)) # from A.2
β_hat = - α + softplus(first(flow.β)) # from A.2
# Define the objective functional
f(y) = r -> norm(y .- flow.z_0) - r * (1 + β_hat /+ r)) # from eq(26)
# Run solver
rs::T = find_zero(f(y), zero(T), Order16())
return (flow.z_0 .+ (y .- flow.z_0) ./ (1 .+ β_hat ./.+ rs)))::TV
α_plus_β_hat = softplus(first(flow.β)) # from A.2

# Compute the norm ``r`` from A.2.
y_minus_z0 = y .- z0
r = compute_r(y_minus_z0, α, α_plus_β_hat)

return z0 .+ ((α + r) / (α_plus_β_hat + r)) .* y_minus_z0
end

function (ib::Inverse{<:RadialLayer})(y::AbstractMatrix{<:Real})
flow = ib.orig
T = promote_type(eltype(flow.α_), eltype(flow.β), eltype(flow.z_0), eltype(y))
TM = matrixof(T)
z0 = flow.z_0
α = softplus(first(flow.α_)) # from A.2
β_hat = - α + softplus(first(flow.β)) # from A.2
# Define the objective functional
f(y) = r -> norm(y .- flow.z_0) - r * (1 + β_hat /+ r)) # from eq(26)
# Run solver
rs = mapvcat(eachcol(y)) do c
find_zero(f(c), zero(T), Order16())
α_plus_β_hat = softplus(first(flow.β)) # from A.2

# Compute the norm ``r`` from A.2 for each column.
y_minus_z0 = y .- z0
rs = mapvcat(eachcol(y_minus_z0)) do c
return compute_r(c, α, α_plus_β_hat)
end
return (flow.z_0 .+ (y .- flow.z_0) ./ (1 .+ β_hat ./.+ rs')))::TM

return z0 .+ ((α .+ rs) ./ (α_plus_β_hat .+ rs))' .* y_minus_z0
end

"""
compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_β_hat)
Compute the unique solution ``r`` to the equation
```math
\\|y_minus_z0\\|_2 = r \\left(1 + \\frac{α_plus_β_hat - α}{α + r}\\right)
```
subject to ``r ≥ 0`` and ``r ≠ α``.
Since ``α > 0`` and ``α_plus_β_hat > 0``, the solution is unique and given by
```math
r = (\\sqrt{(α_plus_β_hat - γ)^2 + 4 α γ} - (α_plus_β_hat - γ)) / 2,
```
where ``γ = \\|y_minus_z0\\|_2``. For details see appendix A.2 of the reference.
# References
D. Rezende, S. Mohamed (2015): Variational Inference with Normalizing Flows.
arXiv:1505.05770
"""
function compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_β_hat)
γ = norm(y_minus_z0)
a = α_plus_β_hat - γ
r = (sqrt(a^2 + 4 * α * γ) - a) / 2
return r
end

logabsdetjac(flow::RadialLayer, x::AbstractVecOrMat) = forward(flow, x).logabsdetjac
Expand Down

0 comments on commit 8b0d2e1

Please sign in to comment.