-
Notifications
You must be signed in to change notification settings - Fork 14
/
Results.java
220 lines (177 loc) · 6.6 KB
/
Results.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
package cz.cvut.fel.ida.learning.results;
import cz.cvut.fel.ida.algebra.functions.aggregation.Sum;
import org.jetbrains.annotations.NotNull;
import cz.cvut.fel.ida.algebra.functions.Aggregation;
import cz.cvut.fel.ida.algebra.functions.aggregation.Average;
import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.learning.results.metrics.HITS;
import cz.cvut.fel.ida.setup.Settings;
import cz.cvut.fel.ida.utils.exporting.Exportable;
import java.util.*;
import java.util.logging.Logger;
/**
* Object carrying target values with outputs, and corresponding evaluation computations.
* <p>
* Created by gusta on 8.3.17.
*/
public abstract class Results implements Exportable<Results> {
private static final Logger LOG = Logger.getLogger(Results.class.getName());
transient Settings settings;
@Deprecated
public boolean evaluatedOnline = true;
public transient List<Result> evaluations;
/**
* How to aggregate individual errors of samples. E.g. mean for MSE, or sum for SSE.
*/
Aggregation aggregationFcn;
/**
* The error fcn value as measured by the respective settingsFcn over individual sample errorFcns
*/
public Value error;
public Results(@NotNull List<Result> evaluations, Settings settings) {
this.settings = settings;
this.evaluations = evaluations;
this.aggregationFcn = getAggregation(settings);
if (!evaluations.isEmpty())
this.recalculate();
}
public void addResult(Result result) {
evaluations.add(result);
}
@Override
public String toString() {
return aggregationFcn.toString() + "-error= " + error.toDetailedString();
}
public abstract boolean recalculate();
public abstract boolean betterThan(Results other, Settings.ModelSelection criterion);
protected Results(Value meanError) {
this.error = meanError;
}
public StringBuilder printOutputs(boolean sortByIndex) {
if (sortByIndex) {
evaluations.sort(new Comparator<Result>() {
@Override
public int compare(Result o1, Result o2) {
return Integer.compare(o1.position, o2.position);
}
});
} else {
Collections.sort(evaluations); //sort by output (default comparator)
}
StringBuilder sb = new StringBuilder();
sb.append("\n");
for (Result evaluation : evaluations) {
sb.append(evaluation.sampleId);
sb.append(" , output: " + evaluation.getOutput().toDetailedString());
if (evaluation.getTarget() != null) {
sb.append(" , target: " + evaluation.getTarget());
}
sb.append("\n");
}
return sb;
}
private static Aggregation getAggregation(Settings settings) {
if (settings.errorAggregationFcn == Settings.CombinationFcn.AVG) {
return new Average();
} else if (settings.errorAggregationFcn == Settings.CombinationFcn.SUM) {
return new Sum();
} else {
LOG.severe("Unsupported errorAggregationFcn.");
}
return null;
}
public abstract String toString(Settings settings);
public boolean isEmpty() {
return evaluations.isEmpty();
}
public static abstract class Factory<R extends Results> {
Settings settings;
public Factory(Settings settings) {
this.settings = settings;
}
public static Factory getFrom(Settings.ResultsType type, Settings settings) {
switch (type) {
case KBC:
return new KBCFactory(settings);
case REGRESSION:
return new RegressionFactory(settings);
case CLASSIFICATION:
return new ClassificationFactory(settings);
case DETAILEDCLASSIFICATION:
return new DetailedClassificationFactory(settings);
default:
throw new RuntimeException("Unknown ResultsType required");
}
}
public abstract R createFrom(List<Result> outputs);
/**
* Possibly store some precalculated structures from the results
*
* @param trainingResults
*/
public void cacheForReuse(R trainingResults) {
}
}
private static class VoidFactory extends Factory<Results> {
public VoidFactory(Settings settings) {
super(settings);
}
@Override
public Results createFrom(List<Result> outputs) {
return new VoidResults(outputs, settings);
}
}
private static class RegressionFactory extends Factory<RegressionResults> {
public RegressionFactory(Settings settings) {
super(settings);
}
@Override
public RegressionResults createFrom(List<Result> outputs) {
return new RegressionResults(outputs, settings);
}
}
private static class ClassificationFactory extends Factory<ClassificationResults> {
public ClassificationFactory(Settings settings) {
super(settings);
}
@Override
public ClassificationResults createFrom(List<Result> outputs) {
return new ClassificationResults(outputs, settings);
}
}
private static class DetailedClassificationFactory extends Factory<DetailedClassificationResults> {
public DetailedClassificationFactory(Settings settings) {
super(settings);
}
@Override
public DetailedClassificationResults createFrom(List<Result> outputs) {
return new DetailedClassificationResults(outputs, settings);
}
}
private static class KBCFactory extends Factory<KBCResults> {
HITS hits;
private Set<HITS> consumed;
public KBCFactory(Settings settings) {
super(settings);
consumed = new HashSet<>();
}
@Override
public KBCResults createFrom(List<Result> outputs) {
KBCResults kbcResults = new KBCResults(outputs, settings, hits);
cacheForReuse(kbcResults);
return kbcResults;
}
@Override
public void cacheForReuse(KBCResults results) {
if (hits == null) {
hits = results.hits;
consumed.add(results.hits);
} else {
if (hits != results.hits && !consumed.contains(results.hits)) {
hits.mergeWith(results.hits);
consumed.add(results.hits);
}
}
}
}
}