-
Notifications
You must be signed in to change notification settings - Fork 32
/
truncated.jl
145 lines (139 loc) · 4.51 KB
/
truncated.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#######################################################
# Constrained to unconstrained distribution bijectors #
#######################################################
struct TruncatedBijector{N, T1, T2} <: Bijector{N}
lb::T1
ub::T2
end
@functor TruncatedBijector
TruncatedBijector(lb, ub) = TruncatedBijector{0}(lb, ub)
function TruncatedBijector{N}(lb::T1, ub::T2) where {N, T1, T2}
return TruncatedBijector{N, T1, T2}(lb, ub)
end
up1(b::TruncatedBijector{N}) where {N} = TruncatedBijector{N + 1}(b.lb, b.ub)
function Base.:(==)(b1::TruncatedBijector, b2::TruncatedBijector)
return b1.lb == b2.lb && b1.ub == b2.ub
end
function (b::TruncatedBijector{0})(x::Real)
a, b = b.lb, b.ub
truncated_link(_clamp(x, a, b), a, b)
end
function (b::TruncatedBijector{0})(x::AbstractArray{<:Real})
a, b = b.lb, b.ub
truncated_link.(_clamp.(x, a, b), a, b)
end
function (b::TruncatedBijector{1})(x::AbstractVecOrMat{<:Real})
a, b = b.lb, b.ub
if a isa AbstractVector
@assert b isa AbstractVector
maporbroadcast(x, a, b) do x, a, b
truncated_link(_clamp(x, a, b), a, b)
end
else
truncated_link.(_clamp.(x, a, b), a, b)
end
end
function (b::TruncatedBijector{2})(x::AbstractMatrix{<:Real})
a, b = b.lb, b.ub
if a isa AbstractMatrix
@assert b isa AbstractMatrix
maporbroadcast(x, a, b) do x, a, b
truncated_link(_clamp(x, a, b), a, b)
end
else
truncated_link.(_clamp.(x, a, b), a, b)
end
end
(b::TruncatedBijector{2})(x::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, x)
function truncated_link(x::Real, a, b)
lowerbounded, upperbounded = isfinite(a), isfinite(b)
if lowerbounded && upperbounded
return StatsFuns.logit((x - a) / (b - a))
elseif lowerbounded
return log(x - a)
elseif upperbounded
return log(b - x)
else
return x
end
end
function (ib::Inverse{<:TruncatedBijector{0}})(y::Real)
a, b = ib.orig.lb, ib.orig.ub
_clamp(truncated_invlink(y, a, b), a, b)
end
function (ib::Inverse{<:TruncatedBijector{0}})(y::AbstractArray{<:Real})
a, b = ib.orig.lb, ib.orig.ub
_clamp.(truncated_invlink.(y, a, b), a, b)
end
function (ib::Inverse{<:TruncatedBijector{1}})(y::AbstractVecOrMat{<:Real})
a, b = ib.orig.lb, ib.orig.ub
if a isa AbstractVector
@assert b isa AbstractVector
maporbroadcast(y, a, b) do y, a, b
_clamp(truncated_invlink(y, a, b), a, b)
end
else
_clamp.(truncated_invlink.(y, a, b), a, b)
end
end
function (ib::Inverse{<:TruncatedBijector{2}})(y::AbstractMatrix{<:Real})
a, b = ib.orig.lb, ib.orig.ub
if a isa AbstractMatrix
@assert b isa AbstractMatrix
return maporbroadcast(y, a, b) do y, a, b
_clamp(truncated_invlink(y, a, b), a, b)
end
else
return _clamp.(truncated_invlink.(y, a, b), a, b)
end
end
(ib::Inverse{<:TruncatedBijector{2}})(y::AbstractArray{<:AbstractMatrix{<:Real}}) = map(ib, y)
function truncated_invlink(y, a, b)
lowerbounded, upperbounded = isfinite(a), isfinite(b)
if lowerbounded && upperbounded
return (b - a) * StatsFuns.logistic(y) + a
elseif lowerbounded
return exp(y) + a
elseif upperbounded
return b - exp(y)
else
return y
end
end
function logabsdetjac(b::TruncatedBijector{0}, x::Real)
a, b = b.lb, b.ub
truncated_logabsdetjac(_clamp(x, a, b), a, b)
end
function logabsdetjac(b::TruncatedBijector{0}, x::AbstractArray{<:Real})
a, b = b.lb, b.ub
truncated_logabsdetjac.(_clamp.(x, a, b), a, b)
end
function logabsdetjac(b::TruncatedBijector{1}, x::AbstractVector{<:Real})
a, b = b.lb, b.ub
sum(truncated_logabsdetjac.(_clamp.(x, a, b), a, b))
end
function logabsdetjac(b::TruncatedBijector{1}, x::AbstractMatrix{<:Real})
a, b = b.lb, b.ub
vec(sum(truncated_logabsdetjac.(_clamp.(x, a, b), a, b), dims = 1))
end
function logabsdetjac(b::TruncatedBijector{2}, x::AbstractMatrix{<:Real})
a, b = b.lb, b.ub
sum(truncated_logabsdetjac.(_clamp.(x, a, b), a, b))
end
function logabsdetjac(b::TruncatedBijector{2}, x::AbstractArray{<:AbstractMatrix{<:Real}})
map(x) do x
logabsdetjac(b, x)
end
end
function truncated_logabsdetjac(x, a, b)
lowerbounded, upperbounded = isfinite(a), isfinite(b)
if lowerbounded && upperbounded
return - log((x - a) * (b - x) / (b - a))
elseif lowerbounded
return - log(x - a)
elseif upperbounded
return - log(b - x)
else
return zero(x)
end
end