-
Notifications
You must be signed in to change notification settings - Fork 32
/
interface.jl
169 lines (135 loc) · 4.92 KB
/
interface.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import Base: inv, ∘
import Random: AbstractRNG
import Distributions: logpdf, rand, rand!, _rand!, _logpdf
#######################################
# AD stuff "extracted" from Turing.jl #
#######################################
abstract type ADBackend end
struct ForwardDiffAD <: ADBackend end
struct ReverseDiffAD <: ADBackend end
struct TrackerAD <: ADBackend end
struct ZygoteAD <: ADBackend end
const ADBACKEND = Ref(:forwarddiff)
setadbackend(backend_sym::Symbol) = setadbackend(Val(backend_sym))
setadbackend(::Val{:forwarddiff}) = ADBACKEND[] = :forwarddiff
setadbackend(::Val{:reversediff}) = ADBACKEND[] = :reversediff
setadbackend(::Val{:tracker}) = ADBACKEND[] = :tracker
setadbackend(::Val{:zygote}) = ADBACKEND[] = :zygote
ADBackend() = ADBackend(ADBACKEND[])
ADBackend(T::Symbol) = ADBackend(Val(T))
ADBackend(::Val{:forwarddiff}) = ForwardDiffAD
ADBackend(::Val{:reversediff}) = ReverseDiffAD
ADBackend(::Val{:tracker}) = TrackerAD
ADBackend(::Val{:zygote}) = ZygoteAD
ADBackend(::Val) = error("The requested AD backend is not available. Make sure to load all required packages.")
######################
# Bijector interface #
######################
"Abstract type for a bijector."
abstract type AbstractBijector end
"Abstract type of bijectors with fixed dimensionality."
abstract type Bijector{N} <: AbstractBijector end
dimension(b::Bijector{N}) where {N} = N
dimension(b::Type{<:Bijector{N}}) where {N} = N
Broadcast.broadcastable(b::Bijector) = Ref(b)
"""
isclosedform(b::Bijector)::bool
isclosedform(b⁻¹::Inverse{<:Bijector})::bool
Returns `true` or `false` depending on whether or not evaluation of `b`
has a closed-form implementation.
Most bijectors have closed-form evaluations, but there are cases where
this is not the case. For example the *inverse* evaluation of `PlanarLayer`
requires an iterative procedure to evaluate and thus is not differentiable.
"""
isclosedform(b::Bijector) = true
"""
inv(b::Bijector)
Inverse(b::Bijector)
A `Bijector` representing the inverse transform of `b`.
"""
struct Inverse{B <: Bijector, N} <: Bijector{N}
orig::B
Inverse(b::B) where {N, B<:Bijector{N}} = new{B, N}(b)
end
@functor Inverse
up1(b::Inverse) = Inverse(up1(b.orig))
inv(b::Bijector) = Inverse(b)
inv(ib::Inverse{<:Bijector}) = ib.orig
Base.:(==)(b1::Inverse{<:Bijector}, b2::Inverse{<:Bijector}) = b1.orig == b2.orig
"""
logabsdetjac(b::Bijector, x)
logabsdetjac(ib::Inverse{<:Bijector}, y)
Computes the log(abs(det(J(b(x))))) where J is the jacobian of the transform.
Similarily for the inverse-transform.
Default implementation for `Inverse{<:Bijector}` is implemented as
`- logabsdetjac` of original `Bijector`.
"""
logabsdetjac(ib::Inverse{<:Bijector}, y) = - logabsdetjac(ib.orig, ib(y))
"""
forward(b::Bijector, x)
Computes both `transform` and `logabsdetjac` in one forward pass, and
returns a named tuple `(rv=b(x), logabsdetjac=logabsdetjac(b, x))`.
This defaults to the call above, but often one can re-use computation
in the computation of the forward pass and the computation of the
`logabsdetjac`. `forward` allows the user to take advantange of such
efficiencies, if they exist.
"""
forward(b::Bijector, x) = (rv=b(x), logabsdetjac=logabsdetjac(b, x))
"""
logabsdetjacinv(b::Bijector, y)
Just an alias for `logabsdetjac(inv(b), y)`.
"""
logabsdetjacinv(b::Bijector, y) = logabsdetjac(inv(b), y)
##############################
# Example bijector: Identity #
##############################
struct Identity{N} <: Bijector{N} end
(::Identity)(x) = copy(x)
inv(b::Identity) = b
up1(::Identity{N}) where {N} = Identity{N + 1}()
logabsdetjac(::Identity{0}, x::Real) = zero(eltype(x))
@generated function logabsdetjac(
b::Identity{N1},
x::AbstractArray{T2, N2}
) where {N1, T2, N2}
if N1 == N2
return :(zero(eltype(x)))
elseif N1 + 1 == N2
return :(zeros(eltype(x), size(x, $N2)))
else
return :(throw(MethodError(logabsdetjac, (b, x))))
end
end
logabsdetjac(::Identity{2}, x::AbstractArray{<:AbstractMatrix}) = zeros(eltype(x[1]), size(x))
########################
# Convenient constants #
########################
const ZeroOrOneDimBijector = Union{Bijector{0}, Bijector{1}}
######################
# Bijectors includes #
######################
# General
include("bijectors/adbijector.jl")
include("bijectors/composed.jl")
include("bijectors/stacked.jl")
# Specific
include("bijectors/exp_log.jl")
include("bijectors/logit.jl")
include("bijectors/scale.jl")
include("bijectors/shift.jl")
include("bijectors/permute.jl")
include("bijectors/simplex.jl")
include("bijectors/pd.jl")
include("bijectors/corr.jl")
include("bijectors/truncated.jl")
include("bijectors/named_bijector.jl")
# Normalizing flow related
include("bijectors/planar_layer.jl")
include("bijectors/radial_layer.jl")
include("bijectors/leaky_relu.jl")
include("bijectors/coupling.jl")
include("bijectors/normalise.jl")
##################
# Other includes #
##################
include("transformed_distribution.jl")