/
state_space_restrictions.jl
142 lines (118 loc) · 3.92 KB
/
state_space_restrictions.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#==============================================================================#
#
# Structs responsible for determining the restrictions imposed on the
# state space of a diffusion process
#
#==============================================================================#
"""
UnboundedStateSpace <: DiffusionStateSpace
No restrictions imposed on the state-space of the process (i.e. ℝᵈ)
"""
struct UnboundedStateSpace <: DiffusionStateSpace end
"""
LowerBoundedStateSpace{T,S,N} <: DiffusionStateSpace
Lower bounds imposed on the state-space of a diffusion process. `T` is used
to list the indices that have lower-bound restrictions, `S` indicates the values
of the lower-bounds, `N` is the total number of coordinates with lower-bound
restrictions
"""
struct LowerBoundedStateSpace{T,S,N} <: DiffusionStateSpace
function LowerBoundedStateSpace(
coords::NTuple{N,Integer},
bounds::NTuple{N,Number},
) where N
new{coords, bounds, N}()
end
function LowerBoundedStateSpace(
coords,
bounds,
)
N = length(bounds)
@assert length(coords) == N
@assert all( map(c->(typeof(c)<:Integer), coords) )
@assert all( map(b->(typeof(b)<:Number), bounds) )
new{Tuple(coords), Tuple(bounds), N}()
end
end
bound_info(::LowerBoundedStateSpace{T,S,N}) where {T,S,N} = T,S,N
"""
UpperBoundedStateSpace{T,S,N} <: DiffusionStateSpace
Upper bounds imposed on the state-space of a diffusion process. `T` is used
to list the indices that have upper-bound restrictions, `S` indicates the values
of the upper-bounds, `N` is the total number of coordinates with upper-bound
restrictions
"""
struct UpperBoundedStateSpace{T,S,N} <: DiffusionStateSpace
function UpperBoundedStateSpace(coords, bounds)
T,S,N = bound_info(LowerBoundedStateSpace(coords, bounds))
new{T,S,N}()
end
end
"""
BoundedStateSpace{L,U} <: DiffusionStateSpace
Upper and lower bounds imposed on the state-space of a diffusion process.
`L` corresponds to lower bounds, `U` corresponds to upper bounds.
"""
struct BoundedStateSpace{L,U} <: DiffusionStateSpace
function BoundedStateSpace(
(coords_lower, bounds_lower),
(coords_upper, bounds_upper)
)
L = LowerBoundedStateSpace(coords_lower, bounds_lower)
U = UpperBoundedStateSpace(coords_upper, bounds_upper)
new{L,U}()
end
function BoundedStateSpace(
L::LowerBoundedStateSpace,
U::UpperBoundedStateSpace
)
new{L,U}()
end
end
"""
bound_satisfied(::UnboundedStateSpace, x)
No restrictions, bounds satisfied by default
"""
@inline _bound_satisfied(::UnboundedStateSpace, x) = true
"""
bound_satisfied(::LowerBoundedStateSpace{T,S,N}, x) where {T,S,N}
Checks if all coordinates adhere to lower bound restrictions
"""
@generated function _bound_satisfied(
::LowerBoundedStateSpace{T,S,N},
x
) where {T,S,N}
ex = :(true)
for i = 1:N
ex = :(x[T[$i]] > S[$i] ? $ex : false)
end
return ex
end
"""
bound_satisfied(::UpperBoundedStateSpace{T,S,N}, x) where {T,S,N}
Checks if all coordinates adhere to upper bound restrictions
"""
@generated function _bound_satisfied(
::UpperBoundedStateSpace{T,S,N},
x
) where {T,S,N}
ex = :(true)
for i = 1:N
ex = :(x[T[$i]] < S[$i] ? $ex : false)
end
return ex
end
"""
bound_satisfied(::BoundedStateSpace{L,U}, x) where {L,U}
Checks if all coordinates adhere to lower and upper bound restrictions
"""
function _bound_satisfied(::BoundedStateSpace{L,U}, x) where {L,U}
_bound_satisfied(L, x) && _bound_satisfied(U, x)
end
function bound_info(::DiffusionProcess{T,DP,DW,SS}) where {T,DP,DW,SS}
bound_info(SS)
end
bound_satisfied(P::DiffusionProcess, x) = _bound_satisfied(P, x)
function _bound_satisfied(::DiffusionProcess{T,DP,DW,SS}, x) where {T,DP,DW,SS}
_bound_satisfied(SS, x)
end