Skip to content

Commit

Permalink
1) report_outcome() should log in background thread. 2) Bug fixes and…
Browse files Browse the repository at this point in the history
… cleanups (#1549)

* Missed bug fix checkin

* 1) Move report outcome to background thread 2) Use object pool for buffer since this can be acessed from multiple threads 3) stringstream reset does not work properly.  replaced with data_buffer 4) use size_t for action id since it indexes an array

* Fix linux build

* Add vw as dependency for rl_clientlib

* Remove binaries

* Remove debug code from rl_sim

* Address PR comments
  • Loading branch information
rajan-chari authored and JohnLangford committed Jul 26, 2018
1 parent 246b271 commit 7af4ff4
Show file tree
Hide file tree
Showing 25 changed files with 261 additions and 170 deletions.
2 changes: 1 addition & 1 deletion Makefile
Expand Up @@ -104,7 +104,7 @@ all: vw library_example java spanning_tree rl_clientlib

export

rl_clientlib:
rl_clientlib: vw
cd reinforcement_learning/rlclientlib; $(MAKE) -j $(NPROCS) things

rl_clientlib_test: vw rl_clientlib
Expand Down
Binary file not shown.
8 changes: 6 additions & 2 deletions reinforcement_learning/examples/rl_sim_cpp/rl_sim.cc
Expand Up @@ -48,13 +48,15 @@ int rl_sim::loop() {

stats.record(p.id(), choosen_action, reward);

std::cout << stats.count() << ", ctxt, " << p.id() << ", action, " << choosen_action << ", reward, " << reward
std::cout << " " << stats.count() << ", ctxt, " << p.id() << ", action, " << choosen_action << ", reward, " << reward
<< ", dist, " << get_dist_str(response) << ", " << stats.get_stats(p.id(), choosen_action) << std::endl;

response.clear();

std::this_thread::sleep_for(std::chrono::milliseconds(2000));
}

return 0;
}

person& rl_sim::pick_a_random_person() {
Expand Down Expand Up @@ -103,6 +105,8 @@ int rl_sim::init_rl() {
return -1;
}

std::cout << " API Config " << config;

return err::success;
}

