-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
libsearch.h
126 lines (108 loc) · 4.09 KB
/
libsearch.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#/*
COpyright (c) by respective owners including Yahoo!, Microsoft, and
individual contributors. All rights reserved. Released under a BSD
license as described in the file LICENSE.
*/
#ifndef LIBSEARCH_HOOKTASK_H
# define LIBSEARCH_HOOKTASK_H
# include "vw/core/parse_example.h"
# include "vw/core/parser.h"
# include "vw/core/reductions/search/search.h"
# include "vw/core/reductions/search/search_hooktask.h"
# include "vw/core/vw.h"
# include <memory>
template <class INPUT, class OUTPUT>
class SearchTask // NOLINT
{
public:
SearchTask(VW::workspace& vw_obj) : vw_obj(vw_obj), sch(*(Search::search*)vw_obj.searchstr)
{
_bogus_example = new VW::example;
VW::parsers::text::read_line(vw_obj, _bogus_example, (char*)"1 | x");
VW::setup_example(vw_obj, _bogus_example);
_trigger.push_back(_bogus_example);
HookTask::task_data* d = sch.get_task_data<HookTask::task_data>();
d->run_f = search_run_fn;
d->run_setup_f = search_setup_fn;
d->run_takedown_f = search_takedown_fn;
d->run_object = std::shared_ptr<SearchTask<INPUT, OUTPUT>>(this);
}
virtual ~SearchTask()
{
_trigger.clear(); // the individual examples get cleaned up below
delete _bogus_example;
}
virtual void _run(Search::search& sch, INPUT& input_example, OUTPUT& output) {
} // NOLINT YOU MUST DEFINE THIS FUNCTION!
void _setup(Search::search& sch, INPUT& input_example, OUTPUT& output) {} // NOLINT OPTIONAL
void _takedown(Search::search& sch, INPUT& input_example, OUTPUT& output) {} // NOLINT OPTIONAL
void learn(INPUT& input_example, OUTPUT& output)
{
_bogus_example->test_only = false;
call_vw(input_example, output);
}
void predict(INPUT& input_example, OUTPUT& output)
{
_bogus_example->test_only = true;
call_vw(input_example, output);
}
protected:
VW::workspace& vw_obj; // NOLINT
Search::search& sch; // NOLINT
private:
VW::example* _bogus_example;
VW::multi_ex _trigger;
INPUT _input;
OUTPUT _output;
void call_vw(INPUT& input_example, OUTPUT& output)
{
_input = input_example;
_output = output;
vw_obj.learn(_trigger); // this will cause our search_run_fn hook to get called
}
static void search_run_fn(Search::search& sch)
{
HookTask::task_data* d = sch.get_task_data<HookTask::task_data>();
if (d->run_object == nullptr) { THROW("error: calling search_run_fn without setting run object"); }
auto* run_obj = static_cast<SearchTask<INPUT, OUTPUT>*>(d->run_object.get());
run_obj->_run(sch, run_obj->_input, run_obj->_output);
}
static void search_setup_fn(Search::search& sch)
{
HookTask::task_data* d = sch.get_task_data<HookTask::task_data>();
if (d->run_object == nullptr) { THROW("error: calling search_setup_fn without setting run object"); }
auto* run_obj = static_cast<SearchTask<INPUT, OUTPUT>*>(d->run_object.get());
run_obj->_setup(sch, run_obj->_input, run_obj->_output);
}
static void search_takedown_fn(Search::search& sch)
{
HookTask::task_data* d = sch.get_task_data<HookTask::task_data>();
if (d->run_object == nullptr) { THROW("error: calling search_takedown_fn without setting run object"); }
auto* run_obj = static_cast<SearchTask<INPUT, OUTPUT>*>(d->run_object.get());
run_obj->_takedown(sch, run_obj->_input, run_obj->_output);
}
};
class BuiltInTask : public SearchTask<VW::multi_ex, std::vector<uint32_t>> // NOLINT
{
public:
BuiltInTask(VW::workspace& vw_obj, Search::search_task* task)
: SearchTask<VW::multi_ex, std::vector<uint32_t>>(vw_obj)
{
HookTask::task_data* d = sch.get_task_data<HookTask::task_data>();
size_t num_actions = d->num_actions;
my_task = task;
if (my_task->initialize) my_task->initialize(sch, num_actions, *d->arg);
}
~BuiltInTask()
{
if (my_task->finish) my_task->finish(sch);
}
void _run(Search::search& sch, VW::multi_ex& input_example, std::vector<uint32_t>& output)
{
my_task->run(sch, input_example);
sch.get_test_action_sequence(output);
}
protected:
Search::search_task* my_task; // NOLINT
};
#endif