Skip to content

Commit

Permalink
[multistep] Activations in client (#397)
Browse files Browse the repository at this point in the history
* [Binary parser] reward functions for ccb format (#361)

* [Binary parser] refactoring rewards (#366)

* Example gen add ccb loop for e2e testing (#368)

* [Binary parser] add e2e ccb test, compare dsjson and fb logged files create the same model (#369)

* minor binary parser cleanup (#370)

* [Binary parser] add external parser test for apprentice mode cb (#374)

* CCB apprentice reward (#373)

* [Binary parser] add metrics for cb (#371)

* [Binary parser] don't log when skip learn, more tests, skip over unknown msg type (#375)

* [binary parser] ccb skip learn (#376)

* refactor: add error message to fix config file (#377)

* Fix CI's after flatbuffer version update to 2.0 (#390)

* try set fb span minimal

* add to preprocessor definitions

* add to unit_test project file

* Revert "mac ci: continue on error true (#327)" (#385)

* Fix python build path on windows, and formatting. (#383)

* Update build_docs.yml (#391)

* only convert timestamp to string before exiting (#382)

* ntohl is a define on osx, rename the function. (#386)

* Add bunch of nice to haves CLI options and fix FB 2.0 compat. (#387)

* our build requires CMP0074 due to usage of PackageName_ROOT variables. (#393)

* our build requires CMP0074 due to usage of PackageName_ROOT variables.

* try to use cmake_policy

* Activations in multistep: first PR with schema changes only (#392)

* deferred action to multistep schema

* Multistep to problem type

* try to set cmake policy for CMP0074

* OLD -> NEW

* try default policy for cmp0074

* build fix

* flags to request_episodic_decision

* episodic decision: deferred action implementation

* report_action_taken for secondary index

* formatting fixes

Co-authored-by: cheng-tan <chengtan2013@gmail.com>
Co-authored-by: olgavrou <olgavrou@gmail.com>
Co-authored-by: Griffin Bassman <griffinbassman@gmail.com>
Co-authored-by: Eduardo Salinas <edus@microsoft.com>
Co-authored-by: zwd-ms <71728747+zwd-ms@users.noreply.github.com>
Co-authored-by: Rodrigo Kumpera <kumpera@users.noreply.github.com>
  • Loading branch information
7 people committed Aug 9, 2021
1 parent 35556dd commit 4a5b2df
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 9 deletions.
12 changes: 12 additions & 0 deletions include/live_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,17 @@ namespace reinforcement_learning {
*/
int report_action_taken(const char* event_id, api_status* status = nullptr);

/**
* @brief Report that action was taken.
*
* @param primary_id The unique primary_id used when choosing an action should be presented here. This is so that
* the action taken can be matched with feedback received.
* @param secondary_id Index of the partial outcome.
* @param status Optional field with detailed string description if there is an error
* @return int Return error code. This will also be returned in the api_status object
*/
int report_action_taken(const char* primary_id, const char* secondary_id, api_status* status = nullptr);

/**
* @brief Report the outcome for the top action.
*
Expand Down Expand Up @@ -415,6 +426,7 @@ namespace reinforcement_learning {

//multistep
int request_episodic_decision(const char* event_id, const char* previous_id, const char* context_json, ranking_response& resp, episode_state& episode, api_status* status = nullptr);
int request_episodic_decision(const char* event_id, const char* previous_id, const char* context_json, unsigned int flags, ranking_response& resp, episode_state& episode, api_status* status = nullptr);

private:
std::unique_ptr<live_model_impl> _pimpl; //! The actual implementation details are forwarded to this object (PIMPL pattern)
Expand Down
13 changes: 12 additions & 1 deletion rlclientlib/live_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,17 @@ namespace reinforcement_learning

int live_model::request_episodic_decision(const char* event_id, const char* previous_id, const char* context_json, ranking_response& resp, episode_state& episode, api_status* status) {
INIT_CHECK();
return _pimpl->request_episodic_decision(event_id, previous_id, context_json, resp, episode, status);
return _pimpl->request_episodic_decision(event_id, previous_id, context_json, action_flags::DEFAULT, resp, episode, status);
}

int live_model::request_episodic_decision(const char* event_id, const char* previous_id, const char* context_json, unsigned int flags, ranking_response& resp, episode_state& episode, api_status* status) {
INIT_CHECK();
return _pimpl->request_episodic_decision(event_id, previous_id, context_json, flags, resp, episode, status);
}

int live_model::report_action_taken(const char* primary_id, const char* secondary_id, api_status* status) {
INIT_CHECK();
return _pimpl->report_action_taken(primary_id, secondary_id, status);
}

}
11 changes: 9 additions & 2 deletions rlclientlib/live_model_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,13 @@ namespace reinforcement_learning {
return _outcome_logger->report_action_taken(event_id, status);
}

int live_model_impl::report_action_taken(const char* primary_id, const char* secondary_id, api_status* status) {
// Clear previous errors if any
api_status::try_clear(status);
// Send the outcome event to the backend
return _outcome_logger->report_action_taken(primary_id, secondary_id, status);
}

int live_model_impl::report_outcome(const char* event_id, const char* outcome, api_status* status) {
// Check arguments
RETURN_IF_FAIL(check_null_or_empty(event_id, outcome, _trace_logger.get(), status));
Expand Down Expand Up @@ -582,7 +589,7 @@ namespace reinforcement_learning {
return refresh_model(status);
}

int live_model_impl::request_episodic_decision(const char* event_id, const char* previous_id, const char* context_json, ranking_response& resp, episode_state& episode, api_status* status) {
int live_model_impl::request_episodic_decision(const char* event_id, const char* previous_id, const char* context_json, unsigned int flags, ranking_response& resp, episode_state& episode, api_status* status) {
resp.clear();
//clear previous errors if any
api_status::try_clear(status);
Expand All @@ -605,7 +612,7 @@ namespace reinforcement_learning {
resp.set_event_id(event_id);

RETURN_IF_FAIL(episode.update(event_id, previous_id, context_json, resp, status));
RETURN_IF_FAIL(_interaction_logger->log(episode.get_episode_id(), previous_id, context_patched.c_str(), resp, status));
RETURN_IF_FAIL(_interaction_logger->log(episode.get_episode_id(), previous_id, context_patched.c_str(), flags, resp, status));
return error_code::success;
}

Expand Down
3 changes: 2 additions & 1 deletion rlclientlib/live_model_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ namespace reinforcement_learning
int request_multi_slot_decision(const char* context_json, unsigned int flags, multi_slot_response_detailed& resp, const std::vector<int>& baseline_actions, api_status* status = nullptr);

int report_action_taken(const char* event_id, api_status* status);
int report_action_taken(const char* primary_id, const char *secondary_id, api_status* status);

int report_outcome(const char* event_id, const char* outcome_data, api_status* status);
int report_outcome(const char* event_id, float reward, api_status* status);
Expand All @@ -66,7 +67,7 @@ namespace reinforcement_learning
live_model_impl& operator=(const live_model_impl&) = delete;
live_model_impl& operator=(live_model_impl&&) = delete;

int request_episodic_decision(const char* event_id, const char* previous_id, const char* context_json, ranking_response& resp, episode_state& episode, api_status* status = nullptr);
int request_episodic_decision(const char* event_id, const char* previous_id, const char* context_json, unsigned int flags, ranking_response& resp, episode_state& episode, api_status* status = nullptr);

private:
// Internal implementation methods
Expand Down
11 changes: 9 additions & 2 deletions rlclientlib/logger/logger_facade.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,14 @@ namespace reinforcement_learning {
}
}

int interaction_logger_facade::log(const char* episode_id, const char* previous_id, const char* context, const ranking_response& response, api_status* status) {
int interaction_logger_facade::log(const char* episode_id, const char* previous_id, const char* context, unsigned int flags, const ranking_response& response, api_status* status) {
switch (_version) {
case 2: {
generic_event::object_list_t actions;
generic_event::payload_buffer_t payload;
event_content_type content_type;

RETURN_IF_FAIL(wrap_log_call(_ext, _multistep_serializer, context, actions, payload, content_type, status, previous_id, response));
RETURN_IF_FAIL(wrap_log_call(_ext, _multistep_serializer, context, actions, payload, content_type, status, previous_id, flags, response));
return _v2->log(episode_id, std::move(payload), _multistep_serializer.type, content_type, std::move(actions), status);
}
default: return protocol_not_supported(status);
Expand Down Expand Up @@ -248,5 +248,12 @@ namespace reinforcement_learning {
default: return protocol_not_supported(status);
}
}

int observation_logger_facade::report_action_taken(const char* primary_id, const char* secondary_id, api_status* status) {
switch (_version) {
case 2: return _v2->log(primary_id, _serializer.report_action_taken(secondary_id), _serializer.type, event_content_type::IDENTITY, status);
default: return protocol_not_supported(status);
}
}
}
}
3 changes: 2 additions & 1 deletion rlclientlib/logger/logger_facade.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ namespace reinforcement_learning
//CB v1/v2
int log(const char* context, unsigned int flags, const ranking_response& response, api_status* status, learning_mode learning_mode = ONLINE);

int log(const char* episode_id, const char* previous_id, const char* context, const ranking_response& response, api_status* status);
int log(const char* episode_id, const char* previous_id, const char* context, unsigned int flags, const ranking_response& response, api_status* status);
const multistep_serializer _multistep_serializer;
int log_decisions(std::vector<const char*>& event_ids, const char* context, unsigned int flags, const std::vector<std::vector<uint32_t>>& action_ids,
const std::vector<std::vector<float>>& pdfs, const std::string& model_version, api_status* status);
Expand Down Expand Up @@ -108,6 +108,7 @@ namespace reinforcement_learning
int log(const char* event_id, const char* index, const char* outcome, api_status* status);

int report_action_taken(const char* event_id, api_status* status);
int report_action_taken(const char* event_id, const char* index, api_status* status);
private:
const int _version;
int _serializer_shared_state;
Expand Down
13 changes: 11 additions & 2 deletions rlclientlib/serialization/payload_serializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,18 @@ namespace reinforcement_learning {
fbb.Finish(fb);
return fbb.Release();
}

static generic_event::payload_buffer_t report_action_taken(const char* index) {
flatbuffers::FlatBufferBuilder fbb;
const auto idx = fbb.CreateString(index).Union();
auto fb = v2::CreateOutcomeEvent(fbb, v2::OutcomeValue_NONE, 0, v2::IndexValue_literal, idx, true);
fbb.Finish(fb);
return fbb.Release();
}
};

struct multistep_serializer : payload_serializer<generic_event::payload_type_t::PayloadType_MultiStep> {
static generic_event::payload_buffer_t event(const char* context, const char* previous_id, const ranking_response& response) {
static generic_event::payload_buffer_t event(const char* context, const char* previous_id, unsigned int flags, const ranking_response& response) {
flatbuffers::FlatBufferBuilder fbb;
std::vector<uint64_t> action_ids;
std::vector<float> probabilities;
Expand All @@ -180,7 +188,8 @@ namespace reinforcement_learning {
std::string context_str(context);
copy(context_str.begin(), context_str.end(), std::back_inserter(_context));

auto fb = v2::CreateMultiStepEventDirect(fbb, response.get_event_id(), previous_id, &action_ids, &_context, &probabilities, response.get_model_id());
auto fb = v2::CreateMultiStepEventDirect(fbb, response.get_event_id(), previous_id, &action_ids,
&_context, &probabilities, response.get_model_id(), flags & action_flags::DEFERRED);
fbb.Finish(fb);
return fbb.Release();
}
Expand Down

0 comments on commit 4a5b2df

Please sign in to comment.