diff --git a/src/bijectors/planar_layer.jl b/src/bijectors/planar_layer.jl index 3ff7f88f..b40166c7 100644 --- a/src/bijectors/planar_layer.jl +++ b/src/bijectors/planar_layer.jl @@ -15,7 +15,7 @@ using Roots # for inverse # TODO: add docstring -mutable struct PlanarLayer{T1<:AbstractVector{<:Real}, T2<:Real} <: Bijector{1} +mutable struct PlanarLayer{T1<:AbstractVector{<:Real}, T2<:Union{Real, AbstractVector{<:Real}}} <: Bijector{1} w::T1 u::T1 b::T2 @@ -33,7 +33,7 @@ end function PlanarLayer(dims::Int, wrapper=identity) w = wrapper(randn(dims)) u = wrapper(randn(dims)) - b = wrapper(randn()) + b = wrapper(randn(1)) return PlanarLayer(w, u, b) end @@ -42,7 +42,7 @@ planar_flow_m(x) = -1 + softplus(x) # for planar flow from A.1 # An internal version of transform that returns intermediate variables function _transform(flow::PlanarLayer, z::AbstractVecOrMat) - _planar_transform(flow.u, flow.w, flow.b, z) + return _planar_transform(flow.u, flow.w, first(flow.b), z) end function _planar_transform(u, w, b, z) u_hat = get_u_hat(u, w) @@ -55,7 +55,7 @@ end function forward(flow::PlanarLayer, z::AbstractVecOrMat) transformed, u_hat = _transform(flow, z) # Compute log_det_jacobian - psi = ψ(z, flow.w, flow.b) .+ zero(eltype(u_hat)) + psi = ψ(z, flow.w, first(flow.b)) .+ zero(eltype(u_hat)) if psi isa AbstractVector T = eltype(psi) else @@ -71,11 +71,11 @@ function (ib::Inverse{<: PlanarLayer})(y::AbstractVector{<:Real}) T = promote_type(eltype(flow.u), eltype(flow.w), eltype(flow.b), eltype(y)) TV = vectorof(T) # Define the objective functional; implemented with reference from A.1 - f(y) = alpha -> (flow.w' * y) - alpha - (flow.w' * u_hat) * tanh(alpha + flow.b) + f(y) = alpha -> (flow.w' * y) - alpha - (flow.w' * u_hat) * tanh(alpha + first(flow.b)) # Run solver alpha::T = find_zero(f(y), zero(T), Order16()) z_para::TV = (flow.w ./ norm(flow.w, 2)) .* alpha - return (y .- u_hat .* tanh.(flow.w' * z_para .+ flow.b))::TV + return (y .- u_hat .* tanh.(flow.w' * z_para .+ first(flow.b)))::TV end function (ib::Inverse{<: PlanarLayer})(y::AbstractMatrix{<:Real}) flow = ib.orig @@ -83,13 +83,13 @@ function (ib::Inverse{<: PlanarLayer})(y::AbstractMatrix{<:Real}) T = promote_type(eltype(flow.u), eltype(flow.w), eltype(flow.b), eltype(y)) TM = matrixof(T) # Define the objective functional; implemented with reference from A.1 - f(y) = alpha -> (flow.w' * y) - alpha - (flow.w' * u_hat) * tanh(alpha + flow.b) + f(y) = alpha -> (flow.w' * y) - alpha - (flow.w' * u_hat) * tanh(alpha + first(flow.b)) # Run solver alpha = mapvcat(eachcol(y)) do c find_zero(f(c), zero(T), Order16()) end z_para::TM = (flow.w ./ norm(flow.w, 2)) .* alpha' - return (y .- u_hat .* tanh.(flow.w' * z_para .+ flow.b))::TM + return (y .- u_hat .* tanh.(flow.w' * z_para .+ first(flow.b)))::TM end function matrixof(::Type{Vector{T}}) where {T <: Real} diff --git a/src/bijectors/radial_layer.jl b/src/bijectors/radial_layer.jl index 330c4f70..43e7c26f 100644 --- a/src/bijectors/radial_layer.jl +++ b/src/bijectors/radial_layer.jl @@ -13,7 +13,7 @@ using Roots # for inverse # RadialLayer # ############### -mutable struct RadialLayer{T1 <: Real, T2 <: AbstractVector{<:Real}} <: Bijector{1} +mutable struct RadialLayer{T1<:Union{Real, AbstractVector{<:Real}}, T2<:AbstractVector{<:Real}} <: Bijector{1} α_::T1 β::T1 z_0::T2 @@ -23,8 +23,8 @@ function Base.:(==)(b1::RadialLayer, b2::RadialLayer) end function RadialLayer(dims::Int, wrapper=identity) - α_ = wrapper(randn()) - β = wrapper(randn()) + α_ = wrapper(randn(1)) + β = wrapper(randn(1)) z_0 = wrapper(randn(dims)) return RadialLayer(α_, β, z_0) end @@ -34,7 +34,7 @@ h(α, r) = 1 ./ (α .+ r) # for radial flow from eq(14) # An internal version of transform that returns intermediate variables function _transform(flow::RadialLayer, z::AbstractVecOrMat) - return _radial_transform(flow.α_, flow.β, flow.z_0, z) + return _radial_transform(first(flow.α_), first(flow.β), flow.z_0, z) end function _radial_transform(α_, β, z_0, z) α = softplus(α_) # from A.2 @@ -72,8 +72,8 @@ function (ib::Inverse{<:RadialLayer})(y::AbstractVector{<:Real}) flow = ib.orig T = promote_type(eltype(flow.α_), eltype(flow.β), eltype(flow.z_0), eltype(y)) TV = vectorof(T) - α = softplus(flow.α_) # from A.2 - β_hat = - α + softplus(flow.β) # from A.2 + α = softplus(first(flow.α_)) # from A.2 + β_hat = - α + softplus(first(flow.β)) # from A.2 # Define the objective functional f(y) = r -> norm(y .- flow.z_0) - r * (1 + β_hat / (α + r)) # from eq(26) # Run solver @@ -84,8 +84,8 @@ function (ib::Inverse{<:RadialLayer})(y::AbstractMatrix{<:Real}) flow = ib.orig T = promote_type(eltype(flow.α_), eltype(flow.β), eltype(flow.z_0), eltype(y)) TM = matrixof(T) - α = softplus(flow.α_) # from A.2 - β_hat = - α + softplus(flow.β) # from A.2 + α = softplus(first(flow.α_)) # from A.2 + β_hat = - α + softplus(first(flow.β)) # from A.2 # Define the objective functional f(y) = r -> norm(y .- flow.z_0) - r * (1 + β_hat / (α + r)) # from eq(26) # Run solver