Skip to content

Commit

Permalink
refactor: cleanup cb.h header (#4427)
Browse files Browse the repository at this point in the history
* refactor: cleanup cb.h header

* fix inline

* deprecated usage

* csharp fix

* format
  • Loading branch information
jackgerrits committed Jan 5, 2023
1 parent b38e95e commit b5fe70d
Show file tree
Hide file tree
Showing 56 changed files with 288 additions and 269 deletions.
5 changes: 3 additions & 2 deletions cs/cli/vw_example.cpp
Expand Up @@ -2,6 +2,7 @@
// individual contributors. All rights reserved. Released under a BSD (revised)
// license as described in the file LICENSE.

#include "vw/core/cb.h"
#define NOMINMAX

#include "vowpalwabbit.h"
Expand Down Expand Up @@ -75,9 +76,9 @@ ILabel^ VowpalWabbitExample::Label::get()
auto lp = m_owner->Native->m_vw->example_parser->lbl_parser;
if (!memcmp(&lp, &VW::simple_label_parser_global, sizeof(lp)))
label = gcnew SimpleLabel();
else if (!memcmp(&lp, &CB::cb_label, sizeof(lp)))
else if (!memcmp(&lp, &VW::cb_label_parser_global, sizeof(lp)))
label = gcnew ContextualBanditLabel();
else if (!memcmp(&lp, &CB_EVAL::cb_eval, sizeof(lp)))
else if (!memcmp(&lp, &VW::cb_eval_label_parser_global, sizeof(lp)))
label = gcnew SimpleLabel();
else if (!memcmp(&lp, &VW::cs_label_parser_global, sizeof(lp)))
label = gcnew SimpleLabel();
Expand Down
12 changes: 6 additions & 6 deletions cs/cli/vw_label.h
Expand Up @@ -100,10 +100,10 @@ public ref class ContextualBanditLabel sealed : ILabel

virtual void ReadFromExample(example* ex)
{
CB::label* ld = &ex->l.cb;
VW::cb_label* ld = &ex->l.cb;
if (ld->costs.size() > 0)
{
CB::cb_class& f = ld->costs[0];
VW::cb_class& f = ld->costs[0];

m_action = f.action;
m_cost = f.cost;
Expand All @@ -113,8 +113,8 @@ public ref class ContextualBanditLabel sealed : ILabel

virtual void UpdateExample(VW::workspace* vw, example* ex)
{
CB::label* ld = &ex->l.cb;
CB::cb_class f;
VW::cb_label* ld = &ex->l.cb;
VW::cb_class f;

f.partial_prediction = 0.;
f.action = m_action;
Expand Down Expand Up @@ -153,8 +153,8 @@ public ref class SharedLabel sealed : ILabel

virtual void UpdateExample(VW::workspace* vw, example* ex)
{
CB::label* ld = &ex->l.cb;
CB::cb_class f;
VW::cb_label* ld = &ex->l.cb;
VW::cb_class f;

f.partial_prediction = 0.;
f.action = m_action;
Expand Down
2 changes: 1 addition & 1 deletion cs/vw.net.native/vw.net.cbutil.cc
Expand Up @@ -2,6 +2,6 @@

API float GetCbUnbiasedCost(uint32_t actionObservered, uint32_t actionTaken, float cost, float probability)
{
CB::cb_class observation(cost, actionObservered, probability);
VW::cb_class observation(cost, actionObservered, probability);
return CB_ALGS::get_cost_estimate(observation, actionTaken);
}
16 changes: 8 additions & 8 deletions cs/vw.net.native/vw.net.labels.cc
Expand Up @@ -31,16 +31,16 @@ API void SimpleLabelUpdateExample(vw_net_native::workspace_context* workspace, V
VW::count_label(*workspace->vw->sd, ld->label);
}

API CB::cb_class* CbLabelReadFromExampleDangerous(VW::example* ex)
API VW::cb_class* CbLabelReadFromExampleDangerous(VW::example* ex)
{
CB::label* ld = &ex->l.cb;
VW::cb_label* ld = &ex->l.cb;

return (ld->costs.size() > 0) ? &ld->costs[0] : nullptr;
}

API void CbLabelUpdateExample(VW::example* ex, const CB::cb_class* f)
API void CbLabelUpdateExample(VW::example* ex, const VW::cb_class* f)
{
CB::label* ld = &ex->l.cb;
VW::cb_label* ld = &ex->l.cb;

// TODO: Should we be clearing the costs here?
// ld->costs.clear();
Expand Down Expand Up @@ -86,8 +86,8 @@ API char* ComputeDiffDescriptionSimpleLabels(VW::example* ex1, VW::example* ex2)

API char* ComputeDiffDescriptionCbLabels(VW::example* ex1, VW::example* ex2)
{
CB::label ld1 = ex1->l.cb;
CB::label ld2 = ex2->l.cb;
VW::cb_label ld1 = ex1->l.cb;
VW::cb_label ld2 = ex2->l.cb;

std::stringstream sstream;
if (ld1.costs.size() != ld2.costs.size())
Expand All @@ -99,8 +99,8 @@ API char* ComputeDiffDescriptionCbLabels(VW::example* ex1, VW::example* ex2)
{
for (size_t i = 0; i < ld1.costs.size(); i++)
{
CB::cb_class c1 = ld1.costs[i];
CB::cb_class c2 = ld2.costs[i];
VW::cb_class c1 = ld1.costs[i];
VW::cb_class c2 = ld2.costs[i];

if (c1.action != c2.action)
{
Expand Down
4 changes: 2 additions & 2 deletions cs/vw.net.native/vw.net.labels.h
Expand Up @@ -14,8 +14,8 @@ extern "C"
API void SimpleLabelUpdateExample(vw_net_native::workspace_context* workspace, VW::example* ex, float label,
float* maybe_weight, float* maybe_initial);

API CB::cb_class* CbLabelReadFromExampleDangerous(VW::example* ex);
API void CbLabelUpdateExample(VW::example* ex, const CB::cb_class* f);
API VW::cb_class* CbLabelReadFromExampleDangerous(VW::example* ex);
API void CbLabelUpdateExample(VW::example* ex, const VW::cb_class* f);

API vw_net_native::ERROR_CODE StringLabelParseAndUpdateExample(vw_net_native::workspace_context* workspace,
VW::example* ex, const char* label, size_t label_len, VW::experimental::api_status* status = nullptr);
Expand Down
14 changes: 7 additions & 7 deletions java/src/main/c++/jni_spark_vw.cc
Expand Up @@ -710,8 +710,8 @@ JNIEXPORT void JNICALL Java_org_vowpalwabbit_spark_VowpalWabbitExample_setContex

try
{
CB::label* ld = &ex->l.cb;
CB::cb_class f;
VW::cb_label* ld = &ex->l.cb;
VW::cb_class f;

f.action = (uint32_t)action;
f.cost = (float)cost;
Expand All @@ -732,8 +732,8 @@ JNIEXPORT void JNICALL Java_org_vowpalwabbit_spark_VowpalWabbitExample_setShared
try
{
// https://github.com/VowpalWabbit/vowpal_wabbit/blob/master/vowpalwabbit/parse_example_json.h#L437
CB::label* ld = &ex->l.cb;
CB::cb_class f;
VW::cb_label* ld = &ex->l.cb;
VW::cb_class f;

f.partial_prediction = 0.;
f.action = (uint32_t)VW::uniform_hash("shared", 6 /*length of string*/, 0);
Expand Down Expand Up @@ -889,16 +889,16 @@ JNIEXPORT jstring JNICALL Java_org_vowpalwabbit_spark_VowpalWabbitExample_toStri
const auto& red_fts = ex->ex_reduction_features.template get<VW::simple_label_reduction_features>();
ostr << "simple " << ld->label << ":" << red_fts.weight << ":" << red_fts.initial;
}
else if (!memcmp(&lp, &CB::cb_label, sizeof(lp)))
else if (!memcmp(&lp, &VW::cb_label_parser_global, sizeof(lp)))
{
CB::label* ld = &ex->l.cb;
VW::cb_label* ld = &ex->l.cb;
ostr << "CB " << ld->costs.size();

if (ld->costs.size() > 0)
{
ostr << " ";

CB::cb_class& f = ld->costs[0];
VW::cb_class& f = ld->costs[0];

// Ignore checking if f.action == VW::uniform_hash("shared")
if (f.partial_prediction == 0 && f.cost == FLT_MAX && f.probability == -1.f)
Expand Down
8 changes: 4 additions & 4 deletions python/pylibvw.cc
Expand Up @@ -402,15 +402,15 @@ VW::label_parser* get_label_parser(VW::workspace* all, size_t labelType)
case lCOST_SENSITIVE:
return &VW::cs_label_parser_global;
case lCONTEXTUAL_BANDIT:
return &CB::cb_label;
return &VW::cb_label_parser_global;
case lCONDITIONAL_CONTEXTUAL_BANDIT:
return &VW::ccb_label_parser_global;
case lSLATES:
return &VW::slates::slates_label_parser;
case lCONTINUOUS:
return &VW::cb_continuous::the_label_parser;
case lCONTEXTUAL_BANDIT_EVAL:
return &CB_EVAL::cb_eval;
return &VW::cb_eval_label_parser_global;
case lMULTILABEL:
return &MULTILABEL::multilabel;
default:
Expand All @@ -424,8 +424,8 @@ size_t my_get_label_type(VW::workspace* all)
if (lp->parse_label == VW::simple_label_parser_global.parse_label) { return lSIMPLE; }
else if (lp->parse_label == VW::multiclass_label_parser_global.parse_label) { return lMULTICLASS; }
else if (lp->parse_label == VW::cs_label_parser_global.parse_label) { return lCOST_SENSITIVE; }
else if (lp->parse_label == CB::cb_label.parse_label) { return lCONTEXTUAL_BANDIT; }
else if (lp->parse_label == CB_EVAL::cb_eval.parse_label) { return lCONTEXTUAL_BANDIT_EVAL; }
else if (lp->parse_label == VW::cb_label_parser_global.parse_label) { return lCONTEXTUAL_BANDIT; }
else if (lp->parse_label == VW::cb_eval_label_parser_global.parse_label) { return lCONTEXTUAL_BANDIT_EVAL; }
else if (lp->parse_label == VW::ccb_label_parser_global.parse_label) { return lCONDITIONAL_CONTEXTUAL_BANDIT; }
else if (lp->parse_label == VW::slates::slates_label_parser.parse_label) { return lSLATES; }
else if (lp->parse_label == VW::cb_continuous::the_label_parser.parse_label) { return lCONTINUOUS; }
Expand Down
4 changes: 2 additions & 2 deletions vowpalwabbit/core/include/vw/core/automl_impl.h
Expand Up @@ -296,8 +296,8 @@ class automl
}
}
// This fn gets called before learning any example
void one_step(VW::LEARNER::multi_learner& base, multi_ex& ec, CB::cb_class& logged, uint64_t labelled_action);
void offset_learn(VW::LEARNER::multi_learner& base, multi_ex& ec, CB::cb_class& logged, uint64_t labelled_action);
void one_step(VW::LEARNER::multi_learner& base, multi_ex& ec, VW::cb_class& logged, uint64_t labelled_action);
void offset_learn(VW::LEARNER::multi_learner& base, multi_ex& ec, VW::cb_class& logged, uint64_t labelled_action);
};
} // namespace automl

Expand Down
70 changes: 47 additions & 23 deletions vowpalwabbit/core/include/vw/core/cb.h
Expand Up @@ -5,6 +5,7 @@

#include "vw/core/io_buf.h"
#include "vw/core/label_parser.h"
#include "vw/core/multi_ex.h"
#include "vw/core/v_array.h"
#include "vw/core/vw_fwd.h"

Expand All @@ -15,11 +16,6 @@

namespace VW
{
class example;
using multi_ex = std::vector<example*>;
} // namespace VW
namespace CB
{
// By default a cb class does not contain an observed cost.
class cb_class
{
Expand All @@ -41,7 +37,7 @@ class cb_class
constexpr bool has_observed_cost() const { return (cost != FLT_MAX && probability > .0); }
};

class label
class cb_label
{
public:
std::vector<cb_class> costs;
Expand All @@ -52,36 +48,64 @@ class label
void reset_to_default();
};

extern VW::label_parser cb_label; // for learning
bool ec_is_example_header(VW::example const& ec); // example headers look like "shared"
extern VW::label_parser cb_label_parser_global;

std::pair<bool, cb_class> get_observed_cost_cb(const label& ld);
// example headers look like "shared"
bool ec_is_example_header_cb(VW::example const& ec);

void print_update(VW::workspace& all, bool is_test, const VW::example& ec, const VW::multi_ex* ec_seq,
bool action_scores, const CB::cb_class* known_cost);
} // namespace CB
std::pair<bool, cb_class> get_observed_cost_cb(const cb_label& ld);

namespace CB_EVAL
} // namespace VW
namespace VW
{
namespace details
{
void print_update_cb(VW::workspace& all, bool is_test, const VW::example& ec, const VW::multi_ex* ec_seq,
bool action_scores, const VW::cb_class* known_cost);
}
} // namespace VW

namespace VW
{
class label
class cb_eval_label
{
public:
uint32_t action = 0;
CB::label event;
cb_label event;
};

extern VW::label_parser cb_eval; // for evaluation of an arbitrary policy.
} // namespace CB_EVAL
extern VW::label_parser cb_eval_label_parser_global; // for evaluation of an arbitrary policy.
} // namespace VW

namespace VW
{
namespace model_utils
{
size_t read_model_field(io_buf&, CB::cb_class&);
size_t write_model_field(io_buf&, const CB::cb_class&, const std::string&, bool);
size_t read_model_field(io_buf&, CB::label&);
size_t write_model_field(io_buf&, const CB::label&, const std::string&, bool);
size_t read_model_field(io_buf&, CB_EVAL::label&);
size_t write_model_field(io_buf&, const CB_EVAL::label&, const std::string&, bool);
size_t read_model_field(io_buf&, VW::cb_class&);
size_t write_model_field(io_buf&, const VW::cb_class&, const std::string&, bool);
size_t read_model_field(io_buf&, VW::cb_label&);
size_t write_model_field(io_buf&, const VW::cb_label&, const std::string&, bool);
size_t read_model_field(io_buf&, VW::cb_eval_label&);
size_t write_model_field(io_buf&, const VW::cb_eval_label&, const std::string&, bool);
} // namespace model_utils
} // namespace VW

namespace CB
{
using cb_class VW_DEPRECATED("Renamed to VW::cb_class") = VW::cb_class;
using label VW_DEPRECATED("Renamed to VW::cb_label") = VW::cb_label;

VW_DEPRECATED("Renamed to VW::ec_is_example_header_cb")
inline bool ec_is_example_header(VW::example const& ec) { return VW::ec_is_example_header_cb(ec); }

VW_DEPRECATED("Renamed to VW::get_observed_cost_cb")
inline std::pair<bool, VW::cb_class> get_observed_cost_cb(const VW::cb_label& ld)
{
return VW::get_observed_cost_cb(ld);
}
} // namespace CB

namespace CB_EVAL
{
using label VW_DEPRECATED("Renamed to VW::cb_eval_label") = VW::cb_eval_label;
}
4 changes: 2 additions & 2 deletions vowpalwabbit/core/include/vw/core/example.h
Expand Up @@ -41,11 +41,11 @@ class polylabel
VW::simple_label simple;
VW::multiclass_label multi;
VW::cs_label cs;
CB::label cb;
VW::cb_label cb;
VW::cb_continuous::continuous_label cb_cont;
VW::ccb_label conditional_contextual_bandit;
VW::slates::label slates;
CB_EVAL::label cb_eval;
VW::cb_eval_label cb_eval;
MULTILABEL::labels multilabels;
};

Expand Down
19 changes: 10 additions & 9 deletions vowpalwabbit/core/include/vw/core/gen_cs_example.h
Expand Up @@ -22,7 +22,7 @@ class cb_to_cs
float last_pred_reg = 0.f;
float last_correct_cost = 0.f;

CB::cb_class known_cost;
VW::cb_class known_cost;
};

class cb_to_cs_adf
Expand All @@ -38,17 +38,17 @@ class cb_to_cs_adf

// for DR
VW::cs_label pred_scores;
CB::cb_class known_cost;
VW::cb_class known_cost;
VW::LEARNER::single_learner* scorer = nullptr;
};

float safe_probability(float prob, VW::io::logger& logger);

void gen_cs_example_ips(
cb_to_cs& c, const CB::label& ld, VW::cs_label& cs_ld, VW::io::logger& logger, float clip_p = 0.f);
cb_to_cs& c, const VW::cb_label& ld, VW::cs_label& cs_ld, VW::io::logger& logger, float clip_p = 0.f);

template <bool is_learn>
void gen_cs_example_dm(cb_to_cs& c, VW::example& ec, const CB::label& ld, VW::cs_label& cs_ld)
void gen_cs_example_dm(cb_to_cs& c, VW::example& ec, const VW::cb_label& ld, VW::cs_label& cs_ld)
{ // this implements the direct estimation method, where costs are directly specified by the learned regressor.

float min = FLT_MAX;
Expand Down Expand Up @@ -141,7 +141,8 @@ void gen_cs_label(cb_to_cs& c, VW::example& ec, VW::cs_label& cs_ld, uint32_t ac
}

template <bool is_learn>
void gen_cs_example_dr(cb_to_cs& c, VW::example& ec, const CB::label& ld, VW::cs_label& cs_ld, float /*clip_p*/ = 0.f)
void gen_cs_example_dr(
cb_to_cs& c, VW::example& ec, const VW::cb_label& ld, VW::cs_label& cs_ld, float /*clip_p*/ = 0.f)
{
// this implements the doubly robust method
VW_DBG(ec) << "gen_cs_example_dr:" << is_learn << std::endl;
Expand Down Expand Up @@ -169,7 +170,7 @@ void gen_cs_example_dr(cb_to_cs& c, VW::example& ec, const CB::label& ld, VW::cs
}

template <bool is_learn>
void gen_cs_example(cb_to_cs& c, VW::example& ec, const CB::label& ld, VW::cs_label& cs_ld, VW::io::logger& logger)
void gen_cs_example(cb_to_cs& c, VW::example& ec, const VW::cb_label& ld, VW::cs_label& cs_ld, VW::io::logger& logger)
{
switch (c.cb_type)
{
Expand Down Expand Up @@ -222,7 +223,7 @@ void gen_cs_example_dr(cb_to_cs_adf& c, VW::multi_ex& examples, VW::cs_label& cs
wc.x = CB_ALGS::get_cost_pred<is_learn>(c.scorer, c.known_cost, *(examples[i]), 0, 2);
c.known_cost.action = known_index;
}
else { wc.x = CB_ALGS::get_cost_pred<is_learn>(c.scorer, CB::cb_class{}, *(examples[i]), 0, 2); }
else { wc.x = CB_ALGS::get_cost_pred<is_learn>(c.scorer, VW::cb_class{}, *(examples[i]), 0, 2); }

c.pred_scores.costs.push_back(wc); // done

Expand Down Expand Up @@ -255,12 +256,12 @@ void gen_cs_example(cb_to_cs_adf& c, VW::multi_ex& ec_seq, VW::cs_label& cs_labe
}
}

void cs_prep_labels(VW::multi_ex& examples, std::vector<CB::label>& cb_labels, VW::cs_label& cs_labels,
void cs_prep_labels(VW::multi_ex& examples, std::vector<VW::cb_label>& cb_labels, VW::cs_label& cs_labels,
std::vector<VW::cs_label>& prepped_cs_labels, uint64_t offset);

template <bool is_learn>
void cs_ldf_learn_or_predict(VW::LEARNER::multi_learner& base, VW::multi_ex& examples,
std::vector<CB::label>& cb_labels, VW::cs_label& cs_labels, std::vector<VW::cs_label>& prepped_cs_labels,
std::vector<VW::cb_label>& cb_labels, VW::cs_label& cs_labels, std::vector<VW::cs_label>& prepped_cs_labels,
bool predict_first, uint64_t offset, size_t id = 0)
{
VW_DBG(*examples[0]) << "cs_ldf_" << (is_learn ? "<learn>" : "<predict>") << ": ex=" << examples[0]->example_counter
Expand Down

0 comments on commit b5fe70d

Please sign in to comment.