Skip to content

Commit

Permalink
relaxed the type-requirement on planar and radial
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Jul 7, 2020
1 parent 9756425 commit 7175cb2
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 16 deletions.
16 changes: 8 additions & 8 deletions src/bijectors/planar_layer.jl
Expand Up @@ -15,7 +15,7 @@ using Roots # for inverse

# TODO: add docstring

mutable struct PlanarLayer{T1<:AbstractVector{<:Real}, T2<:Real} <: Bijector{1}
mutable struct PlanarLayer{T1<:AbstractVector{<:Real}, T2<:Union{Real, AbstractVector{<:Real}}} <: Bijector{1}
w::T1
u::T1
b::T2
Expand All @@ -33,7 +33,7 @@ end
function PlanarLayer(dims::Int, wrapper=identity)
w = wrapper(randn(dims))
u = wrapper(randn(dims))
b = wrapper(randn())
b = wrapper(randn(1))
return PlanarLayer(w, u, b)
end

Expand All @@ -42,7 +42,7 @@ planar_flow_m(x) = -1 + softplus(x) # for planar flow from A.1

# An internal version of transform that returns intermediate variables
function _transform(flow::PlanarLayer, z::AbstractVecOrMat)
_planar_transform(flow.u, flow.w, flow.b, z)
return _planar_transform(flow.u, flow.w, first(flow.b), z)
end
function _planar_transform(u, w, b, z)
u_hat = get_u_hat(u, w)
Expand All @@ -55,7 +55,7 @@ end
function forward(flow::PlanarLayer, z::AbstractVecOrMat)
transformed, u_hat = _transform(flow, z)
# Compute log_det_jacobian
psi = ψ(z, flow.w, flow.b) .+ zero(eltype(u_hat))
psi = ψ(z, flow.w, first(flow.b)) .+ zero(eltype(u_hat))
if psi isa AbstractVector
T = eltype(psi)
else
Expand All @@ -71,25 +71,25 @@ function (ib::Inverse{<: PlanarLayer})(y::AbstractVector{<:Real})
T = promote_type(eltype(flow.u), eltype(flow.w), eltype(flow.b), eltype(y))
TV = vectorof(T)
# Define the objective functional; implemented with reference from A.1
f(y) = alpha -> (flow.w' * y) - alpha - (flow.w' * u_hat) * tanh(alpha + flow.b)
f(y) = alpha -> (flow.w' * y) - alpha - (flow.w' * u_hat) * tanh(alpha + first(flow.b))
# Run solver
alpha::T = find_zero(f(y), zero(T), Order16())
z_para::TV = (flow.w ./ norm(flow.w, 2)) .* alpha
return (y .- u_hat .* tanh.(flow.w' * z_para .+ flow.b))::TV
return (y .- u_hat .* tanh.(flow.w' * z_para .+ first(flow.b)))::TV
end
function (ib::Inverse{<: PlanarLayer})(y::AbstractMatrix{<:Real})
flow = ib.orig
u_hat = get_u_hat(flow.u, flow.w)
T = promote_type(eltype(flow.u), eltype(flow.w), eltype(flow.b), eltype(y))
TM = matrixof(T)
# Define the objective functional; implemented with reference from A.1
f(y) = alpha -> (flow.w' * y) - alpha - (flow.w' * u_hat) * tanh(alpha + flow.b)
f(y) = alpha -> (flow.w' * y) - alpha - (flow.w' * u_hat) * tanh(alpha + first(flow.b))
# Run solver
alpha = mapvcat(eachcol(y)) do c
find_zero(f(c), zero(T), Order16())
end
z_para::TM = (flow.w ./ norm(flow.w, 2)) .* alpha'
return (y .- u_hat .* tanh.(flow.w' * z_para .+ flow.b))::TM
return (y .- u_hat .* tanh.(flow.w' * z_para .+ first(flow.b)))::TM
end

function matrixof(::Type{Vector{T}}) where {T <: Real}
Expand Down
16 changes: 8 additions & 8 deletions src/bijectors/radial_layer.jl
Expand Up @@ -13,7 +13,7 @@ using Roots # for inverse
# RadialLayer #
###############

mutable struct RadialLayer{T1 <: Real, T2 <: AbstractVector{<:Real}} <: Bijector{1}
mutable struct RadialLayer{T1<:Union{Real, AbstractVector{<:Real}}, T2<:AbstractVector{<:Real}} <: Bijector{1}
α_::T1
β::T1
z_0::T2
Expand All @@ -23,8 +23,8 @@ function Base.:(==)(b1::RadialLayer, b2::RadialLayer)
end

function RadialLayer(dims::Int, wrapper=identity)
α_ = wrapper(randn())
β = wrapper(randn())
α_ = wrapper(randn(1))
β = wrapper(randn(1))
z_0 = wrapper(randn(dims))
return RadialLayer(α_, β, z_0)
end
Expand All @@ -34,7 +34,7 @@ h(α, r) = 1 ./ (α .+ r) # for radial flow from eq(14)

# An internal version of transform that returns intermediate variables
function _transform(flow::RadialLayer, z::AbstractVecOrMat)
return _radial_transform(flow.α_, flow.β, flow.z_0, z)
return _radial_transform(first(flow.α_), first(flow.β), flow.z_0, z)
end
function _radial_transform(α_, β, z_0, z)
α = softplus(α_) # from A.2
Expand Down Expand Up @@ -72,8 +72,8 @@ function (ib::Inverse{<:RadialLayer})(y::AbstractVector{<:Real})
flow = ib.orig
T = promote_type(eltype(flow.α_), eltype(flow.β), eltype(flow.z_0), eltype(y))
TV = vectorof(T)
α = softplus(flow.α_) # from A.2
β_hat = - α + softplus(flow.β) # from A.2
α = softplus(first(flow.α_)) # from A.2
β_hat = - α + softplus(first(flow.β)) # from A.2
# Define the objective functional
f(y) = r -> norm(y .- flow.z_0) - r * (1 + β_hat /+ r)) # from eq(26)
# Run solver
Expand All @@ -84,8 +84,8 @@ function (ib::Inverse{<:RadialLayer})(y::AbstractMatrix{<:Real})
flow = ib.orig
T = promote_type(eltype(flow.α_), eltype(flow.β), eltype(flow.z_0), eltype(y))
TM = matrixof(T)
α = softplus(flow.α_) # from A.2
β_hat = - α + softplus(flow.β) # from A.2
α = softplus(first(flow.α_)) # from A.2
β_hat = - α + softplus(first(flow.β)) # from A.2
# Define the objective functional
f(y) = r -> norm(y .- flow.z_0) - r * (1 + β_hat /+ r)) # from eq(26)
# Run solver
Expand Down

0 comments on commit 7175cb2

Please sign in to comment.