Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added functors to make working with parameters of bijectors easier #143

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
@@ -1,11 +1,12 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.8.7"
version = "0.8.8"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Expand All @@ -21,6 +22,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
ArgCheck = "1, 2"
Compat = "3"
Distributions = "0.23.3"
Functors = "0.1"
MappedArrays = "0.2.2, 0.3"
NNlib = "0.6, 0.7"
Reexport = "0.2"
Expand Down
1 change: 1 addition & 0 deletions src/Bijectors.jl
Expand Up @@ -34,6 +34,7 @@ using StatsFuns
using LinearAlgebra
using MappedArrays
using Roots
using Functors
using Base.Iterators: drop
using LinearAlgebra: AbstractTriangular

Expand Down
2 changes: 2 additions & 0 deletions src/bijectors/composed.jl
Expand Up @@ -89,6 +89,8 @@ struct Composed{A, N} <: Bijector{N}
ts::A
end

@functor Composed

Composed(bs::Tuple{Vararg{<:Bijector{N}}}) where N = Composed{typeof(bs),N}(bs)
Composed(bs::AbstractArray{<:Bijector{N}}) where N = Composed{typeof(bs),N}(bs)

Expand Down
2 changes: 2 additions & 0 deletions src/bijectors/coupling.jl
Expand Up @@ -172,6 +172,8 @@ struct Coupling{F, M} <: Bijector{1} where {F, M <: PartitionMask}
mask::M
end

@functor Coupling

function Coupling(θ, n::Int)
idx = div(n, 2)
return Coupling(θ, PartitionMask(n, 1:idx))
Expand Down
2 changes: 2 additions & 0 deletions src/bijectors/leaky_relu.jl
Expand Up @@ -11,6 +11,8 @@ struct LeakyReLU{T, N} <: Bijector{N}
α::T
end

@functor LeakyReLU

LeakyReLU(α::T; dim::Val{N} = Val(0)) where {T<:Real, N} = LeakyReLU{T, N}(α)
LeakyReLU(α::T; dim::Val{N} = Val(D)) where {D, T<:AbstractArray{<:Real, D}, N} = LeakyReLU{T, N}(α)

Expand Down
9 changes: 9 additions & 0 deletions src/bijectors/named_bijector.jl
Expand Up @@ -29,6 +29,8 @@ struct NamedBijector{names, Bs<:NamedTuple{names}} <: AbstractNamedBijector
bs::Bs
end

@functor NamedBijector

names_to_bijectors(b::NamedBijector) = b.bs

