-
Notifications
You must be signed in to change notification settings - Fork 794
/
EnumerationAsk.java
226 lines (205 loc) · 7.02 KB
/
EnumerationAsk.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
package aima.core.probability.bayes.exact;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import aima.core.probability.CategoricalDistribution;
import aima.core.probability.RandomVariable;
import aima.core.probability.bayes.BayesInference;
import aima.core.probability.bayes.BayesianNetwork;
import aima.core.probability.bayes.FiniteNode;
import aima.core.probability.bayes.Node;
import aima.core.probability.domain.FiniteDomain;
import aima.core.probability.proposition.AssignmentProposition;
import aima.core.probability.util.ProbabilityTable;
import aima.core.util.Util;
/**
* Artificial Intelligence A Modern Approach (3rd Edition): Figure 14.9, page
* 525.<br>
* <br>
*
* <pre>
* function ENUMERATION-ASK(X, e, bn) returns a distribution over X
* inputs: X, the query variable
* e, observed values for variables E
* bn, a Bayes net with variables {X} ∪ E ∪ Y /* Y = hidden variables //
*
* Q(X) <- a distribution over X, initially empty
* for each value x<sub>i</sub> of X do
* Q(x<sub>i</sub>) <- ENUMERATE-ALL(bn.VARS, e<sub>x<sub>i</sub></sub>)
* where e<sub>x<sub>i</sub></sub> is e extended with X = x<sub>i</sub>
* return NORMALIZE(Q(X))
*
* ---------------------------------------------------------------------------------------------------
*
* function ENUMERATE-ALL(vars, e) returns a real number
* if EMPTY?(vars) then return 1.0
* Y <- FIRST(vars)
* if Y has value y in e
* then return P(y | parents(Y)) * ENUMERATE-ALL(REST(vars), e)
* else return ∑<sub>y</sub> P(y | parents(Y)) * ENUMERATE-ALL(REST(vars), e<sub>y</sub>)
* where e<sub>y</sub> is e extended with Y = y
* </pre>
*
* Figure 14.9 The enumeration algorithm for answering queries on Bayesian
* networks. <br>
* <br>
* <b>Note:</b> The implementation has been extended to handle queries with
* multiple variables. <br>
*
* @author Ciaran O'Reilly
*/
public class EnumerationAsk implements BayesInference {
public EnumerationAsk() {
}
// function ENUMERATION-ASK(X, e, bn) returns a distribution over X
/**
* The ENUMERATION-ASK algorithm in Figure 14.9 evaluates expression trees
* (Figure 14.8) using depth-first recursion.
*
* @param X
* the query variables.
* @param observedEvidence
* observed values for variables E.
* @param bn
* a Bayes net with variables {X} ∪ E ∪ Y /* Y = hidden
* variables //
* @return a distribution over the query variables.
*/
public CategoricalDistribution enumerationAsk(final RandomVariable[] X,
final AssignmentProposition[] observedEvidence,
final BayesianNetwork bn) {
// Q(X) <- a distribution over X, initially empty
final ProbabilityTable Q = new ProbabilityTable(X);
final ObservedEvidence e = new ObservedEvidence(X, observedEvidence, bn);
// for each value x<sub>i</sub> of X do
ProbabilityTable.Iterator di = new ProbabilityTable.Iterator() {
int cnt = 0;
/**
* <pre>
* Q(x<sub>i</sub>) <- ENUMERATE-ALL(bn.VARS, e<sub>x<sub>i</sub></sub>)
* where e<sub>x<sub>i</sub></sub> is e extended with X = x<sub>i</sub>
* </pre>
*/
public void iterate(Map<RandomVariable, Object> possibleWorld,
double probability) {
for (int i = 0; i < X.length; i++) {
e.setExtendedValue(X[i], possibleWorld.get(X[i]));
}
Q.setValue(cnt,
enumerateAll(bn.getVariablesInTopologicalOrder(), e));
cnt++;
}
};
Q.iterateOverTable(di);
// return NORMALIZE(Q(X))
return Q.normalize();
}
//
// START-BayesInference
public CategoricalDistribution ask(final RandomVariable[] X,
final AssignmentProposition[] observedEvidence,
final BayesianNetwork bn) {
return this.enumerationAsk(X, observedEvidence, bn);
}
// END-BayesInference
//
//
// PROTECTED METHODS
//
// function ENUMERATE-ALL(vars, e) returns a real number
protected double enumerateAll(List<RandomVariable> vars, ObservedEvidence e) {
// if EMPTY?(vars) then return 1.0
if (0 == vars.size()) {
return 1;
}
// Y <- FIRST(vars)
RandomVariable Y = Util.first(vars);
// if Y has value y in e
if (e.containsValue(Y)) {
// then return P(y | parents(Y)) * ENUMERATE-ALL(REST(vars), e)
return e.posteriorForParents(Y) * enumerateAll(Util.rest(vars), e);
}
/**
* <pre>
* else return ∑<sub>y</sub> P(y | parents(Y)) * ENUMERATE-ALL(REST(vars), e<sub>y</sub>)
* where e<sub>y</sub> is e extended with Y = y
* </pre>
*/
double sum = 0;
for (Object y : ((FiniteDomain) Y.getDomain()).getPossibleValues()) {
e.setExtendedValue(Y, y);
sum += e.posteriorForParents(Y) * enumerateAll(Util.rest(vars), e);
}
return sum;
}
protected class ObservedEvidence {
private BayesianNetwork bn = null;
private Object[] extendedValues = null;
private int hiddenStart = 0;
private int extendedIdx = 0;
private RandomVariable[] var = null;
private Map<RandomVariable, Integer> varIdxs = new HashMap<RandomVariable, Integer>();
public ObservedEvidence(RandomVariable[] queryVariables,
AssignmentProposition[] e, BayesianNetwork bn) {
this.bn = bn;
int maxSize = bn.getVariablesInTopologicalOrder().size();
extendedValues = new Object[maxSize];
var = new RandomVariable[maxSize];
// query variables go first
int idx = 0;
for (int i = 0; i < queryVariables.length; i++) {
var[idx] = queryVariables[i];
varIdxs.put(var[idx], idx);
idx++;
}
// initial evidence variables go next
for (int i = 0; i < e.length; i++) {
var[idx] = e[i].getTermVariable();
varIdxs.put(var[idx], idx);
extendedValues[idx] = e[i].getValue();
idx++;
}
extendedIdx = idx - 1;
hiddenStart = idx;
// the remaining slots are left open for the hidden variables
for (RandomVariable rv : bn.getVariablesInTopologicalOrder()) {
if (!varIdxs.containsKey(rv)) {
var[idx] = rv;
varIdxs.put(var[idx], idx);
idx++;
}
}
}
public void setExtendedValue(RandomVariable rv, Object value) {
int idx = varIdxs.get(rv);
extendedValues[idx] = value;
if (idx >= hiddenStart) {
extendedIdx = idx;
} else {
extendedIdx = hiddenStart - 1;
}
}
public boolean containsValue(RandomVariable rv) {
return varIdxs.get(rv) <= extendedIdx;
}
public double posteriorForParents(RandomVariable rv) {
Node n = bn.getNode(rv);
if (!(n instanceof FiniteNode)) {
throw new IllegalArgumentException(
"Enumeration-Ask only works with finite Nodes.");
}
FiniteNode fn = (FiniteNode) n;
Object[] vals = new Object[1 + fn.getParents().size()];
int idx = 0;
for (Node pn : n.getParents()) {
vals[idx] = extendedValues[varIdxs.get(pn.getRandomVariable())];
idx++;
}
vals[idx] = extendedValues[varIdxs.get(rv)];
return fn.getCPT().getValue(vals);
}
}
//
// PRIVATE METHODS
//
}