Skip to content

Commit

Permalink
Merge 47947ef into 5fd780c
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Aug 8, 2020
2 parents 5fd780c + 47947ef commit bfa8967
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 48 deletions.
79 changes: 50 additions & 29 deletions src/bijectors/planar_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,42 +67,63 @@ end

function (ib::Inverse{<: PlanarLayer})(y::AbstractVector{<:Real})
flow = ib.orig
u_hat = get_u_hat(flow.u, flow.w)
T = promote_type(eltype(flow.u), eltype(flow.w), eltype(flow.b), eltype(y))
TV = vectorof(T)
# Define the objective functional; implemented with reference from A.1
f(y) = alpha -> (flow.w' * y) - alpha - (flow.w' * u_hat) * tanh(alpha + first(flow.b))
# Run solver
alpha::T = find_zero(f(y), zero(T), Order16())
z_para::TV = (flow.w ./ norm(flow.w, 2)) .* alpha
return (y .- u_hat .* tanh.(flow.w' * z_para .+ first(flow.b)))::TV
w = flow.w
b = first(flow.b)
u_hat = get_u_hat(flow.u, w)

# Find the scalar ``alpha`` from A.1.
wt_y = dot(w, y)
wt_u_hat = dot(w, u_hat)
alpha = find_alpha(y, wt_y, wt_u_hat, b)

return y .- u_hat .* tanh(alpha * norm(w, 2) + b)
end

function (ib::Inverse{<: PlanarLayer})(y::AbstractMatrix{<:Real})
flow = ib.orig
w = flow.w
b = first(flow.b)
u_hat = get_u_hat(flow.u, flow.w)
T = promote_type(eltype(flow.u), eltype(flow.w), eltype(flow.b), eltype(y))
TM = matrixof(T)
# Define the objective functional; implemented with reference from A.1
f(y) = alpha -> (flow.w' * y) - alpha - (flow.w' * u_hat) * tanh(alpha + first(flow.b))
# Run solver
alpha = mapvcat(eachcol(y)) do c
find_zero(f(c), zero(T), Order16())

# Find the scalar ``alpha`` from A.1 for each column.
wt_u_hat = dot(w, u_hat)
alphas = mapvcat(eachcol(y)) do c
find_alpha(c, dot(w, c), wt_u_hat, b)
end
z_para::TM = (flow.w ./ norm(flow.w, 2)) .* alpha'
return (y .- u_hat .* tanh.(flow.w' * z_para .+ first(flow.b)))::TM
end

function matrixof(::Type{Vector{T}}) where {T <: Real}
return Matrix{T}
end
function matrixof(::Type{T}) where {T <: Real}
return Matrix{T}
return y .- u_hat .* tanh.(alphas' .* norm(w, 2) .+ b)
end
function vectorof(::Type{Matrix{T}}) where {T <: Real}
return Vector{T}
end
function vectorof(::Type{T}) where {T <: Real}
return Vector{T}

"""
find_alpha(y::AbstractVector{<:Real}, wt_y, wt_u_hat, b)
Compute an (approximate) real-valued solution ``α`` to the equation
```math
wt_y = α + wt_u_hat tanh(α + b)
```
The uniqueness of the solution is guaranteed since ``wt_u_hat ≥ -1``.
For details see appendix A.1 of the reference.
# References
D. Rezende, S. Mohamed (2015): Variational Inference with Normalizing Flows.
arXiv:1505.05770
"""
function find_alpha(y::AbstractVector{<:Real}, wt_y, wt_u_hat, b)
# Compute the initial bracket ((-Inf, 0) or (0, Inf))
f0 = wt_u_hat * tanh(b) - wt_y
zero_f0 = zero(f0)
if f0 < zero_f0
initial_bracket = (zero_f0, oftype(f0, Inf))
else
initial_bracket = (oftype(f0, -Inf), zero_f0)
end
alpha = find_zero(initial_bracket) do x
x + wt_u_hat * tanh(x + b) - wt_y
end

return alpha
end

logabsdetjac(flow::PlanarLayer, x) = forward(flow, x).logabsdetjac
Expand Down
62 changes: 45 additions & 17 deletions src/bijectors/radial_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,29 +70,57 @@ end

function (ib::Inverse{<:RadialLayer})(y::AbstractVector{<:Real})
flow = ib.orig
T = promote_type(eltype(flow.α_), eltype(flow.β), eltype(flow.z_0), eltype(y))
TV = vectorof(T)
z0 = flow.z_0
α = softplus(first(flow.α_)) # from A.2
β_hat = - α + softplus(first(flow.β)) # from A.2
# Define the objective functional
f(y) = r -> norm(y .- flow.z_0) - r * (1 + β_hat /+ r)) # from eq(26)
# Run solver
rs::T = find_zero(f(y), zero(T), Order16())
return (flow.z_0 .+ (y .- flow.z_0) ./ (1 .+ β_hat ./.+ rs)))::TV
α_plus_β_hat = softplus(first(flow.β)) # from A.2

# Compute the norm ``r`` from A.2.
y_minus_z0 = y .- z0
r = compute_r(y_minus_z0, α, α_plus_β_hat)

return z0 .+ ((α + r) / (α_plus_β_hat + r)) .* y_minus_z0
end

function (ib::Inverse{<:RadialLayer})(y::AbstractMatrix{<:Real})
flow = ib.orig
T = promote_type(eltype(flow.α_), eltype(flow.β), eltype(flow.z_0), eltype(y))
TM = matrixof(T)
z0 = flow.z_0
α = softplus(first(flow.α_)) # from A.2
β_hat = - α + softplus(first(flow.β)) # from A.2
# Define the objective functional
f(y) = r -> norm(y .- flow.z_0) - r * (1 + β_hat /+ r)) # from eq(26)
# Run solver
rs = mapvcat(eachcol(y)) do c
find_zero(f(c), zero(T), Order16())
α_plus_β_hat = softplus(first(flow.β)) # from A.2

# Compute the norm ``r`` from A.2 for each column.
y_minus_z0 = y .- z0
rs = mapvcat(eachcol(y_minus_z0)) do c
return compute_r(c, α, α_plus_β_hat)
end
return (flow.z_0 .+ (y .- flow.z_0) ./ (1 .+ β_hat ./.+ rs')))::TM

return z0 .+ ((α .+ rs) ./ (α_plus_β_hat .+ rs))' .* y_minus_z0
end

"""
compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_β_hat)
Compute the unique solution ``r`` to the equation
```math
\\|y_minus_z0\\|_2 = r \\left(1 + \\frac{α_plus_β_hat - α}{α + r}\\right)
```
subject to ``r ≥ 0`` and ``r ≠ α``.
Since ``α > 0`` and ``α_plus_β_hat > 0``, the solution is unique and given by
```math
r = (\\sqrt{(α_plus_β_hat - γ)^2 + 4 α γ} - (α_plus_β_hat - γ)) / 2,
```
where ``γ = \\|y_minus_z0\\|_2``. For details see appendix A.2 of the reference.
# References
D. Rezende, S. Mohamed (2015): Variational Inference with Normalizing Flows.
arXiv:1505.05770
"""
function compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_β_hat)
γ = norm(y_minus_z0)
a = α_plus_β_hat - γ
r = (sqrt(a^2 + 4 * α * γ) - a) / 2
return r
end

logabsdetjac(flow::RadialLayer, x::AbstractVecOrMat) = forward(flow, x).logabsdetjac
Expand Down
4 changes: 2 additions & 2 deletions test/norm_flows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ end
our_method = sum(forward(flow, z).logabsdetjac)

@test our_method forward_diff
@test inv(flow)(flow(z)) z rtol=0.2
@test (inv(flow) flow)(z) z rtol=0.2
@test inv(flow)(flow(z)) z rtol=0.25
@test (inv(flow) flow)(z) z rtol=0.25
end

w = ones(10)
Expand Down

0 comments on commit bfa8967

Please sign in to comment.