Skip to content

Commit

Permalink
fixed bugs in sample_from_pdf ranking
Browse files Browse the repository at this point in the history
changed variable names
  • Loading branch information
eisber committed Apr 19, 2018
1 parent 8261c87 commit 08f54ff
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 34 deletions.
23 changes: 12 additions & 11 deletions explore/explore.h
Expand Up @@ -5,6 +5,7 @@
#define E_EXPLORATION_PDF_RANKING_SIZE_MISMATCH 2

#include "explore_internal.h"
#include "hash.h"

namespace exploration {
/**
Expand Down Expand Up @@ -50,17 +51,17 @@ namespace exploration {
int generate_bag(InputIt top_actions_begin, InputIt top_actions_last, OutputIt pdf_first, OutputIt pdf_last);

/**
* @brief Updates the pdf to ensure each action is explored with at least min_prob/num_actions.
* @brief Updates the pdf to ensure each action is explored with at least minimum_uniform/num_actions.
*
* @tparam It Iterator type of the pdf. Must be a RandomAccessIterator.
* @param min_prob The minimum probability used for exploration.
* @param minimum_uniform The minimum amount of uniform distribution to impose on the pdf.
* @param update_zero_elements If true elements with zero probability are updated, otherwise those actions will be unchanged.
* @param pdf_first Iterator pointing to the pre-allocated beginning of the pdf to be generated by this function.
* @param pdf_last Iterator pointing to the pre-allocated end of the pdf to be generated by this function.
* @return int returns 0 on success, otherwise an error code as defined by E_EXPLORATION_*.
*/
template<typename It>
int enforce_minimum_probability(float min_prob, bool update_zero_elements, It pdf_first, It pdf_last);
int enforce_minimum_probability(float minimum_uniform, bool update_zero_elements, It pdf_first, It pdf_last);

/**
* @brief Sample an index from the provided pdf.
Expand Down Expand Up @@ -107,16 +108,16 @@ namespace exploration {
* @param pdf_begin Iterator pointing to the beginning of the pdf.
* @param pdf_end Iterator pointing to the end of the pdf.
* @param scores_begin Iterator pointing to the beginning of the scores.
* @param scores_end Iterator pointing to the end of the scores.
* @param scores_last Iterator pointing to the end of the scores.
* @param ranking_begin Iterator pointing to the pre-allocated beginning of the output ranking.
* @param ranking_end Iterator pointing to the pre-allocated end of the output ranking.
* @param ranking_last Iterator pointing to the pre-allocated end of the output ranking.
* @return int returns 0 on success, otherwise an error code as defined by E_EXPLORATION_*.
*/
template<typename InputPdfIt, typename InputScoreIt, typename OutputIt>
int sample_from_pdf(const char* seed, InputPdfIt pdf_begin, InputPdfIt pdf_end, InputScoreIt scores_begin, InputScoreIt scores_end, OutputIt ranking_begin, OutputIt ranking_end)
int sample_from_pdf(const char* seed, InputPdfIt pdf_begin, InputPdfIt pdf_end, InputScoreIt scores_begin, InputScoreIt scores_last, OutputIt ranking_begin, OutputIt ranking_last)
{
uint64_t seed_hash = uniform_hash(seed, strlen(seed), 0);
return sample_from_pdf(seed_hash, pdf_begin, pdf_end, scores_begin, scores_end, ranking_begin, ranking_end);
return sample_from_pdf(seed_hash, pdf_begin, pdf_end, scores_begin, scores_last, ranking_begin, ranking_last);
}

/**
Expand All @@ -130,18 +131,18 @@ namespace exploration {
* @param pdf_begin Iterator pointing to the beginning of the pdf.
* @param pdf_end Iterator pointing to the end of the pdf.
* @param scores_begin Iterator pointing to the beginning of the scores.
* @param scores_end Iterator pointing to the end of the scores.
* @param scores_last Iterator pointing to the end of the scores.
* @param ranking_begin Iterator pointing to the pre-allocated beginning of the output ranking.
* @param ranking_end Iterator pointing to the pre-allocated end of the output ranking.
* @param ranking_last Iterator pointing to the pre-allocated end of the output ranking.
* @return int returns 0 on success, otherwise an error code as defined by E_EXPLORATION_*.
*/
template<typename InputPdfIt, typename InputScoreIt, typename OutputIt>
int sample_from_pdf(uint64_t seed, InputPdfIt pdf_begin, InputPdfIt pdf_end, InputScoreIt scores_begin, InputScoreIt scores_end, OutputIt ranking_begin, OutputIt ranking_end)
int sample_from_pdf(uint64_t seed, InputPdfIt pdf_begin, InputPdfIt pdf_end, InputScoreIt scores_begin, InputScoreIt scores_last, OutputIt ranking_begin, OutputIt ranking_last)
{
typedef typename std::iterator_traits<InputPdfIt>::iterator_category pdf_category;
typedef typename std::iterator_traits<InputScoreIt>::iterator_category scores_category;
typedef typename std::iterator_traits<OutputIt>::iterator_category ranking_category;

return sample_from_pdf(seed, pdf_begin, pdf_end, pdf_category(), scores_begin, scores_end, scores_category(), ranking_begin, ranking_end, ranking_category());
return sample_from_pdf(seed, pdf_begin, pdf_end, pdf_category(), scores_begin, scores_last, scores_category(), ranking_begin, ranking_last, ranking_category());
}
}
51 changes: 28 additions & 23 deletions explore/explore_internal.h
Expand Up @@ -150,15 +150,15 @@ namespace exploration
}

