Skip to content

Commit

Permalink
Merge 9dd1869 into 3387a40
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Oct 23, 2020
2 parents 3387a40 + 9dd1869 commit 05f13cf
Show file tree
Hide file tree
Showing 14 changed files with 40 additions and 1 deletion.
1 change: 1 addition & 0 deletions Project.toml
Expand Up @@ -6,6 +6,7 @@ version = "0.8.7"
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Expand Down
1 change: 1 addition & 0 deletions src/Bijectors.jl
Expand Up @@ -34,6 +34,7 @@ using StatsFuns
using LinearAlgebra
using MappedArrays
using Roots
using Functors
using Base.Iterators: drop
using LinearAlgebra: AbstractTriangular

Expand Down
2 changes: 2 additions & 0 deletions src/bijectors/composed.jl
Expand Up @@ -89,6 +89,8 @@ struct Composed{A, N} <: Bijector{N}
ts::A
end

@functor Composed

Composed(bs::Tuple{Vararg{<:Bijector{N}}}) where N = Composed{typeof(bs),N}(bs)
Composed(bs::AbstractArray{<:Bijector{N}}) where N = Composed{typeof(bs),N}(bs)

Expand Down
2 changes: 2 additions & 0 deletions src/bijectors/coupling.jl
Expand Up @@ -172,6 +172,8 @@ struct Coupling{F, M} <: Bijector{1} where {F, M <: PartitionMask}
mask::M
end

@functor Coupling

function Coupling(θ, n::Int)
idx = div(n, 2)
return Coupling(θ, PartitionMask(n, 1:idx))
Expand Down
2 changes: 2 additions & 0 deletions src/bijectors/leaky_relu.jl
Expand Up @@ -11,6 +11,8 @@ struct LeakyReLU{T, N} <: Bijector{N}
α::T
end

@functor LeakyReLU

LeakyReLU::T; dim::Val{N} = Val(0)) where {T<:Real, N} = LeakyReLU{T, N}(α)
LeakyReLU::T; dim::Val{N} = Val(D)) where {D, T<:AbstractArray{<:Real, D}, N} = LeakyReLU{T, N}(α)

Expand Down
9 changes: 9 additions & 0 deletions src/bijectors/named_bijector.jl
Expand Up @@ -29,6 +29,8 @@ struct NamedBijector{names, Bs<:NamedTuple{names}} <: AbstractNamedBijector
bs::Bs
end

@functor NamedBijector

names_to_bijectors(b::NamedBijector) = b.bs

