Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cats via python #2668

Merged
merged 2 commits into from Nov 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 19 additions & 1 deletion python/pylibvw.cc
Expand Up @@ -41,6 +41,7 @@ const size_t lCONTEXTUAL_BANDIT = 4;
const size_t lMAX = 5;
const size_t lCONDITIONAL_CONTEXTUAL_BANDIT = 6;
const size_t lSLATES = 7;
const size_t lCONTINUOUS = 8;

const size_t pSCALAR = 0;
const size_t pSCALARS = 1;
Expand All @@ -51,6 +52,7 @@ const size_t pMULTILABELS = 5;
const size_t pPROB = 6;
const size_t pMULTICLASSPROBS = 7;
const size_t pDECISION_SCORES = 8;
const size_t pACTION_PDF_VALUE = 9;

void dont_delete_me(void* arg) {}

Expand Down Expand Up @@ -119,6 +121,8 @@ label_parser* get_label_parser(vw* all, size_t labelType)
return &CCB::ccb_label_parser;
case lSLATES:
return &VW::slates::slates_label_parser;
case lCONTINUOUS:
return &VW::cb_continuous::the_label_parser;
default:
THROW("get_label_parser called on invalid label type");
}
Expand Down Expand Up @@ -148,6 +152,10 @@ size_t my_get_label_type(vw* all)
{
return lSLATES;
}
else if (lp->parse_label == VW::cb_continuous::the_label_parser.parse_label)
{
return lCONTINUOUS;
}
else
{
THROW("unsupported label parser used");
Expand Down Expand Up @@ -176,6 +184,8 @@ size_t my_get_prediction_type(vw_ptr all)
return pMULTICLASSPROBS;
case prediction_type_t::decision_probs:
return pDECISION_SCORES;
case prediction_type_t::action_pdf_value:
return pACTION_PDF_VALUE;
default:
THROW("unsupported prediction type used");
}
Expand Down Expand Up @@ -546,6 +556,11 @@ py::list ex_get_decision_scores(example_ptr ec)
return values;
}

py::tuple ex_get_action_pdf_value(example_ptr ec)
{
return py::make_tuple(ec->pred.pdf_value.action, ec->pred.pdf_value.pdf_value);
}

