Skip to content

Commit

Permalink
Implement prediction and label reuse
Browse files Browse the repository at this point in the history
  • Loading branch information
jackgerrits committed Feb 12, 2020
1 parent 19288b8 commit 62bc213
Show file tree
Hide file tree
Showing 12 changed files with 36 additions and 29 deletions.
1 change: 1 addition & 0 deletions vowpalwabbit/cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ int read_cached_features(vw* all, v_array<example*>& examples)
ae->sorted = all->p->sorted_cache;
io_buf* input = all->p->input;

all->p->lp.default_label(ae->l);
size_t total = all->p->lp.read_cached_label(all->p->_shared_data, ae->l, *input);
if (total == 0)
return 0;
Expand Down
13 changes: 8 additions & 5 deletions vowpalwabbit/cb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ size_t read_cached_label(shared_data*, CB::label& ld, io_buf& cache)

size_t read_cached_label(shared_data* s, polylabel& v, io_buf& cache)
{
return CB::read_cached_label(s, v.init_as_cb(), cache);
return CB::read_cached_label(s, v.cb(), cache);
}

float weight(CB::label& ld) { return ld.weight; }
Expand Down Expand Up @@ -87,11 +87,12 @@ void default_label(CB::label& ld)

void default_label(polylabel& v)
{
if (v.get_type() != label_type_t::unset)
if (v.get_type() != label_type_t::cb)
{
v.reset();
v.init_as_cb();
}
CB::default_label(v.init_as_cb());
CB::default_label(v.cb());
}

bool test_label(CB::label& ld)
Expand Down Expand Up @@ -256,11 +257,13 @@ void cache_label(polylabel& v, io_buf& cache)

void default_label(polylabel& v)
{
if (v.get_type() != label_type_t::unset)
if (v.get_type() != label_type_t::cb_eval)
{
v.reset();
v.init_as_cb_eval();

}
auto& ld = v.init_as_cb_eval();
auto& ld = v.cb_eval();
CB::default_label(ld.event);
ld.action = 0;
}
Expand Down
6 changes: 4 additions & 2 deletions vowpalwabbit/ccb_label.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,13 @@ void cache_label(polylabel& v, io_buf& cache)

