diff --git a/vowpalwabbit/cb_explore.cc b/vowpalwabbit/cb_explore.cc index 759306c6a8b..7e9ffcbc63e 100644 --- a/vowpalwabbit/cb_explore.cc +++ b/vowpalwabbit/cb_explore.cc @@ -3,12 +3,14 @@ #include "rand48.h" #include "bs.h" #include "gen_cs_example.h" +#include "exploration.h" using namespace LEARNER; using namespace ACTION_SCORE; using namespace GEN_CS; using namespace std; using namespace CB_ALGS; +using namespace exploration; //All exploration algorithms return a vector of probabilities, to be used by GenericExplorer downstream namespace CB_EXPLORE @@ -72,6 +74,8 @@ template void predict_or_learn_greedy(cb_explore& data, base_learner& base, example& ec) { //Explore uniform random an epsilon fraction of the time. + // TODO: pointers are copied here. What happens if base.learn/base.predict re-allocs? + // ec.pred.a_s = probs; will restore the than free'd memory action_scores probs = ec.pred.a_s; probs.erase(); @@ -80,11 +84,10 @@ void predict_or_learn_greedy(cb_explore& data, base_learner& base, example& ec) else base.predict(ec); - float prob = data.epsilon/(float)data.cbcs.num_actions; - for(uint32_t i = 0; i < data.cbcs.num_actions; i++) - probs.push_back({i,prob}); - uint32_t chosen = ec.pred.multiclass-1; - probs[chosen].score += (1-data.epsilon); + // pre-allocate pdf + probs.resize(data.cbcs.num_actions); + probs.end() = probs.begin() + data.cbcs.num_actions; + generate_epsilon_greedy(data.epsilon, ec.pred.multiclass-1, begin_scores(probs), end_scores(probs)); ec.pred.a_s = probs; } @@ -116,48 +119,6 @@ void predict_or_learn_bag(cb_explore& data, base_learner& base, example& ec) ec.pred.a_s = probs; } -void safety(v_array& distribution, float min_prob, bool zeros) -{ - //input: a probability distribution - //output: a probability distribution with all events having probability > min_prob. This includes events with probability 0 if zeros = true - if (min_prob > 0.999) // uniform exploration - { - size_t support_size = distribution.size(); - if (!zeros) - { - for (size_t i = 0; i < distribution.size(); ++i) - if (distribution[i].score == 0) - support_size--; - } - for (size_t i = 0; i < distribution.size(); ++i) - if (zeros || distribution[i].score > 0) - distribution[i].score = 1.f / support_size; - return; - } - - min_prob /= distribution.size(); - float touched_mass = 0.; - float untouched_mass = 0.; - for (uint32_t i = 0; i < distribution.size(); i++) - if ((distribution[i].score > 0 || (distribution[i].score ==0 && zeros)) && distribution[i].score <= min_prob) - { - touched_mass += min_prob; - distribution[i].score = min_prob; - } - else - untouched_mass += distribution[i].score; - - if (touched_mass > 0.) - { - if (touched_mass > 0.999) - THROW("Cannot safety this distribution"); - float ratio = (1.f - touched_mass) / untouched_mass; - for (uint32_t i = 0; i < distribution.size(); i++) - if (distribution[i].score > min_prob) - distribution[i].score = distribution[i].score * ratio; - } -} - void get_cover_probabilities(cb_explore& data, base_learner& base, example& ec, v_array& probs) { float additive_probability = 1.f / (float)data.cover_size; @@ -181,7 +142,7 @@ void get_cover_probabilities(cb_explore& data, base_learner& base, example& ec, float min_prob = min(1.f / num_actions, 1.f / (float)sqrt(data.counter * num_actions)); - safety(probs, min_prob*num_actions, false); + enforce_minimum_probability(min_prob*num_actions, false, begin_scores(probs), end_scores(probs)); data.counter++; } diff --git a/vowpalwabbit/cb_explore.h b/vowpalwabbit/cb_explore.h index 928f25b4891..e196ae1a5e6 100644 --- a/vowpalwabbit/cb_explore.h +++ b/vowpalwabbit/cb_explore.h @@ -10,9 +10,4 @@ namespace LEARNER typedef learner base_learner; } -LEARNER::base_learner* cb_explore_setup(arguments& arg); - -namespace CB_EXPLORE -{ -void safety(v_array& distribution, float min_prob, bool zeros); -} +LEARNER::base_learner* cb_explore_setup(arguments& arg); \ No newline at end of file diff --git a/vowpalwabbit/cb_explore_adf.cc b/vowpalwabbit/cb_explore_adf.cc index 0a250cb58d7..e0dedf1445d 100644 --- a/vowpalwabbit/cb_explore_adf.cc +++ b/vowpalwabbit/cb_explore_adf.cc @@ -139,7 +139,8 @@ void predict_or_learn_first(cb_explore_adf& data, base_learner& base, v_array @@ -201,8 +202,7 @@ void predict_or_learn_bag(cb_explore_adf& data, base_learner& base, v_array& preds = examples[0]->pred.a_s; generate_softmax(data.lambda, begin_scores(preds), end_scores(preds), begin_scores(preds), end_scores(preds)); - CB_EXPLORE::safety(preds, data.epsilon, true); + enforce_minimum_probability(data.epsilon, true, begin_scores(preds), end_scores(preds)); } void end_examples(cb_explore_adf& data)