py::list ex_get_multilabel_predictions(example_ptr ec)
{
py::list values;
Expand Down Expand Up @@ -923,6 +938,7 @@ BOOST_PYTHON_MODULE(pylibvw)
.def_readonly("lConditionalContextualBandit", lCONDITIONAL_CONTEXTUAL_BANDIT,
"Conditional Contextual bandit label type -- used as input to the example() initializer")
.def_readonly("lSlates", lSLATES, "Slates label type -- used as input to the example() initializer")
.def_readonly("lContinuous", lCONTINUOUS, "Continuous label type -- used as input to the example() initializer")

.def_readonly("pSCALAR", pSCALAR, "Scalar prediction type")
.def_readonly("pSCALARS", pSCALARS, "Multiple scalar-valued prediction type")
Expand All @@ -932,7 +948,8 @@ BOOST_PYTHON_MODULE(pylibvw)
.def_readonly("pMULTILABELS", pMULTILABELS, "Multilabel prediction type")
.def_readonly("pPROB", pPROB, "Probability prediction type")
.def_readonly("pMULTICLASSPROBS", pMULTICLASSPROBS, "Multiclass probabilities prediction type")
.def_readonly("pDECISION_SCORES", pDECISION_SCORES, "Decision scores prediction type");
.def_readonly("pDECISION_SCORES", pDECISION_SCORES, "Decision scores prediction type")
.def_readonly("pACTION_PDF_VALUE", pACTION_PDF_VALUE, "Action pdf value prediction type");

// define the example class
py::class_<example, example_ptr, boost::noncopyable>("example", py::no_init)
Expand Down Expand Up @@ -1001,6 +1018,7 @@ BOOST_PYTHON_MODULE(pylibvw)
.def("get_scalars", &ex_get_scalars, "Get scalar values from example prediction")
.def("get_action_scores", &ex_get_action_scores, "Get action scores from example prediction")
.def("get_decision_scores", &ex_get_decision_scores, "Get decision scores from example prediction")
.def("get_action_pdf_value", &ex_get_action_pdf_value, "Get action and pdf value from example prediction")
.def("get_multilabel_predictions", &ex_get_multilabel_predictions,
"Get multilabel predictions from example prediction")
.def("get_costsensitive_prediction", &ex_get_costsensitive_prediction,
Expand Down
19 changes: 19 additions & 0 deletions python/tests/test_cats.py
@@ -0,0 +1,19 @@
from vowpalwabbit import pyvw

import pytest

def test_cats():
min_value = 10
max_value = 20

vw = pyvw.vw("--cats 4 --min_value " + str(min_value) + " --max_value " + str(max_value) + " --bandwidth 1")
vw_example = vw.parse("ca 15:0.657567:6.20426e-05 | f1 f2 f3 f4", pyvw.vw.lContinuous)
vw.learn(vw_example)
vw.finish_example(vw_example)

assert vw.get_prediction_type() == vw.pACTION_PDF_VALUE, "prediction_type should be action_pdf_value"

action, pdf_value = vw.predict("| f1 f2 f3 f4")
assert action >= 10
assert action <= 20
vw.finish()
10 changes: 10 additions & 0 deletions python/vowpalwabbit/pyvw.py
Expand Up @@ -98,6 +98,8 @@ def example(self, initStringOrDict=None, labelType=pylibvw.vw.lDefault):
- 4 : lCONTEXTUAL_BANDIT
- 5 : lMAX
- 6 : lCONDITIONAL_CONTEXTUAL_BANDIT
- 7 : lSLATES
- 8 : lCONTINUOUS
The integer is used to map the corresponding labelType using the
above available options

Expand Down Expand Up @@ -149,6 +151,7 @@ def get_prediction(ec, prediction_type):
- 6: pPROB
- 7: pMULTICLASSPROBS
- 8: pDECISION_SCORES
- 9: pACTION_PDF_VALUE

Examples
--------
Expand Down Expand Up @@ -176,6 +179,7 @@ def get_prediction(ec, prediction_type):
pylibvw.vw.pPROB: ec.get_prob,
pylibvw.vw.pMULTICLASSPROBS: ec.get_scalars,
pylibvw.vw.pDECISION_SCORES: ec.get_decision_scores,
pylibvw.vw.pACTION_PDF_VALUE: ec.get_action_pdf_value,
}
return switch_prediction_type[prediction_type]()

Expand Down Expand Up @@ -263,6 +267,8 @@ def parse(self, str_ex, labelType=pylibvw.vw.lDefault):
- 4 : lCONTEXTUAL_BANDIT
- 5 : lMAX
- 6 : lCONDITIONAL_CONTEXTUAL_BANDIT
- 7 : lSLATES
- 8 : lCONTINUOUS
The integer is used to map the corresponding labelType using the
above available options

Expand Down Expand Up @@ -501,6 +507,8 @@ def example(self, stringOrDict=None, labelType=pylibvw.vw.lDefault):
- 4 : lCONTEXTUAL_BANDIT
- 5 : lMAX
- 6 : lCONDITIONAL_CONTEXTUAL_BANDIT
- 7 : lSLATES
- 8 : lCONTINUOUS
The integer is used to map the corresponding labelType using the
above available options

Expand Down Expand Up @@ -1029,6 +1037,8 @@ def __init__(
- 4 : lCONTEXTUAL_BANDIT
- 5 : lMAX
- 6 : lCONDITIONAL_CONTEXTUAL_BANDIT
- 7 : lSLATES
- 8 : lCONTINUOUS
The integer is used to map the corresponding labelType using the
above available options

Expand Down