@generated function (b::NamedBijector{names1})(
Expand Down Expand Up @@ -70,6 +72,9 @@ See also: [`Inverse`](@ref)
struct NamedInverse{B<:AbstractNamedBijector} <: AbstractNamedBijector
orig::B
end

@functor NamedInverse

Base.inv(nb::AbstractNamedBijector) = NamedInverse(nb)
Base.inv(ni::NamedInverse) = ni.orig

Expand All @@ -93,6 +98,8 @@ struct NamedComposition{Bs} <: AbstractNamedBijector
bs::Bs
end

@functor NamedComposition

# Essentially just copy-paste from impl of composition for 'standard' bijectors,
# with minor changes here and there.
composel(bs::AbstractNamedBijector...) = NamedComposition(bs)
Expand Down Expand Up @@ -201,6 +208,8 @@ struct NamedCoupling{target, deps, F} <: AbstractNamedBijector where {F, target}
f::F
end

@functor NamedCoupling

NamedCoupling(target, deps, f::F) where {F} = NamedCoupling{target, deps, F}(f)
function NamedCoupling(::Val{target}, ::Val{deps}, f::F) where {target, deps, F}
return NamedCoupling{target, deps, F}(f)
Expand Down
3 changes: 3 additions & 0 deletions src/bijectors/normalise.jl
Expand Up @@ -14,6 +14,9 @@ mutable struct InvertibleBatchNorm{T1,T2,T3} <: Bijector{1}
eps :: T3
mtm :: T3 # momentum
end

@functor InvertibleBatchNorm

function Base.:(==)(b1::InvertibleBatchNorm, b2::InvertibleBatchNorm)
return b1.b == b2.b &&
b1.logs == b2.logs &&
Expand Down
3 changes: 3 additions & 0 deletions src/bijectors/planar_layer.jl
Expand Up @@ -20,6 +20,9 @@ mutable struct PlanarLayer{T1<:AbstractVector{<:Real}, T2<:Union{Real, AbstractV
u::T1
b::T2
end

@functor PlanarLayer

function Base.:(==)(b1::PlanarLayer, b2::PlanarLayer)
return b1.w == b2.w && b1.u == b2.u && b1.b == b2.b
end
Expand Down
3 changes: 3 additions & 0 deletions src/bijectors/radial_layer.jl
Expand Up @@ -18,6 +18,9 @@ mutable struct RadialLayer{T1<:Union{Real, AbstractVector{<:Real}}, T2<:Abstract
β::T1
z_0::T2
end

@functor RadialLayer

function Base.:(==)(b1::RadialLayer, b2::RadialLayer)
return b1.α_ == b2.α_ && b1.β == b2.β && b1.z_0 == b2.z_0
end
Expand Down
4 changes: 3 additions & 1 deletion src/bijectors/scale.jl
Expand Up @@ -2,6 +2,8 @@ struct Scale{T, N} <: Bijector{N}
a::T
end

@functor Scale

Base.:(==)(b1::Scale{<:Any, N}, b2::Scale{<:Any, N}) where {N} = b1.a == b2.a

function Scale(a::Union{Real,AbstractArray}; dim::Val{D} = Val(ndims(a))) where D
Expand Down Expand Up @@ -46,4 +48,4 @@ function _logabsdetjac_scale(
map(x) do x
_logabsdetjac_scale(a, x, Val(2))
end
end
end
2 changes: 2 additions & 0 deletions src/bijectors/shift.jl
Expand Up @@ -5,6 +5,8 @@ struct Shift{T, N} <: Bijector{N}
a::T
end

@functor Shift

Base.:(==)(b1::Shift{<:Any, N}, b2::Shift{<:Any, N}) where {N} = b1.a == b2.a

function Shift(a::Union{Real,AbstractArray}; dim::Val{D} = Val(ndims(a))) where D
Expand Down
3 changes: 3 additions & 0 deletions src/bijectors/stacked.jl
Expand Up @@ -41,6 +41,9 @@ struct Stacked{Bs, N} <: Bijector{1}
return new{A, N}(bs, ranges)
end
end

@functor Stacked

Stacked(bs, ranges::AbstractArray) = Stacked(bs, tuple(ranges...))
Stacked(bs) = Stacked(bs, tuple([i:i for i = 1:length(bs)]...))

Expand Down
3 changes: 3 additions & 0 deletions src/bijectors/truncated.jl
Expand Up @@ -5,6 +5,9 @@ struct TruncatedBijector{N, T1, T2} <: Bijector{N}
lb::T1
ub::T2
end

@functor TruncatedBijector

TruncatedBijector(lb, ub) = TruncatedBijector{0}(lb, ub)
function TruncatedBijector{N}(lb::T1, ub::T2) where {N, T1, T2}
return TruncatedBijector{N, T1, T2}(lb, ub)
Expand Down
3 changes: 3 additions & 0 deletions src/interface.jl
Expand Up @@ -66,6 +66,9 @@ struct Inverse{B <: Bijector, N} <: Bijector{N}

Inverse(b::B) where {N, B<:Bijector{N}} = new{B, N}(b)
end

@functor Inverse

up1(b::Inverse) = Inverse(up1(b.orig))

inv(b::Bijector) = Inverse(b)
Expand Down

0 comments on commit 05f13cf

Please sign in to comment.