@generated function (b::NamedBijector{names1})(
Expand Down Expand Up @@ -70,6 +72,9 @@ See also: [`Inverse`](@ref)
struct NamedInverse{B<:AbstractNamedBijector} <: AbstractNamedBijector
orig::B
end

@functor NamedInverse

Base.inv(nb::AbstractNamedBijector) = NamedInverse(nb)
Base.inv(ni::NamedInverse) = ni.orig

Expand All @@ -93,6 +98,8 @@ struct NamedComposition{Bs} <: AbstractNamedBijector
bs::Bs
end

@functor NamedComposition

# Essentially just copy-paste from impl of composition for 'standard' bijectors,
# with minor changes here and there.
composel(bs::AbstractNamedBijector...) = NamedComposition(bs)
Expand Down Expand Up @@ -201,6 +208,8 @@ struct NamedCoupling{target, deps, F} <: AbstractNamedBijector where {F, target}
f::F
end

@functor NamedCoupling

NamedCoupling(target, deps, f::F) where {F} = NamedCoupling{target, deps, F}(f)
function NamedCoupling(::Val{target}, ::Val{deps}, f::F) where {target, deps, F}
return NamedCoupling{target, deps, F}(f)
Expand Down
3 changes: 3 additions & 0 deletions src/bijectors/normalise.jl
Expand Up @@ -14,6 +14,9 @@ mutable struct InvertibleBatchNorm{T1,T2,T3} <: Bijector{1}
eps :: T3
mtm :: T3 # momentum
end

@functor InvertibleBatchNorm

function Base.:(==)(b1::InvertibleBatchNorm, b2::InvertibleBatchNorm)
return b1.b == b2.b &&
b1.logs == b2.logs &&
Expand Down
3 changes: 3 additions & 0 deletions src/bijectors/planar_layer.jl
Expand Up @@ -20,6 +20,9 @@ mutable struct PlanarLayer{T1<:AbstractVector{<:Real}, T2<:Union{Real, AbstractV
u::T1
b::T2
end

@functor PlanarLayer

function Base.:(==)(b1::PlanarLayer, b2::PlanarLayer)
return b1.w == b2.w && b1.u == b2.u && b1.b == b2.b
end
Expand Down
3 changes: 3 additions & 0 deletions src/bijectors/radial_layer.jl
Expand Up @@ -18,6 +18,9 @@ mutable struct RadialLayer{T1<:Union{Real, AbstractVector{<:Real}}, T2<:Abstract
β::T1
z_0::T2
end

@functor RadialLayer

function Base.:(==)(b1::RadialLayer, b2::RadialLayer)
return b1.α_ == b2.α_ && b1.β == b2.β && b1.z_0 == b2.z_0
end
Expand Down
4 changes: 3 additions & 1 deletion src/bijectors/scale.jl
Expand Up @@ -2,6 +2,8 @@ struct Scale{T, N} <: Bijector{N}
a::T
end

@functor Scale

Base.:(==)(b1::Scale{<:Any, N}, b2::Scale{<:Any, N}) where {N} = b1.a == b2.a

function Scale(a::Union{Real,AbstractArray}; dim::Val{D} = Val(ndims(a))) where D
Expand Down Expand Up @@ -46,4 +48,4 @@ function _logabsdetjac_scale(
map(x) do x
_logabsdetjac_scale(a, x, Val(2))
end
end
end
2 changes: 2 additions & 0 deletions src/bijectors/shift.jl
Expand Up @@ -5,6 +5,8 @@ struct Shift{T, N} <: Bijector{N}
a::T
end

@functor Shift

Base.:(==)(b1::Shift{<:Any, N}, b2::Shift{<:Any, N}) where {N} = b1.a == b2.a

function Shift(a::Union{Real,AbstractArray}; dim::Val{D} = Val(ndims(a))) where D
Expand Down
3 changes: 3 additions & 0 deletions src/bijectors/stacked.jl
Expand Up @@ -41,6 +41,9 @@ struct Stacked{Bs, N} <: Bijector{1}
return new{A, N}(bs, ranges)
end
end

@functor Stacked

Stacked(bs, ranges::AbstractArray) = Stacked(bs, tuple(ranges...))
Stacked(bs) = Stacked(bs, tuple([i:i for i = 1:length(bs)]...))

Expand Down
3 changes: 3 additions & 0 deletions src/bijectors/truncated.jl
Expand Up @@ -5,6 +5,9 @@ struct TruncatedBijector{N, T1, T2} <: Bijector{N}
lb::T1
ub::T2
end

@functor TruncatedBijector

TruncatedBijector(lb, ub) = TruncatedBijector{0}(lb, ub)
function TruncatedBijector{N}(lb::T1, ub::T2) where {N, T1, T2}
return TruncatedBijector{N, T1, T2}(lb, ub)
Expand Down
3 changes: 3 additions & 0 deletions src/interface.jl
Expand Up @@ -66,6 +66,9 @@ struct Inverse{B <: Bijector, N} <: Bijector{N}

Inverse(b::B) where {N, B<:Bijector{N}} = new{B, N}(b)
end

@functor Inverse

up1(b::Inverse) = Inverse(up1(b.orig))

inv(b::Bijector) = Inverse(b)
Expand Down