-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
test_search.cc
155 lines (141 loc) · 5.28 KB
/
test_search.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#include "libsearch.h"
#include "vw/config/options_cli.h"
#include "vw/core/memory.h"
#include "vw/core/reductions/search/search_sequencetask.h"
#include "vw/core/vw.h"
#include <cstdio>
#include <cstdlib> // for system
#include <utility>
using std::cerr;
using std::endl;
struct wt
{
std::string word;
uint32_t tag;
wt(std::string w, uint32_t t) : word(std::move(w)), tag(t) {}
};
class SequenceLabelerTask : public SearchTask<std::vector<wt>, std::vector<uint32_t> > // NOLINT
{
public:
SequenceLabelerTask(VW::workspace& vw_obj)
: SearchTask<std::vector<wt>, std::vector<uint32_t> >(vw_obj) // must run parent constructor!
{
sch.set_options(Search::AUTO_HAMMING_LOSS | Search::AUTO_CONDITION_FEATURES);
HookTask::task_data* d = sch.get_task_data<HookTask::task_data>();
cerr << "num_actions = " << d->num_actions << endl;
}
// using vanilla vw interface
void _run(Search::search& sch, std::vector<wt>& input_example, std::vector<uint32_t>& output)
{
output.clear();
// ptag currently uint32_t
for (ptag i = 0; i < input_example.size(); i++)
{
VW::example* ex = VW::read_example(vw_obj, std::string("1 |w ") + input_example[i].word);
action p =
Search::predictor(sch, i + 1).set_input(*ex).set_oracle(input_example[i].tag).set_condition(i, 'p').predict();
VW::finish_example(vw_obj, *ex);
output.push_back(p);
}
}
void _run2(Search::search& sch, std::vector<wt>& input_example, std::vector<uint32_t>& output) // NOLINT
{
auto& vw_obj = sch.get_vw_pointer_unsafe();
output.clear();
// ptag currently uint32_t
for (ptag i = 0; i < input_example.size(); i++)
{
VW::example ex;
auto ns_hash_w = VW::hash_space(vw_obj, "w");
auto& fs_w = ex.feature_space['w'];
ex.indices.push_back('w');
fs_w.push_back(1.f, VW::hash_feature(vw_obj, input_example[i].word, ns_hash_w));
VW::setup_example(vw_obj, &ex);
action p =
Search::predictor(sch, i + 1).set_input(ex).set_oracle(input_example[i].tag).set_condition(i, 'p').predict();
output.push_back(p);
VW::finish_example(vw_obj, ex);
}
}
};
void run(VW::workspace& vw_obj)
{
// we put this in its own scope so that its destructor on
// SequenceLabelerTask gets called *before* VW::finish gets called;
// otherwise we'll get a segfault :(. i'm not sure what to do about
// this :(.
SequenceLabelerTask task(vw_obj);
std::vector<wt> data;
std::vector<uint32_t> output;
static constexpr const uint32_t DET = 1;
static constexpr const uint32_t NOUN = 2;
static constexpr const uint32_t VERB = 3;
static constexpr const uint32_t ADJ = 4;
data.push_back(wt("the", DET));
data.push_back(wt("monster", NOUN));
data.push_back(wt("ate", VERB));
data.push_back(wt("a", DET));
data.push_back(wt("big", ADJ));
data.push_back(wt("sandwich", NOUN));
task.learn(data, output);
task.learn(data, output);
task.learn(data, output);
task.predict(data, output);
cerr << "output = [";
for (size_t i = 0; i < output.size(); i++) { cerr << " " << output[i]; }
cerr << " ]" << endl;
cerr << "should have printed: 1 2 3 1 4 2" << endl;
}
void train()
{
// initialize VW as usual, but use 'hook' as the search_task
cerr << endl << endl << "##### train() #####" << endl << endl;
auto vw_obj = VW::initialize(VW::make_unique<VW::config::options_cli>(std::vector<std::string>{
"--search", "4", "--quiet", "--search_task", "hook", "--example_queue_limit", "1024", "-f", "my_model"}));
run(*vw_obj);
vw_obj->finish();
}
void predict()
{
cerr << endl << endl << "##### predict() #####" << endl << endl;
auto vw_obj = VW::initialize(VW::make_unique<VW::config::options_cli>(
std::vector<std::string>{"--quiet", "-t", "--example_queue_limit", "1024", "-i", "my_model"}));
run(*vw_obj);
vw_obj->finish();
}
void test_buildin_task()
{
cerr << endl << endl << "##### run commandline vw #####" << endl << endl;
// train a model on the command line
int ret = system(
"../vowpalwabbit/vw -c -k --holdout_off --passes 20 --search 4 --search_task sequence -d "
"../../test/train-sets/sequence_data -f "
"sequence.model");
if (ret != 0) { cerr << "../vowpalwabbit/vw failed" << endl; }
// now, load that model using the BuiltInTask library
cerr << endl << endl << "##### test BuiltInTask #####" << endl << endl;
auto vw_obj =
VW::initialize(VW::make_unique<VW::config::options_cli>(std::vector<std::string>{"-t", "--search_task", "hook"}));
{ // create a new scope for the task object
BuiltInTask task(*vw_obj, &SequenceTask::task);
VW::multi_ex mult_ex;
mult_ex.push_back(VW::read_example(*vw_obj, (char*)"1 | a"));
mult_ex.push_back(VW::read_example(*vw_obj, (char*)"1 | a"));
mult_ex.push_back(VW::read_example(*vw_obj, (char*)"1 | a"));
mult_ex.push_back(VW::read_example(*vw_obj, (char*)"1 | a"));
mult_ex.push_back(VW::read_example(*vw_obj, (char*)"1 | a"));
std::vector<action> out;
task.predict(mult_ex, out);
cerr << "out (should be 1 2 3 4 3) =";
for (size_t i = 0; i < out.size(); i++) { cerr << " " << out[i]; }
cerr << endl;
for (size_t i = 0; i < mult_ex.size(); i++) { VW::finish_example(*vw_obj, *mult_ex[i]); }
}
vw_obj->finish();
}
int main(int argc, char* argv[])
{
train();
predict();
test_buildin_task();
}