Expand Down Expand Up @@ -157,7 +161,7 @@ rl_sim::rl_sim(boost::program_options::variables_map vm) :_options(std::move(vm)
std::string get_dist_str(const reinforcement_learning::ranking_response& response) {
std::string ret;
ret += "(";
for (auto& ap_pair : response) {
for (const auto& ap_pair : response) {
ret += "[" + to_string(ap_pair.action_id) + ",";
ret += to_string(ap_pair.probability) + "]";
ret += " ,";
Expand Down
Binary file not shown.
10 changes: 9 additions & 1 deletion reinforcement_learning/include/config_collection.h
Expand Up @@ -9,6 +9,12 @@
#include <string>
#include <unordered_map>

namespace reinforcement_learning {namespace utility {
class config_collection;
}}

std::ostream& operator<<(std::ostream& os, const reinforcement_learning::utility::config_collection&);

namespace reinforcement_learning { namespace utility {
/**
* @brief Configuration class to initialize the API
Expand Down Expand Up @@ -38,9 +44,11 @@ namespace reinforcement_learning { namespace utility {
bool get_bool(const char* str, bool defval) const;
//! Gets the value as a float. If the value does not exist or if there is an error, it returns defval
float get_float(const char* name, float defval) const;
//! friend Left shift operator
friend std::ostream& ::operator<<(std::ostream& os, const config_collection&);

private:
using map_type = std::unordered_map<std::string, std::string>; //! Collection type that holds the (name,value) pairs
map_type* _pmap; //! Collection that holds the (name,value) pairs
};
}}
}}
9 changes: 6 additions & 3 deletions reinforcement_learning/include/constants.h
Expand Up @@ -16,9 +16,12 @@ namespace reinforcement_learning { namespace name {
const char *const OBSERVATION_EH_NAME = "observation.eventhub.name";
const char *const OBSERVATION_EH_KEY_NAME = "observation.eventhub.keyname";
const char *const OBSERVATION_EH_KEY = "observation.eventhub.key";
const char *const SEND_HIGH_WATER_MARK = "eh.send.highwatermark";
const char *const SEND_QUEUE_MAXSIZE = "eh.send.queue.maxsize";
const char *const SEND_BATCH_INTERVAL = "eh.send.batchintervalms";
const char *const INTERACTION_SEND_HIGH_WATER_MARK = "interaction.send.highwatermark";
const char *const INTERACTION_SEND_QUEUE_MAXSIZE = "interaction.send.queue.maxsize";
const char *const INTERACTION_SEND_BATCH_INTERVAL_MS = "interaction.send.batchintervalms";
const char *const OBSERVATION_SEND_HIGH_WATER_MARK = "observation.send.highwatermark";
const char *const OBSERVATION_SEND_QUEUE_MAXSIZE = "observation.send.queue.maxsize";
const char *const OBSERVATION_SEND_BATCH_INTERVAL_MS = "observation.send.batchintervalms";
const char *const EH_TEST = "eventhub.mock";
}}

Expand Down
4 changes: 2 additions & 2 deletions reinforcement_learning/include/ranking_response.h
Expand Up @@ -17,7 +17,7 @@ namespace reinforcement_learning {
*/
struct action_prob {
//! action id
int action_id;
size_t action_id;
//! probablity associated with the action id
float probability;
};
Expand Down Expand Up @@ -75,7 +75,7 @@ namespace reinforcement_learning {
* @param action_id
* @param prob
*/
void push_back(const int action_id, const float prob);
void push_back(const size_t action_id, const float prob);

/**
* @brief Size of the action collection.
Expand Down
198 changes: 92 additions & 106 deletions reinforcement_learning/rlclientlib/live_model_impl.cc
Expand Up @@ -18,10 +18,10 @@
namespace e = exploration;
using namespace std;

namespace reinforcement_learning
{
namespace reinforcement_learning {
// Some namespace changes for more concise code
namespace m = model_management;
namespace u = utility;

// Some typdefs for more concise code
using vw_ptr = std::shared_ptr<safe_vw>;
Expand All @@ -41,16 +41,14 @@ namespace reinforcement_learning
return scode;
}

int live_model_impl::choose_rank(const char* uuid, const char* context, ranking_response& response, api_status* status)
{
int live_model_impl::choose_rank(const char* uuid, const char* context, ranking_response& response,
api_status* status) {
//clear previous errors if any
api_status::try_clear(status);

//check arguments
RETURN_IF_FAIL(check_null_or_empty(uuid, context, status));

int scode;
if(!_model_data_received) {
if (!_model_data_received) {
scode = explore_only(uuid, context, response, status);
RETURN_IF_FAIL(scode);
response.set_model_id("N/A");
Expand All @@ -59,25 +57,21 @@ namespace reinforcement_learning
scode = explore_exploit(uuid, context, response, status);
RETURN_IF_FAIL(scode);
}

response.set_uuid(uuid);

// Serialize the event
_buff.seekp(0, std::ostringstream::beg);
ranking_event::serialize(_buff, uuid, context, response);
_buff << std::ends;
auto sbuf = _buff.str();

u::pooled_object_guard<u::data_buffer, u::buffer_factory> guard(_buffer_pool, _buffer_pool.get_or_create());
guard->reset();
ranking_event::serialize(*guard.get(), uuid, context, response);
auto sbuf = guard->str();
// Send the ranking event to the backend
RETURN_IF_FAIL(_logger.append_ranking(sbuf, status));

return error_code::success;
}

//here the uuid is auto-generated
int live_model_impl::choose_rank(const char* context, ranking_response& response, api_status* status) {
return choose_rank(boost::uuids::to_string(boost::uuids::random_generator()()).c_str(), context, response,
status);
status);
}

int live_model_impl::report_outcome(const char* uuid, const char* outcome_data, api_status* status) {
Expand All @@ -100,107 +94,99 @@ namespace reinforcement_learning
model_factory_t* m_factory
)
: _configuration(config),
_error_cb(fn, err_context),
_data_cb(_handle_model_update, this),
_logger(config, &_error_cb),
_t_factory{t_factory},
_m_factory{m_factory},
_transport(nullptr),
_model(nullptr),
_model_download(nullptr),
_bg_model_proc(config.get_int(name::MODEL_REFRESH_INTERVAL_MS, 60 * 1000), &_error_cb) { }

int live_model_impl::init_model(api_status* status) {
const auto model_impl = _configuration.get(name::MODEL_IMPLEMENTATION, value::VW);
m::i_model* pmodel;
RETURN_IF_FAIL(_m_factory->create(&pmodel, model_impl, _configuration,status));
_model.reset(pmodel);
return error_code::success;
}

void inline live_model_impl::_handle_model_update(const m::model_data& data, live_model_impl* ctxt) {
ctxt->handle_model_update(data);
}

void live_model_impl::handle_model_update(const model_management::model_data& data) {
api_status status;
if(_model->update(data,&status) != error_code::success) {
_error_cb.report_error(status);
return;
_error_cb(fn, err_context),
_data_cb(_handle_model_update, this),
_logger(config, &_error_cb),
_t_factory{t_factory},
_m_factory{m_factory},
_transport(nullptr),
_model(nullptr),
_model_download(nullptr),
_bg_model_proc(config.get_int(name::MODEL_REFRESH_INTERVAL_MS, 60 * 1000), &_error_cb),
_buffer_pool(new u::buffer_factory()) { }

int live_model_impl::init_model(api_status* status) {
const auto model_impl = _configuration.get(name::MODEL_IMPLEMENTATION, value::VW);
m::i_model* pmodel;
RETURN_IF_FAIL(_m_factory->create(&pmodel, model_impl, _configuration,status));
_model.reset(pmodel);
return error_code::success;
}
_model_data_received = true;
}

int live_model_impl::explore_only(const char* uuid, const char* context, ranking_response& response, api_status* status) const {

// Generate egreedy pdf
size_t action_count = 0;
RETURN_IF_FAIL(utility::get_action_count(action_count, context, status));
vector<float> pdf(action_count);

// Assume that the user's top choice for action is at index 0
const auto top_action_id = 0;
auto scode = e::generate_epsilon_greedy(_initial_epsilon, top_action_id, begin(pdf), end(pdf));
if( S_EXPLORATION_OK != scode) {
RETURN_ERROR_LS(status, exploration_error) << "Exploration error code: " << scode;
void inline live_model_impl::_handle_model_update(const m::model_data& data, live_model_impl* ctxt) {
ctxt->handle_model_update(data);
}

// Pick using the pdf
uint32_t choosen_action_id;
scode = e::sample_after_normalizing(uuid, begin(pdf), end(pdf), choosen_action_id);
if ( S_EXPLORATION_OK != scode ) {
RETURN_ERROR_LS(status, exploration_error) << "Exploration error code: " << scode;
void live_model_impl::handle_model_update(const model_management::model_data& data) {
api_status status;
if (_model->update(data, &status) != error_code::success) {
_error_cb.report_error(status);
return;
}
_model_data_received = true;
}

response.push_back(top_action_id, pdf[top_action_id]);

// Setup response with pdf from prediction and choosen action
for ( size_t idx = 0; idx < pdf.size(); ++idx )
if ( top_action_id != idx )
response.push_back(idx, pdf[idx]);

response.set_choosen_action_id(top_action_id);

return error_code::success;
}

int live_model_impl::explore_exploit(const char* uuid, const char* context, ranking_response& response,
api_status* status) const {
return _model->choose_rank(uuid, context, response, status);
}

int live_model_impl::init_model_mgmt(api_status* status) {
// Initialize transport for the model using transport factory
const auto tranport_impl = _configuration.get(name::MODEL_SRC, value::AZURE_STORAGE_BLOB);
m::i_data_transport* ptransport;
RETURN_IF_FAIL(_t_factory->create(&ptransport, tranport_impl, _configuration, status));
// This class manages lifetime of transport
this->_transport.reset(ptransport);

// Initialize background process and start downloading models
this->_model_download.reset(new m::model_downloader(ptransport, &_data_cb));
return _bg_model_proc.init(_model_download.get(),status);
}
int live_model_impl::explore_only(const char* uuid, const char* context, ranking_response& response,
api_status* status) const {
// Generate egreedy pdf
size_t action_count = 0;
RETURN_IF_FAIL(utility::get_action_count(action_count, context, status));
vector<float> pdf(action_count);
// Assume that the user's top choice for action is at index 0
const auto top_action_id = 0;
auto scode = e::generate_epsilon_greedy(_initial_epsilon, top_action_id, begin(pdf), end(pdf));
if (S_EXPLORATION_OK != scode) {
RETURN_ERROR_LS(status, exploration_error) << "Exploration error code: " << scode;
}
// Pick using the pdf
uint32_t choosen_action_id;
scode = e::sample_after_normalizing(uuid, begin(pdf), end(pdf), choosen_action_id);
if (S_EXPLORATION_OK != scode) {
RETURN_ERROR_LS(status, exploration_error) << "Exploration error code: " << scode;
}
response.push_back(top_action_id, pdf[top_action_id]);
// Setup response with pdf from prediction and choosen action
for (size_t idx = 0; idx < pdf.size(); ++idx)
if (top_action_id != idx)
response.push_back(idx, pdf[idx]);
response.set_choosen_action_id(top_action_id);
return error_code::success;
}

//helper: check if at least one of the arguments is null or empty
int check_null_or_empty(const char* arg1, const char* arg2, api_status* status) {
if ( !arg1 || !arg2 || strlen(arg1) == 0 || strlen(arg2) == 0 ) {
api_status::try_update(status, error_code::invalid_argument,
"one of the arguments passed to the ds is null or empty");
return error_code::invalid_argument;
int live_model_impl::explore_exploit(const char* uuid, const char* context, ranking_response& response,
api_status* status) const {
return _model->choose_rank(uuid, context, response, status);
}

return error_code::success;
}
int live_model_impl::init_model_mgmt(api_status* status) {
// Initialize transport for the model using transport factory
const auto tranport_impl = _configuration.get(name::MODEL_SRC, value::AZURE_STORAGE_BLOB);
m::i_data_transport* ptransport;
RETURN_IF_FAIL(_t_factory->create(&ptransport, tranport_impl, _configuration, status));
// This class manages lifetime of transport
this->_transport.reset(ptransport);
// Initialize background process and start downloading models
this->_model_download.reset(new m::model_downloader(ptransport, &_data_cb));
return _bg_model_proc.init(_model_download.get(), status);
}

int check_null_or_empty(const char* arg1, api_status* status) {
if ( !arg1 || strlen(arg1) == 0) {
api_status::try_update(status, error_code::invalid_argument,
"one of the arguments passed to the ds is null or empty");
return error_code::invalid_argument;
//helper: check if at least one of the arguments is null or empty
int check_null_or_empty(const char* arg1, const char* arg2, api_status* status) {
if (!arg1 || !arg2 || strlen(arg1) == 0 || strlen(arg2) == 0) {
api_status::try_update(status, error_code::invalid_argument,
"one of the arguments passed to the ds is null or empty");
return error_code::invalid_argument;
}
return error_code::success;
}

return error_code::success;
}
int check_null_or_empty(const char* arg1, api_status* status) {
if (!arg1 || strlen(arg1) == 0) {
api_status::try_update(status, error_code::invalid_argument,
"one of the arguments passed to the ds is null or empty");
return error_code::invalid_argument;
}
return error_code::success;
}

}
10 changes: 5 additions & 5 deletions reinforcement_learning/rlclientlib/live_model_impl.h
Expand Up @@ -62,14 +62,14 @@ namespace reinforcement_learning
utility::config_collection _configuration;
error_callback_fn _error_cb;
model_management::data_callback_fn _data_cb;
std::ostringstream _buff;
logger _logger;
transport_factory_t* _t_factory;
model_factory_t* _m_factory;
std::unique_ptr<model_management::i_data_transport> _transport;
std::unique_ptr<model_management::i_model> _model;
std::unique_ptr<model_management::model_downloader> _model_download;
utility::periodic_background_proc<model_management::model_downloader> _bg_model_proc;
utility::object_pool<utility::data_buffer, utility::buffer_factory> _buffer_pool;
};

template <typename D>
Expand All @@ -78,10 +78,10 @@ namespace reinforcement_learning
api_status::try_clear(status);

// Serialize outcome
_buff.seekp(0, std::ostringstream::beg);
outcome_event::serialize(_buff, uuid, outcome_data);
_buff << std::ends;
auto sbuf = _buff.str();
utility::pooled_object_guard<utility::data_buffer, utility::buffer_factory> buffer(_buffer_pool, _buffer_pool.get_or_create());
buffer->reset();
outcome_event::serialize(*buffer.get(), uuid, outcome_data);
auto sbuf = buffer->str();

// Send the outcome event to the backend
RETURN_IF_FAIL(_logger.append_outcome(sbuf, status));
Expand Down

0 comments on commit 7af4ff4

Please sign in to comment.