diff --git a/Project.toml b/Project.toml index d13bc7d1..b5a9ce56 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.8.7" ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" diff --git a/src/Bijectors.jl b/src/Bijectors.jl index e435ebf7..76e1c47b 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -34,6 +34,7 @@ using StatsFuns using LinearAlgebra using MappedArrays using Roots +using Functors using Base.Iterators: drop using LinearAlgebra: AbstractTriangular diff --git a/src/bijectors/composed.jl b/src/bijectors/composed.jl index 36142c8c..fc49c9c8 100644 --- a/src/bijectors/composed.jl +++ b/src/bijectors/composed.jl @@ -89,6 +89,8 @@ struct Composed{A, N} <: Bijector{N} ts::A end +@functor Composed + Composed(bs::Tuple{Vararg{<:Bijector{N}}}) where N = Composed{typeof(bs),N}(bs) Composed(bs::AbstractArray{<:Bijector{N}}) where N = Composed{typeof(bs),N}(bs) diff --git a/src/bijectors/coupling.jl b/src/bijectors/coupling.jl index 088caf2a..e2cf3446 100644 --- a/src/bijectors/coupling.jl +++ b/src/bijectors/coupling.jl @@ -172,6 +172,8 @@ struct Coupling{F, M} <: Bijector{1} where {F, M <: PartitionMask} mask::M end +@functor Coupling + function Coupling(θ, n::Int) idx = div(n, 2) return Coupling(θ, PartitionMask(n, 1:idx)) diff --git a/src/bijectors/leaky_relu.jl b/src/bijectors/leaky_relu.jl index ffca00c7..3839fa30 100644 --- a/src/bijectors/leaky_relu.jl +++ b/src/bijectors/leaky_relu.jl @@ -11,6 +11,8 @@ struct LeakyReLU{T, N} <: Bijector{N} α::T end +@functor LeakyReLU + LeakyReLU(α::T; dim::Val{N} = Val(0)) where {T<:Real, N} = LeakyReLU{T, N}(α) LeakyReLU(α::T; dim::Val{N} = Val(D)) where {D, T<:AbstractArray{<:Real, D}, N} = LeakyReLU{T, N}(α) diff --git a/src/bijectors/named_bijector.jl b/src/bijectors/named_bijector.jl index f283fcd0..64585829 100644 --- a/src/bijectors/named_bijector.jl +++ b/src/bijectors/named_bijector.jl @@ -29,6 +29,8 @@ struct NamedBijector{names, Bs<:NamedTuple{names}} <: AbstractNamedBijector bs::Bs end +@functor NamedBijector + names_to_bijectors(b::NamedBijector) = b.bs @generated function (b::NamedBijector{names1})( @@ -70,6 +72,9 @@ See also: [`Inverse`](@ref) struct NamedInverse{B<:AbstractNamedBijector} <: AbstractNamedBijector orig::B end + +@functor NamedInverse + Base.inv(nb::AbstractNamedBijector) = NamedInverse(nb) Base.inv(ni::NamedInverse) = ni.orig @@ -93,6 +98,8 @@ struct NamedComposition{Bs} <: AbstractNamedBijector bs::Bs end +@functor NamedComposition + # Essentially just copy-paste from impl of composition for 'standard' bijectors, # with minor changes here and there. composel(bs::AbstractNamedBijector...) = NamedComposition(bs) @@ -201,6 +208,8 @@ struct NamedCoupling{target, deps, F} <: AbstractNamedBijector where {F, target} f::F end +@functor NamedCoupling + NamedCoupling(target, deps, f::F) where {F} = NamedCoupling{target, deps, F}(f) function NamedCoupling(::Val{target}, ::Val{deps}, f::F) where {target, deps, F} return NamedCoupling{target, deps, F}(f) diff --git a/src/bijectors/normalise.jl b/src/bijectors/normalise.jl index 300c46e2..567c0379 100644 --- a/src/bijectors/normalise.jl +++ b/src/bijectors/normalise.jl @@ -14,6 +14,9 @@ mutable struct InvertibleBatchNorm{T1,T2,T3} <: Bijector{1} eps :: T3 mtm :: T3 # momentum end + +@functor InvertibleBatchNorm + function Base.:(==)(b1::InvertibleBatchNorm, b2::InvertibleBatchNorm) return b1.b == b2.b && b1.logs == b2.logs && diff --git a/src/bijectors/planar_layer.jl b/src/bijectors/planar_layer.jl index fbb24f41..7c37cdd6 100644 --- a/src/bijectors/planar_layer.jl +++ b/src/bijectors/planar_layer.jl @@ -20,6 +20,9 @@ mutable struct PlanarLayer{T1<:AbstractVector{<:Real}, T2<:Union{Real, AbstractV u::T1 b::T2 end + +@functor PlanarLayer + function Base.:(==)(b1::PlanarLayer, b2::PlanarLayer) return b1.w == b2.w && b1.u == b2.u && b1.b == b2.b end diff --git a/src/bijectors/radial_layer.jl b/src/bijectors/radial_layer.jl index b29ac8d2..be4dbb07 100644 --- a/src/bijectors/radial_layer.jl +++ b/src/bijectors/radial_layer.jl @@ -18,6 +18,9 @@ mutable struct RadialLayer{T1<:Union{Real, AbstractVector{<:Real}}, T2<:Abstract β::T1 z_0::T2 end + +@functor RadialLayer + function Base.:(==)(b1::RadialLayer, b2::RadialLayer) return b1.α_ == b2.α_ && b1.β == b2.β && b1.z_0 == b2.z_0 end diff --git a/src/bijectors/scale.jl b/src/bijectors/scale.jl index 5d728a05..40142600 100644 --- a/src/bijectors/scale.jl +++ b/src/bijectors/scale.jl @@ -2,6 +2,8 @@ struct Scale{T, N} <: Bijector{N} a::T end +@functor Scale + Base.:(==)(b1::Scale{<:Any, N}, b2::Scale{<:Any, N}) where {N} = b1.a == b2.a function Scale(a::Union{Real,AbstractArray}; dim::Val{D} = Val(ndims(a))) where D @@ -46,4 +48,4 @@ function _logabsdetjac_scale( map(x) do x _logabsdetjac_scale(a, x, Val(2)) end -end \ No newline at end of file +end diff --git a/src/bijectors/shift.jl b/src/bijectors/shift.jl index ab975bb5..1d55bdef 100644 --- a/src/bijectors/shift.jl +++ b/src/bijectors/shift.jl @@ -5,6 +5,8 @@ struct Shift{T, N} <: Bijector{N} a::T end +@functor Shift + Base.:(==)(b1::Shift{<:Any, N}, b2::Shift{<:Any, N}) where {N} = b1.a == b2.a function Shift(a::Union{Real,AbstractArray}; dim::Val{D} = Val(ndims(a))) where D diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index 9be9ebaf..499ddd29 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -41,6 +41,9 @@ struct Stacked{Bs, N} <: Bijector{1} return new{A, N}(bs, ranges) end end + +@functor Stacked + Stacked(bs, ranges::AbstractArray) = Stacked(bs, tuple(ranges...)) Stacked(bs) = Stacked(bs, tuple([i:i for i = 1:length(bs)]...)) diff --git a/src/bijectors/truncated.jl b/src/bijectors/truncated.jl index e403f988..2c3ad7fe 100644 --- a/src/bijectors/truncated.jl +++ b/src/bijectors/truncated.jl @@ -5,6 +5,9 @@ struct TruncatedBijector{N, T1, T2} <: Bijector{N} lb::T1 ub::T2 end + +@functor TruncatedBijector + TruncatedBijector(lb, ub) = TruncatedBijector{0}(lb, ub) function TruncatedBijector{N}(lb::T1, ub::T2) where {N, T1, T2} return TruncatedBijector{N, T1, T2}(lb, ub) diff --git a/src/interface.jl b/src/interface.jl index 66e5b307..8825cd64 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -66,6 +66,9 @@ struct Inverse{B <: Bijector, N} <: Bijector{N} Inverse(b::B) where {N, B<:Bijector{N}} = new{B, N}(b) end + +@functor Inverse + up1(b::Inverse) = Inverse(up1(b.orig)) inv(b::Bijector) = Inverse(b)