template<typename It>
int enforce_minimum_probability(float min_prob, bool update_zero_elements, It pdf_first, It pdf_last, std::random_access_iterator_tag pdf_tag)
int enforce_minimum_probability(float minimum_uniform, bool update_zero_elements, It pdf_first, It pdf_last, std::random_access_iterator_tag pdf_tag)
{
// iterators don't support <= in general
if (pdf_first == pdf_last || pdf_last < pdf_first)
return E_EXPLORATION_BAD_RANGE;

size_t num_actions = pdf_last - pdf_first;

if (min_prob > 0.999) // uniform exploration
if (minimum_uniform > 0.999) // uniform exploration
{
size_t support_size = num_actions;
if (!update_zero_elements)
Expand All @@ -175,18 +175,18 @@ namespace exploration
return S_EXPLORATION_OK;
}

min_prob /= num_actions;
minimum_uniform /= num_actions;
float touched_mass = 0.;
float untouched_mass = 0.;
uint16_t num_actions_touched = 0;

for (It d = pdf_first; d != pdf_last; ++d)
{
auto& prob = *d;
if ((prob > 0 || (prob == 0 && update_zero_elements)) && prob <= min_prob)
if ((prob > 0 || (prob == 0 && update_zero_elements)) && prob <= minimum_uniform)
{
touched_mass += min_prob;
prob = min_prob;
touched_mass += minimum_uniform;
prob = minimum_uniform;
++num_actions_touched;
}
else
Expand All @@ -197,19 +197,19 @@ namespace exploration
{
if (touched_mass > 0.999)
{
min_prob = (1.f - untouched_mass) / (float)num_actions_touched;
minimum_uniform = (1.f - untouched_mass) / (float)num_actions_touched;
for (It d = pdf_first; d != pdf_last; ++d)
{
auto& prob = *d;
if ((prob > 0 || (prob == 0 && update_zero_elements)) && prob <= min_prob)
prob = min_prob;
if ((prob > 0 || (prob == 0 && update_zero_elements)) && prob <= minimum_uniform)
prob = minimum_uniform;
}
}
else
{
float ratio = (1.f - touched_mass) / untouched_mass;
for (It d = pdf_first; d != pdf_last; ++d)
if (*d > min_prob)
if (*d > minimum_uniform)
*d *= ratio;
}
}
Expand All @@ -218,11 +218,11 @@ namespace exploration
}

template<typename It>
int enforce_minimum_probability(float min_prob, bool update_zero_elements, It pdf_first, It pdf_last)
int enforce_minimum_probability(float minimum_uniform, bool update_zero_elements, It pdf_first, It pdf_last)
{
typedef typename std::iterator_traits<It>::iterator_category pdf_category;

return enforce_minimum_probability(min_prob, update_zero_elements, pdf_first, pdf_last, pdf_category());
return enforce_minimum_probability(minimum_uniform, update_zero_elements, pdf_first, pdf_last, pdf_category());
}

template<typename InputIt>
Expand Down Expand Up @@ -265,40 +265,45 @@ namespace exploration
}

template<typename InputIt>
uint32_t sample_from_pdf(const char* seed, InputIt pdf_first, InputIt pdf_last, std::input_iterator_tag pdf_category)
int sample_from_pdf(const char* seed, InputIt pdf_first, InputIt pdf_last, uint32_t& chosen_index, std::input_iterator_tag pdf_category)
{
uint64_t seed_hash = uniform_hash(seed, strlen(seed), 0);
return sample_from_pdf(seed_hash, pdf_first, pdf_last);
return sample_from_pdf(seed_hash, pdf_first, pdf_last, chosen_index, pdf_category);
}


template<typename InputPdfIt, typename InputScoreIt, typename OutputIt>
int sample_from_pdf(uint64_t seed,
InputPdfIt pdf_begin, InputPdfIt pdf_end, std::input_iterator_tag pdf_category,
InputScoreIt scores_begin, InputScoreIt scores_end, std::random_access_iterator_tag scores_category,
OutputIt ranking_begin, OutputIt ranking_end, std::random_access_iterator_tag ranking_category)
InputScoreIt scores_begin, InputScoreIt scores_last, std::random_access_iterator_tag scores_category,
OutputIt ranking_begin, OutputIt ranking_last, std::random_access_iterator_tag ranking_category)
{
if (pdf_end < pdf_begin || ranking_end < ranking_begin)
if (pdf_end < pdf_begin || ranking_last < ranking_begin)
return E_EXPLORATION_BAD_RANGE;

size_t pdf_size = pdf_end - pdf_begin;
size_t ranking_size = ranking_end - ranking_begin;
size_t ranking_size = ranking_last - ranking_begin;

if (pdf_size == 0)
return E_EXPLORATION_BAD_RANGE;

if (pdf_size != ranking_size)
return E_EXPLORATION_PDF_RANKING_SIZE_MISMATCH;

uint32_t chosen_action = sample_from_pdf(seed, pdf_begin, pdf_end);
uint32_t chosen_action;
int ret = sample_from_pdf(seed, pdf_begin, pdf_end, chosen_action);
if (ret)
return ret;

std::iota(ranking_begin, ranking_end, 0);
std::iota(ranking_begin, ranking_last, 0);

// sort indexes based on comparing values in scores
std::sort(ranking_begin, ranking_end,
[&scores_begin, &scores_end](size_t i1, size_t i2) { return scores_begin[i1] > scores_end[i2]; });
std::sort(ranking_begin, ranking_last,
[&scores_begin](size_t i1, size_t i2) { return scores_begin[i1] > scores_begin[i2]; });

// swap top element with chosen one
std::iter_swap(ranking_begin, ranking_end + chosen_action);
if (chosen_action != 0)
std::iter_swap(ranking_begin, ranking_last + chosen_action);

return S_EXPLORATION_OK;
}
Expand Down

0 comments on commit 08f54ff

Please sign in to comment.