-
Notifications
You must be signed in to change notification settings - Fork 24
/
model.jl
166 lines (132 loc) · 6.74 KB
/
model.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
export ProbabilisticModel
export getmodel, getreturnval, getvardict, getrandomvars, getconstantvars, getdatavars, getfactornodes
import Base: push!, show, getindex, haskey, firstindex, lastindex
import ReactiveMP: getaddons, AbstractFactorNode
import GraphPPL: ModelGenerator, getmodel, getkwargs, create_model
import Rocket: getscheduler
"A structure that holds the factor graph representation of a probabilistic model."
struct ProbabilisticModel{M}
model::M
end
"Returns the underlying factor graph model."
getmodel(model::ProbabilisticModel) = model.model
"Returns the value from the `return ...` operator inside the model specification."
getreturnval(model::ProbabilisticModel) = getreturnval(getmodel(model))
"Returns the (nested) dictionary of random variables from the model specification."
getvardict(model::ProbabilisticModel) = getvardict(getmodel(model))
"Returns the random variables from the model specification."
getrandomvars(model::ProbabilisticModel) = getrandomvars(getmodel(model))
"Returns the data variables from the model specification."
getdatavars(model::ProbabilisticModel) = getdatavars(getmodel(model))
"Returns the constant variables from the model specification."
getconstantvars(model::ProbabilisticModel) = getconstantvars(getmodel(model))
"Returns the factor nodes from the model specification."
getfactornodes(model::ProbabilisticModel) = getfactornodes(getmodel(model))
# Redirect the `getvarref` call to the underlying model
getvarref(model::ProbabilisticModel, label) = getvarref(getmodel(model), label)
"""
ConditionedModelGenerator(generator, conditioned_on)
Accepts a model generator and data to condition on.
The `generator` must be `GraphPPL.ModelGenerator` object.
The `conditioned_on` must be named tuple or a dictionary with keys corresponding to the names of the input arguments in the model.
"""
struct ConditionedModelGenerator{G, D}
generator::G
conditioned_on::D
end
getgenerator(generator::ConditionedModelGenerator) = generator.generator
getconditioned_on(generator::ConditionedModelGenerator) = generator.conditioned_on
"""
condition_on(generator::ModelGenerator; kwargs...)
A function that creates a `ConditionedModelGenerator` object from `GraphPPL.ModelGenerator`.
The `|` operator can be used as a shorthand for this function.
```jldoctest
julia> using RxInfer
julia> @model function beta_bernoulli(y, a, b)
θ ~ Beta(a, b)
y .~ Bernoulli(θ)
end
julia> conditioned_model = beta_bernoulli(a = 1.0, b = 2.0) | (y = [ 1.0, 0.0, 1.0 ], )
beta_bernoulli(a = 1.0, b = 2.0) conditioned on:
y = [1.0, 0.0, 1.0]
julia> RxInfer.create_model(conditioned_model) isa RxInfer.ProbabilisticModel
true
```
"""
function condition_on(generator::ModelGenerator; kwargs...)
return ConditionedModelGenerator(generator, NamedTuple(kwargs))
end
function condition_on(generator::ModelGenerator, data)
return ConditionedModelGenerator(generator, data)
end
"""
An alias for [`RxInfer.condition_on`](@ref).
"""
function Base.:(|)(generator::ModelGenerator, data)
return condition_on(generator, data)
end
function Base.show(io::IO, generator::ConditionedModelGenerator)
print(io, getmodel(getgenerator(generator)), "(")
print(io, join(Iterators.map(kv -> string(kv[1], " = ", kv[2]), getkwargs(getgenerator(generator))), ", "))
print(io, ")")
if !isnothing(getconditioned_on(generator))
println(io, " conditioned on: ")
foreach(keys(getconditioned_on(generator))) do key
println(io, " ", key, " = ", getconditioned_on(generator)[key])
end
end
end
"""
create_model(generator::ConditionedModelGenerator)
Materializes the model specification conditioned on some data into a corresponding factor graph representation.
Returns [`ProbabilisticModel`](@ref).
"""
function create_model(generator::ConditionedModelGenerator)
return __infer_create_factor_graph_model(getgenerator(generator), getconditioned_on(generator))
end
function __infer_create_factor_graph_model(::ModelGenerator, conditioned_on)
error("Cannot create a factor graph model from a `ModelGenerator` object. The `data` object must be a `NamedTuple` or a `Dict`. Got `$(typeof(conditioned_on))` instead.")
end
# This function works for static data, such as `NamedTuple` or a `Dict`
function __infer_create_factor_graph_model(generator::ModelGenerator, conditioned_on::Union{NamedTuple, Dict})
# If the data is already a `NamedTuple` this should not really matter
# But it makes it easier to deal with the `Dict` type, which is unordered by default
ntdata = NamedTuple(conditioned_on)::NamedTuple
model = create_model(generator) do model, ctx
ikeys = keys(ntdata)
interfaces = map(ikeys) do key
return __infer_create_data_interface(model, ctx, key, ntdata[key])
end
return NamedTuple{ikeys}(interfaces)
end
return ProbabilisticModel(model)
end
"""
An object that is used to condition on unknown data. That may be necessary to create a model from a `ModelGenerator` object
for which data is not known at the time of the model creation.
"""
struct DeferredDataHandler end
function Base.show(io::IO, ::DeferredDataHandler)
print(io, "[ deffered data ]")
end
# We use the `LazyIndex` to instantiate the data interface for the model, in case of `DeferredDataHandler`
# the data is not known at the time of the model creation
function __infer_create_data_interface(model, context, key::Symbol, ::DeferredDataHandler)
return GraphPPL.getorcreate!(model, context, GraphPPL.NodeCreationOptions(kind = :data, factorized = true), key, GraphPPL.LazyIndex())
end
# In all other cases we use the `LazyIndex` to instantiate the data interface for the model and the data is known at the time of the model creation
function __infer_create_data_interface(model, context, key::Symbol, data)
return GraphPPL.getorcreate!(model, context, GraphPPL.NodeCreationOptions(kind = :data, factorized = true), key, GraphPPL.LazyIndex(data))
end
merge_data_handlers(data::Dict, newdata::Dict) = merge(data, newdata)
merge_data_handlers(data::Dict, newdata::NamedTuple) = merge(data, convert(Dict, newdata))
merge_data_handlers(data::NamedTuple, newdata::Dict) = merge(convert(Dict, data), newdata)
merge_data_handlers(data::NamedTuple, newdata::NamedTuple) = merge(data, newdata)
# This function creates a named tuple of `DeferredDataHandler` objects from a tuple of symbols
function create_deffered_data_handlers(symbols::NTuple{N, Symbol}) where {N}
return NamedTuple{symbols}(map(_ -> DeferredDataHandler(), symbols))
end
# This function creates a dictionary of `DeferredDataHandler` objects from an array of symbols
function create_deffered_data_handlers(symbols::AbstractVector{Symbol})
return Dict{Symbol, DeferredDataHandler}(map(s -> s => DeferredDataHandler(), symbols))
end