/
Boss.java
149 lines (127 loc) · 5.06 KB
/
Boss.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
package edu.cmu.tetrad.algcomparison.algorithm.oracle.cpdag;
import edu.cmu.tetrad.algcomparison.algorithm.Algorithm;
import edu.cmu.tetrad.algcomparison.algorithm.ReturnsBootstrapGraphs;
import edu.cmu.tetrad.algcomparison.score.ScoreWrapper;
import edu.cmu.tetrad.algcomparison.utils.HasKnowledge;
import edu.cmu.tetrad.algcomparison.utils.UsesScoreWrapper;
import edu.cmu.tetrad.annotation.AlgType;
import edu.cmu.tetrad.annotation.Bootstrapping;
import edu.cmu.tetrad.data.DataModel;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataType;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.search.PermutationSearch;
import edu.cmu.tetrad.search.score.Score;
import edu.cmu.tetrad.search.utils.LogUtilsSearch;
import edu.cmu.tetrad.search.utils.TsUtils;
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.Params;
import edu.pitt.dbmi.algo.resampling.GeneralResamplingTest;
import java.util.ArrayList;
import java.util.List;
/**
* BOSS-DC (Best Order Score Search Divide and Conquer)
*
* @author bryanandrews
* @author josephramsey
*/
@edu.cmu.tetrad.annotation.Algorithm(
name = "BOSS",
command = "boss",
algoType = AlgType.forbid_latent_common_causes
)
@Bootstrapping
public class Boss implements Algorithm, UsesScoreWrapper, HasKnowledge,
ReturnsBootstrapGraphs {
private static final long serialVersionUID = 23L;
private ScoreWrapper score;
private Knowledge knowledge = new Knowledge();
private List<Graph> bootstrapGraphs = new ArrayList<>();
public Boss() {
// Used in reflection; do not delete.
}
public Boss(ScoreWrapper score) {
this.score = score;
}
@Override
public Graph search(DataModel dataModel, Parameters parameters) {
if (parameters.getInt(Params.NUMBER_RESAMPLING) < 1) {
if (parameters.getInt(Params.TIME_LAG) > 0) {
DataSet dataSet = (DataSet) dataModel;
DataSet timeSeries = TsUtils.createLagData(dataSet, parameters.getInt(Params.TIME_LAG));
if (dataSet.getName() != null) {
timeSeries.setName(dataSet.getName());
}
dataModel = timeSeries;
knowledge = timeSeries.getKnowledge();
}
Score score = this.score.getScore(dataModel, parameters);
edu.cmu.tetrad.search.Boss boss = new edu.cmu.tetrad.search.Boss(score);
boss.setUseBes(parameters.getBoolean(Params.USE_BES));
boss.setNumStarts(parameters.getInt(Params.NUM_STARTS));
boss.setNumThreads(parameters.getInt(Params.NUM_THREADS));
boss.setUseDataOrder(parameters.getBoolean(Params.USE_DATA_ORDER));
boss.setVerbose(parameters.getBoolean(Params.VERBOSE));
PermutationSearch permutationSearch = new PermutationSearch(boss);
permutationSearch.setKnowledge(this.knowledge);
Graph graph = permutationSearch.search();
LogUtilsSearch.stampWithScores(graph, dataModel, score);
return graph;
} else {
Boss algorithm = new Boss(this.score);
DataSet data = (DataSet) dataModel;
GeneralResamplingTest search = new GeneralResamplingTest(data, algorithm, parameters.getInt(Params.NUMBER_RESAMPLING), parameters.getDouble(Params.PERCENT_RESAMPLE_SIZE), parameters.getBoolean(Params.RESAMPLING_WITH_REPLACEMENT), parameters.getInt(Params.RESAMPLING_ENSEMBLE), parameters.getBoolean(Params.ADD_ORIGINAL_DATASET));
search.setKnowledge(this.knowledge);
search.setParameters(parameters);
Graph graph = search.search();
if (parameters.getBoolean(Params.SAVE_BOOTSTRAP_GRAPHS)) this.bootstrapGraphs = search.getGraphs();
return graph;
}
}
@Override
public Graph getComparisonGraph(Graph graph) {
return new EdgeListGraph(graph);
}
@Override
public String getDescription() {
return "BOSS (Best Order Score Search) using " + this.score.getDescription();
}
@Override
public DataType getDataType() {
return this.score.getDataType();
}
@Override
public List<String> getParameters() {
ArrayList<String> params = new ArrayList<>();
// Parameters
params.add(Params.USE_BES);
params.add(Params.NUM_STARTS);
params.add(Params.TIME_LAG);
params.add(Params.NUM_THREADS);
params.add(Params.USE_DATA_ORDER);
params.add(Params.VERBOSE);
return params;
}
@Override
public ScoreWrapper getScoreWrapper() {
return this.score;
}
@Override
public void setScoreWrapper(ScoreWrapper score) {
this.score = score;
}
@Override
public Knowledge getKnowledge() {
return this.knowledge;
}
@Override
public void setKnowledge(Knowledge knowledge) {
this.knowledge = knowledge;
}
@Override
public List<Graph> getBootstrapGraphs() {
return this.bootstrapGraphs;
}
}