Skip to content

Commit

Permalink
replaced safety with enforce_minimum_probability
Browse files Browse the repository at this point in the history
  • Loading branch information
eisber committed Apr 16, 2018
1 parent 362b03a commit d6a8795
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 59 deletions.
57 changes: 9 additions & 48 deletions vowpalwabbit/cb_explore.cc
Expand Up @@ -3,12 +3,14 @@
#include "rand48.h"
#include "bs.h"
#include "gen_cs_example.h"
#include "exploration.h"

using namespace LEARNER;
using namespace ACTION_SCORE;
using namespace GEN_CS;
using namespace std;
using namespace CB_ALGS;
using namespace exploration;
//All exploration algorithms return a vector of probabilities, to be used by GenericExplorer downstream

namespace CB_EXPLORE
Expand Down Expand Up @@ -72,6 +74,8 @@ template <bool is_learn>
void predict_or_learn_greedy(cb_explore& data, base_learner& base, example& ec)
{
//Explore uniform random an epsilon fraction of the time.
// TODO: pointers are copied here. What happens if base.learn/base.predict re-allocs?
// ec.pred.a_s = probs; will restore the than free'd memory
action_scores probs = ec.pred.a_s;
probs.erase();

Expand All @@ -80,11 +84,10 @@ void predict_or_learn_greedy(cb_explore& data, base_learner& base, example& ec)
else
base.predict(ec);

float prob = data.epsilon/(float)data.cbcs.num_actions;
for(uint32_t i = 0; i < data.cbcs.num_actions; i++)
probs.push_back({i,prob});
uint32_t chosen = ec.pred.multiclass-1;
probs[chosen].score += (1-data.epsilon);
// pre-allocate pdf
probs.resize(data.cbcs.num_actions);
probs.end() = probs.begin() + data.cbcs.num_actions;
generate_epsilon_greedy(data.epsilon, ec.pred.multiclass-1, begin_scores(probs), end_scores(probs));

ec.pred.a_s = probs;
}
Expand Down Expand Up @@ -116,48 +119,6 @@ void predict_or_learn_bag(cb_explore& data, base_learner& base, example& ec)
ec.pred.a_s = probs;
}

void safety(v_array<action_score>& distribution, float min_prob, bool zeros)
{
//input: a probability distribution
//output: a probability distribution with all events having probability > min_prob. This includes events with probability 0 if zeros = true
if (min_prob > 0.999) // uniform exploration
{
size_t support_size = distribution.size();
if (!zeros)
{
for (size_t i = 0; i < distribution.size(); ++i)
if (distribution[i].score == 0)
support_size--;
}
for (size_t i = 0; i < distribution.size(); ++i)
if (zeros || distribution[i].score > 0)
distribution[i].score = 1.f / support_size;
return;
}

min_prob /= distribution.size();
float touched_mass = 0.;
float untouched_mass = 0.;
for (uint32_t i = 0; i < distribution.size(); i++)
if ((distribution[i].score > 0 || (distribution[i].score ==0 && zeros)) && distribution[i].score <= min_prob)
{
touched_mass += min_prob;
distribution[i].score = min_prob;
}
else
untouched_mass += distribution[i].score;

if (touched_mass > 0.)
{
if (touched_mass > 0.999)
THROW("Cannot safety this distribution");
float ratio = (1.f - touched_mass) / untouched_mass;
for (uint32_t i = 0; i < distribution.size(); i++)
if (distribution[i].score > min_prob)
distribution[i].score = distribution[i].score * ratio;
}
}

void get_cover_probabilities(cb_explore& data, base_learner& base, example& ec, v_array<action_score>& probs)
{
float additive_probability = 1.f / (float)data.cover_size;
Expand All @@ -181,7 +142,7 @@ void get_cover_probabilities(cb_explore& data, base_learner& base, example& ec,

float min_prob = min(1.f / num_actions, 1.f / (float)sqrt(data.counter * num_actions));

safety(probs, min_prob*num_actions, false);
enforce_minimum_probability(min_prob*num_actions, false, begin_scores(probs), end_scores(probs));

data.counter++;
}
Expand Down
7 changes: 1 addition & 6 deletions vowpalwabbit/cb_explore.h
Expand Up @@ -10,9 +10,4 @@ namespace LEARNER
typedef learner<char> base_learner;
}

LEARNER::base_learner* cb_explore_setup(arguments& arg);

namespace CB_EXPLORE
{
void safety(v_array<ACTION_SCORE::action_score>& distribution, float min_prob, bool zeros);
}
LEARNER::base_learner* cb_explore_setup(arguments& arg);
10 changes: 5 additions & 5 deletions vowpalwabbit/cb_explore_adf.cc
Expand Up @@ -139,7 +139,8 @@ void predict_or_learn_first(cb_explore_adf& data, base_learner& base, v_array<ex
preds[i].score = 0.;
preds[0].score = 1.0;
}
CB_EXPLORE::safety(preds, data.epsilon, true);

enforce_minimum_probability(data.epsilon, true, begin_scores(preds), end_scores(preds));
}

template <bool is_learn>
Expand Down Expand Up @@ -201,8 +202,7 @@ void predict_or_learn_bag(cb_explore_adf& data, base_learner& base, v_array<exam
// generate distribution over actions
generate_bag(begin(top_actions), end(top_actions), begin_scores(data.action_probs), end_scores(data.action_probs));

// TODO: use exploration::safety
CB_EXPLORE::safety(data.action_probs, data.epsilon, true);
enforce_minimum_probability(data.epsilon, true, begin_scores(data.action_probs), end_scores(data.action_probs));
qsort((void*) data.action_probs.begin(), data.action_probs.size(), sizeof(action_score), reverse_order);

for (size_t i = 0; i < num_actions; i++)
Expand Down Expand Up @@ -266,7 +266,7 @@ void predict_or_learn_cover(cb_explore_adf& data, base_learner& base, v_array<ex
probs[action].score += additive_probability;
}

CB_EXPLORE::safety(data.action_probs, min_prob * num_actions, !data.nounif);
enforce_minimum_probability(min_prob * num_actions, !data.nounif, begin_scores(probs), end_scores(probs));

qsort((void*) probs.begin(), probs.size(), sizeof(action_score), reverse_order);
for (size_t i = 0; i < num_actions; i++)
Expand All @@ -286,7 +286,7 @@ void predict_or_learn_softmax(cb_explore_adf& data, base_learner& base, v_array<
v_array<action_score>& preds = examples[0]->pred.a_s;
generate_softmax(data.lambda, begin_scores(preds), end_scores(preds), begin_scores(preds), end_scores(preds));

CB_EXPLORE::safety(preds, data.epsilon, true);
enforce_minimum_probability(data.epsilon, true, begin_scores(preds), end_scores(preds));
}

void end_examples(cb_explore_adf& data)
Expand Down

0 comments on commit d6a8795

Please sign in to comment.