void default_label(polylabel& v)
{
if (v.get_type() != label_type_t::unset)
if (v.get_type() != label_type_t::conditional_contextual_bandit)
{
v.reset();
v.init_as_ccb();

}
CCB::label& ld = v.init_as_conditional_contextual_bandit();
CCB::label& ld = v.ccb();

// This is tested against nullptr, so unfortunately as things are this must be deleted when not used.
if (ld.outcome)
Expand Down
2 changes: 1 addition & 1 deletion vowpalwabbit/conditional_contextual_bandit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ void learn_or_predict(ccb& data, multi_learner& base, multi_ex& examples)
// Restore ccb labels to the example objects.
for (size_t i = 0; i < examples.size(); i++)
{
examples[i]->l.init_as_conditional_contextual_bandit(std::move(data.stored_labels[i]));
examples[i]->l.init_as_ccb(std::move(data.stored_labels[i]));
}
data.stored_labels.clear();

Expand Down
8 changes: 5 additions & 3 deletions vowpalwabbit/cost_sensitive.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ char* bufread_label(label& ld, char* c, io_buf& cache)

size_t read_cached_label(shared_data*, polylabel& v, io_buf& cache)
{
auto& ld = v.init_as_cs();
auto& ld = v.cs();

ld.costs.clear();
char* c;
Expand Down Expand Up @@ -92,11 +92,13 @@ void default_label(label& label) { label.costs.clear(); }

void default_label(polylabel& v)
{
if (v.get_type() != label_type_t::unset)
if (v.get_type() != label_type_t::cs)
{
v.reset();
v.init_as_cs();
}
auto& ld = v.init_as_cs();

auto& ld = v.cs();
default_label(ld);
}

Expand Down
8 changes: 4 additions & 4 deletions vowpalwabbit/label.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ struct polylabel
init_as_cb(other._cb);
break;
case (label_type_t::conditional_contextual_bandit):
init_as_conditional_contextual_bandit(other._conditional_contextual_bandit);
init_as_ccb(other._conditional_contextual_bandit);
break;
case (label_type_t::cb_eval):
init_as_cb_eval(other._cb_eval);
Expand Down Expand Up @@ -179,7 +179,7 @@ struct polylabel
init_as_cb(std::move(other._cb));
break;
case (label_type_t::conditional_contextual_bandit):
init_as_conditional_contextual_bandit(std::move(other._conditional_contextual_bandit));
init_as_ccb(std::move(other._conditional_contextual_bandit));
break;
case (label_type_t::cb_eval):
init_as_cb_eval(std::move(other._cb_eval));
Expand Down Expand Up @@ -228,7 +228,7 @@ struct polylabel
{
case (label_type_t::unset):
// Nothing to do! Whatever was in here has already been destroyed.
break;
return;
case (label_type_t::empty):
destruct(_empty);
break;
Expand Down Expand Up @@ -364,7 +364,7 @@ struct polylabel
}

template <typename... Args>
CCB::label& init_as_conditional_contextual_bandit(Args&&... args)
CCB::label& init_as_ccb(Args&&... args)
{
ensure_is_type(label_type_t::unset);
new (&_conditional_contextual_bandit) CCB::label(std::forward<Args>(args)...);
Expand Down
7 changes: 4 additions & 3 deletions vowpalwabbit/multiclass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ char* bufread_label(label_t& ld, char* c)

size_t read_cached_label(shared_data*, polylabel& v, io_buf& cache)
{
auto& ld = v.init_as_multi();
auto& ld = v.multi();
char* c;
size_t total = sizeof(ld.label) + sizeof(ld.weight);
if (cache.buf_read(c, total) < total)
Expand Down Expand Up @@ -57,11 +57,12 @@ void cache_label(polylabel& v, io_buf& cache)

void default_label(polylabel& v)
{
if (v.get_type() != label_type_t::unset)
if (v.get_type() != label_type_t::multi)
{
v.reset();
v.init_as_multi();
}
auto& ld = v.init_as_multi();
auto& ld = v.multi();
ld.label = (uint32_t)-1;
ld.weight = 1.;
}
Expand Down
7 changes: 4 additions & 3 deletions vowpalwabbit/multilabel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ char* bufread_label(labels& ld, char* c, io_buf& cache)

size_t read_cached_label(shared_data*, polylabel& v, io_buf& cache)
{
auto& ld = v.init_as_multilabels();
auto& ld = v.multilabels();
ld.label_v.clear();
char* c;
size_t total = sizeof(size_t);
Expand Down Expand Up @@ -66,11 +66,12 @@ void cache_label(polylabel& v, io_buf& cache)

void default_label(polylabel& v)
{
if (v.get_type() != label_type_t::unset)
if (v.get_type() != label_type_t::multilabels)
{
v.reset();
v.init_as_multilabels();
}
auto& ld = v.init_as_multilabels();
auto& ld = v.multilabels();
ld.label_v.clear();
}

Expand Down
2 changes: 1 addition & 1 deletion vowpalwabbit/no_label.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ void cache_no_label(polylabel&, io_buf&) {}
// This is wasted work, ideally empty and unset should be the same thing.
void default_no_label(polylabel& label)
{
if (label.get_type() != label_type_t::empty)
if (label.get_type() != label_type_t::empty && label.get_type() != label_type_t::empty)
{
label.reset();
label.init_as_empty();
Expand Down
7 changes: 2 additions & 5 deletions vowpalwabbit/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -749,8 +749,9 @@ void setup_example(vw& all, example* ae)
ae->total_sum_feat_sq += new_features_sum_feat_sq;

// Prediction type should be preinitialized for the given reductions expected type.
if(ae->pred.get_type() == prediction_type_t::unset)
if(ae->pred.get_type() != all.l->pred_type)
{
ae->pred.reset();
switch (all.l->pred_type)
{
case (prediction_type_t::scalar):
Expand Down Expand Up @@ -895,10 +896,6 @@ void empty_example(vw& /*all*/, example& ec)
for (features& fs : ec)
fs.clear();

// TODO - This is inefficient as we are losing allocated buffers. Once tests are passing this should be removed.
ec.l.reset();
ec.pred.reset();

ec.indices.clear();
ec.tag.clear();
ec.sorted = false;
Expand Down
2 changes: 1 addition & 1 deletion vowpalwabbit/prediction.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ struct polyprediction
{
case (prediction_type_t::unset):
// Nothing to do! Whatever was in here has already been destroyed.
break;
return;
case (prediction_type_t::scalar):
destruct(_scalar);
break;
Expand Down
2 changes: 1 addition & 1 deletion vowpalwabbit/simple_label.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ char* bufread_simple_label(shared_data* sd, label_data& ld, char* c)

size_t read_cached_simple_label(shared_data* sd, polylabel& in_ld, io_buf& cache)
{
auto& ld = in_ld.init_as_simple();
auto& ld = in_ld.simple();
char* c;
size_t total = sizeof(ld.label) + sizeof(ld.weight) + sizeof(ld.initial);
if (cache.buf_read(c, total) < total)
Expand Down

0 comments on commit 62bc213

Please sign in to comment.