-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
VWActionScoresLearnerTest.java
128 lines (112 loc) · 4.04 KB
/
VWActionScoresLearnerTest.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package vowpalWabbit.learner;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import vowpalWabbit.VWTestHelper;
import vowpalWabbit.responses.ActionScores;
import java.io.IOException;
import static org.junit.Assert.assertArrayEquals;
/**
* Created by jmorra on 10/2/15.
*/
public class VWActionScoresLearnerTest extends VWTestHelper {
@Rule
public TemporaryFolder temporaryFolder = new TemporaryFolder();
@Test
public void testCSOAA() throws IOException {
String[][] data = new String[][]{
new String[]{
"1:1.0 | a_1 b_1 c_1",
"2:0.0 | a_2 b_2 c_2",
"3:2.0 | a_3 b_3 c_3"
},
new String[]{
"1:1.0 | b_1 c_1 d_1",
"2:0.0 | b_2 c_2 d_2"
},
new String[]{
"1:1.0 | a_1 b_1 c_1",
"3:2.0 | a_3 b_3 c_3"
}
};
VWActionScoresLearner vw = VWLearners.create("--csoaa_ldf mc --quiet --csoaa_rank");
ActionScores[] pred = new ActionScores[data.length];
for (int j=0; j< 100; ++j) {
for (int i=0; i<data.length; ++i) {
pred[i] = vw.learn(data[i]);
}
}
vw.close();
ActionScores[] expected = new ActionScores[]{
actionScores(
actionScore(1, -1.0573887f),
actionScore(0, -0.033036415f),
actionScore(2, 1.0063205f)
),
actionScores(
actionScore(1, -1.0342788f),
actionScore(0, 0.9994181f)
),
actionScores(
actionScore(0, 0.033397526f),
actionScore(1, 1.0227613f)
)
};
assertArrayEquals(expected, pred);
}
@Test
public void testCBADF() throws IOException {
testCBADF(false);
}
@Test
public void testCBADFWithRank() throws IOException {
testCBADF(true);
}
private void testCBADF(boolean withRank) throws IOException {
String[][] cbADFTrain = new String[][]{
new String[]{"| a:1 b:0.5","0:0.1:0.75 | a:0.5 b:1 c:2"},
new String[]{"shared | s_1 s_2","0:1.0:0.5 | a:1 b:1 c:1","| a:0.5 b:2 c:1"},
new String[]{"| a:1 b:0.5","0:0.1:0.75 | a:0.5 b:1 c:2"},
new String[]{"shared | s_1 s_2","0:1.0:0.5 | a:1 b:1 c:1","| a:0.5 b:2 c:1"}
};
String model = temporaryFolder.newFile().getAbsolutePath();
String cli = "--quiet --cb_adf -f " + model;
if (withRank)
cli += " --rank_all";
VWActionScoresLearner vw = VWLearners.create(cli);
ActionScores[] trainPreds = new ActionScores[cbADFTrain.length];
for (int i=0; i<cbADFTrain.length; ++i) {
trainPreds[i] = vw.learn(cbADFTrain[i]);
}
ActionScores[] expectedTrainPreds = new ActionScores[]{
actionScores(
actionScore(0, 0),
actionScore(1, 0)
),
actionScores(
actionScore(0, 0.11246802f),
actionScore(1, 0.11246802f)
),
actionScores(
actionScore(0, 0.3682006f),
actionScore(1, 0.5136312f)
),
actionScores(
actionScore(0, 0.58848584f),
actionScore(1, 0.6244352f)
)
};
vw.close();
assertArrayEquals(expectedTrainPreds, trainPreds);
vw = VWLearners.create("--quiet -t -i " + model);
ActionScores[] testPreds = new ActionScores[]{vw.predict(cbADFTrain[0])};
ActionScores[] expectedTestPreds = new ActionScores[]{
actionScores(
actionScore(0, 0.39904374f),
actionScore(1, 0.49083984f)
)
};
vw.close();
assertArrayEquals(expectedTestPreds, testPreds);
}
}