Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: [py] add api for learner metrics #3022

Merged
merged 3 commits into from May 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 18 additions & 0 deletions python/pylibvw.cc
Expand Up @@ -268,6 +268,22 @@ void my_run_parser(vw_ptr all)
VW::end_parser(*all);
}

py::dict get_learner_metrics(vw_ptr all)
{
py::dict dictionary;

if (all->options->was_supplied("extra_metrics"))
{
VW::metric_sink metrics;
all->l->persist_metrics(metrics);

for (const auto& m : metrics.int_metrics_list) { dictionary[m.first] = m.second; }
for (const auto& m : metrics.float_metrics_list) { dictionary[m.first] = m.second; }
}

return dictionary;
}

void my_finish(vw_ptr all)
{
VW::finish(*all, false); // don't delete all because python will do that for us!
Expand Down Expand Up @@ -1129,6 +1145,8 @@ BOOST_PYTHON_MODULE(pylibvw)
.def("__init__", py::make_constructor(my_initialize_with_log))
// .def("__del__", &my_finish, "deconstruct the VW object by calling finish")
.def("run_parser", &my_run_parser, "parse external data file")
.def("get_learner_metrics", &get_learner_metrics,
"get current learner stack metrics. returns empty dict if --extra_metrics was not supplied.")
.def("finish", &my_finish, "stop VW by calling finish (and, eg, write weights to disk)")
.def("save", &my_save, "save model to filename")
.def("learn", &my_learn, "given a pyvw example, learn (and predict) on that example")
Expand Down
3 changes: 3 additions & 0 deletions python/tests/test_cb.py
Expand Up @@ -67,6 +67,9 @@ def helper_getting_started_example(which_cb):
assert isinstance(choice, int), "choice should be int"
assert choice == 3, "predicted action should be 3 instead of " + str(choice)

# test that metrics is empty since "--extra_metrics filename" was not supplied
assert(len(vw.get_learner_metrics()) == 0)

vw.finish()

output = vw.get_log()
Expand Down
31 changes: 31 additions & 0 deletions python/tests/test_pyvw.py
Expand Up @@ -430,6 +430,37 @@ def test_dsjson():
for a,b in zip(pred, expected):
assert isclose(a, b)

def test_dsjson_with_metrics():
vw = pyvw.vw('--extra_metrics metrics.json --cb_explore_adf --epsilon 0.2 --dsjson')

ex_l_str='{"_label_cost":-0.9,"_label_probability":0.5,"_label_Action":1,"_labelIndex":0,"o":[{"v":1.0,"EventId":"38cbf24f-70b2-4c76-aa0c-970d0c8d388e","ActionTaken":false}],"Timestamp":"2020-11-15T17:09:31.8350000Z","Version":"1","EventId":"38cbf24f-70b2-4c76-aa0c-970d0c8d388e","a":[1,2],"c":{ "GUser":{"id":"person5","major":"engineering","hobby":"hiking","favorite_character":"spock"}, "_multi": [ { "TAction":{"topic":"SkiConditions-VT"} }, { "TAction":{"topic":"HerbGarden"} } ] },"p":[0.5,0.5],"VWState":{"m":"N/A"}}\n'
ex_l = vw.parse(ex_l_str)
vw.learn(ex_l)
pred = ex_l[0].get_action_scores()
expected = [0.5, 0.5]
assert len(pred) == len(expected)
for a,b in zip(pred, expected):
assert isclose(a, b)
vw.finish_example(ex_l)

ex_p='{"_label_cost":-1.0,"_label_probability":0.5,"_label_Action":1,"_labelIndex":0,"o":[{"v":1.0,"EventId":"38cbf24f-70b2-4c76-aa0c-970d0c8d388e","ActionTaken":false}],"Timestamp":"2020-11-15T17:09:31.8350000Z","Version":"1","EventId":"38cbf24f-70b2-4c76-aa0c-970d0c8d388e","a":[1,2],"c":{ "GUser":{"id":"person5","major":"engineering","hobby":"hiking","favorite_character":"spock"}, "_multi": [ { "TAction":{"topic":"SkiConditions-VT"} }, { "TAction":{"topic":"HerbGarden"} } ] },"p":[0.5,0.5],"VWState":{"m":"N/A"}}\n'
pred = vw.predict(ex_p)
expected = [0.9, 0.1]
assert len(pred) == len(expected)
for a,b in zip(pred, expected):
assert isclose(a, b)

learner_metric_dict = vw.get_learner_metrics()
assert(learner_metric_dict["total_predict_calls"] == 2)
assert(learner_metric_dict["total_learn_calls"] == 1)
assert(learner_metric_dict["cbea_labeled_ex"] == 1)
assert(learner_metric_dict["cbea_predict_in_learn"] == 0)
assert(learner_metric_dict["cbea_label_first_action"] == 1)
assert(learner_metric_dict["cbea_label_not_first"] == 0)
assert(pytest.approx(learner_metric_dict["cbea_sum_cost"]) == -0.9)
assert(pytest.approx(learner_metric_dict["cbea_sum_cost_baseline"]) == -0.9)
assert(len(vw.get_learner_metrics()) == 8)

def test_constructor_exception_is_safe():
try:
vw = pyvw.vw("--invalid_option")
Expand Down