-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
test_search.cc
129 lines (106 loc) · 3.99 KB
/
test_search.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#include <stdio.h>
#include <stdlib.h> // for system
#include "../vowpalwabbit/vw.h"
#include "../vowpalwabbit/ezexample.h"
#include "../vowpalwabbit/search_sequencetask.h"
#include "libsearch.h"
struct wt
{ string word;
uint32_t tag;
wt(string w, uint32_t t) : word(w), tag(t) {}
};
class SequenceLabelerTask : public SearchTask< vector<wt>, vector<uint32_t> >
{
public:
SequenceLabelerTask(vw& vw_obj)
: SearchTask< vector<wt>, vector<uint32_t> >(vw_obj) // must run parent constructor!
{ sch.set_options( Search::AUTO_HAMMING_LOSS | Search::AUTO_CONDITION_FEATURES );
HookTask::task_data* d = sch.get_task_data<HookTask::task_data>();
cerr << "num_actions = " << d->num_actions << endl;
}
// using vanilla vw interface
void _run(Search::search& sch, vector<wt> & input_example, vector<uint32_t> & output)
{ output.clear();
for (size_t i=0; i<input_example.size(); i++)
{ example* ex = VW::read_example(vw_obj, "1 |w " + input_example[i].word);
action p = Search::predictor(sch, i+1).set_input(*ex).set_oracle(input_example[i].tag).set_condition(i, 'p').predict();
VW::finish_example(vw_obj, ex);
output.push_back(p);
}
}
// using ezexample
void _run2(Search::search& sch, vector<wt> & input_example, vector<uint32_t> & output)
{ output.clear();
for (size_t i=0; i<input_example.size(); i++)
{ ezexample ex(&vw_obj);
ex(vw_namespace('w'))(input_example[i].word); // add the feature
action p = Search::predictor(sch, i+1).set_input(*ex.get()).set_oracle(input_example[i].tag).set_condition(i, 'p').predict();
output.push_back(p);
}
}
};
void run(vw& vw_obj)
{ // we put this in its own scope so that its destructor on
// SequenceLabelerTask gets called *before* VW::finish gets called;
// otherwise we'll get a segfault :(. i'm not sure what to do about
// this :(.
SequenceLabelerTask task(vw_obj);
vector<wt> data;
vector<uint32_t> output;
uint32_t DET = 1, NOUN = 2, VERB = 3, ADJ = 4;
data.push_back( wt("the", DET) );
data.push_back( wt("monster", NOUN) );
data.push_back( wt("ate", VERB) );
data.push_back( wt("a", DET) );
data.push_back( wt("big", ADJ) );
data.push_back( wt("sandwich", NOUN) );
task.learn(data, output);
task.learn(data, output);
task.learn(data, output);
task.predict(data, output);
cerr << "output = [";
for (size_t i=0; i<output.size(); i++) cerr << " " << output[i];
cerr << " ]" << endl;
cerr << "should have printed: 1 2 3 1 4 2" << endl;
}
void train()
{ // initialize VW as usual, but use 'hook' as the search_task
vw& vw_obj = *VW::initialize("--search 4 --quiet --search_task hook --ring_size 1024 -f my_model");
run(vw_obj);
VW::finish(vw_obj);
}
void predict()
{ vw& vw_obj = *VW::initialize("--quiet -t --ring_size 1024 -i my_model");
run(vw_obj);
VW::finish(vw_obj);
}
void test_buildin_task()
{ // train a model on the command line
int ret = system("../vowpalwabbit/vw -k -c --holdout_off --passes 20 --search 4 --search_task sequence -d sequence.data -f sequence.model");
if (ret != 0) cerr << "../vowpalwabbit/vw failed" << endl;
// now, load that model using the BuiltInTask library
vw& vw_obj = *VW::initialize("-t -i sequence.model --search_task hook");
{ // create a new scope for the task object
BuiltInTask task(vw_obj, &SequenceTask::task);
vector<example*> V;
V.push_back( VW::read_example(vw_obj, (char*)"1 | a") );
V.push_back( VW::read_example(vw_obj, (char*)"1 | a") );
V.push_back( VW::read_example(vw_obj, (char*)"1 | a") );
V.push_back( VW::read_example(vw_obj, (char*)"1 | a") );
V.push_back( VW::read_example(vw_obj, (char*)"1 | a") );
vector<action> out;
task.predict(V, out);
cerr << "out (should be 1 2 3 4 3) =";
for (size_t i=0; i<out.size(); i++)
cerr << " " << out[i];
cerr << endl;
for (size_t i=0; i<V.size(); i++)
VW::finish_example(vw_obj, V[i]);
}
VW::finish(vw_obj);
}
int main(int argc, char *argv[])
{ train();
predict();
test_buildin_task();
}