Skip to content

Commit

Permalink
[Binary parser] add test that compares json and fb joined logs's outp…
Browse files Browse the repository at this point in the history
…uts (#323)
  • Loading branch information
olgavrou committed May 19, 2021
1 parent 895134e commit 13950d0
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 13 deletions.
6 changes: 4 additions & 2 deletions external_parser/example_joiner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,9 @@ void example_joiner::clear_vw_examples(v_array<example *> &examples) {

void example_joiner::clear_event_id_batch_info(const std::string &id) {
_batch_grouped_events.erase(id);
_batch_event_order.pop();
if (!_batch_event_order.empty() && _batch_event_order.front() == id) {
_batch_event_order.pop();
}
_batch_grouped_examples.erase(id);
}

Expand Down Expand Up @@ -473,7 +475,7 @@ bool example_joiner::process_joined(v_array<example *> &examples) {
return true;
}

auto &id = _batch_event_order.front();
auto id = _batch_event_order.front();
bool multiline = false;
float reward = _default_reward;
// original reward is used to record the observed reward of apprentice mode
Expand Down
1 change: 0 additions & 1 deletion external_parser/parse_example_binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ bool read_payload_type(io_buf *input, unsigned int &payload_type) {
// when we are trying to fetch the next payload and we find out that there
// is nothing left to read the file doesn't have to necessarily contain an
// EOF
VW::io::logger::log_info("Reached end of file");
payload_type = MSG_TYPE_EOF;
return true;
}
Expand Down
27 changes: 19 additions & 8 deletions external_parser/unit_tests/test_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,17 @@
namespace v2 = reinforcement_learning::messages::flatbuff::v2;

void clear_examples(v_array<example *> &examples, vw *vw) {
for (auto *ex : examples) {
VW::finish_example(*vw, *ex);
if (vw->l->is_multiline) {
multi_ex multi_exs;
for (auto *ex : examples) {
multi_exs.push_back(ex);
}
vw->finish_example(multi_exs);
multi_exs.clear();
} else {
for (auto *ex : examples) {
VW::finish_example(*vw, *ex);
}
}
examples.clear();
}
Expand All @@ -17,7 +26,7 @@ void set_buffer_as_vw_input(const std::vector<char> &buffer, vw *vw) {
vw->example_parser->input.reset();
vw->example_parser->input = std::move(reader_view_of_buffer);
vw->example_parser->input->add_file(
VW::io::create_buffer_view(buffer.data(), buffer.size()));
VW::io::create_buffer_view(buffer.data(), buffer.size()));
}

std::vector<char> read_file(std::string file_name) {
Expand All @@ -30,6 +39,7 @@ std::vector<char> read_file(std::string file_name) {

std::vector<char> buffer(size);
file.read(buffer.data(), size);
file.close();
return buffer;
}

Expand All @@ -45,9 +55,9 @@ std::string get_test_files_location() {
}
}

std::vector<const v2::JoinedEvent *>
wrap_into_joined_events(std::vector<char> &buffer,
std::vector<flatbuffers::DetachedBuffer> &detached_buffers) {
std::vector<const v2::JoinedEvent *> wrap_into_joined_events(
std::vector<char> &buffer,
std::vector<flatbuffers::DetachedBuffer> &detached_buffers) {
flatbuffers::FlatBufferBuilder fbb;

// if file is smaller than preamble size then fail
Expand All @@ -60,7 +70,7 @@ wrap_into_joined_events(std::vector<char> &buffer,

BOOST_REQUIRE_GE(event_batch->events()->size(), 1);

std::vector<const v2::JoinedEvent *> event_list {};
std::vector<const v2::JoinedEvent *> event_list{};

int day = 30;
v2::TimeStamp ts(2020, 3, day, 10, 20, 30, 0);
Expand All @@ -77,7 +87,8 @@ wrap_into_joined_events(std::vector<char> &buffer,

fbb.Finish(fb);
detached_buffers.push_back(fbb.Release());
const v2::JoinedEvent *je = flatbuffers::GetRoot<v2::JoinedEvent>(detached_buffers[i].data());
const v2::JoinedEvent *je =
flatbuffers::GetRoot<v2::JoinedEvent>(detached_buffers[i].data());
event_list.push_back(je);
}

Expand Down
1 change: 1 addition & 0 deletions external_parser/unit_tests/test_files/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Residing under `valid_joined_logs`

- cb_simple.log: generated by running `python joiner.py --problem_type_config 1` on the above files (cb_v2.fb, f-reward_v2.fb) and renaming the resulting default `merged.log`
- cb_dedup_compressed.log: generated by running `python joiner.py --problem_type_config 1` with the files (cb_v2_dedup.fb, f-reward_v2.fb) and renaming the resulting default `merged.log`
- average_reward_100_interactions.[fb|json]: generated by running `./example_gen --kind cb-loop --random_reward --count 100` and performing binary joining and dsjson joining (with average reward) to generate each file (`.fb` and `.json`)

### invalid joined logs

Expand Down
4 changes: 4 additions & 0 deletions external_parser/unit_tests/test_files/test_outputs/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Ignore everything in this directory
*
# Except this file
!.gitignore
Binary file not shown.

Large diffs are not rendered by default.

56 changes: 54 additions & 2 deletions external_parser/unit_tests/test_vw_external_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ BOOST_AUTO_TEST_CASE(cb_simple) {

v_array<example *> examples;
examples.push_back(&VW::get_unused_example(vw));

set_buffer_as_vw_input(buffer, vw);

bool read_payload = false;
Expand Down Expand Up @@ -42,7 +42,8 @@ BOOST_AUTO_TEST_CASE(cb_simple) {
BOOST_AUTO_TEST_CASE(cb_dedup_compressed) {
std::string input_files = get_test_files_location();

auto buffer = read_file(input_files + "/valid_joined_logs/cb_dedup_compressed.log");
auto buffer =
read_file(input_files + "/valid_joined_logs/cb_dedup_compressed.log");

auto vw = VW::initialize("--cb_explore_adf --binary_parser --quiet", nullptr,
false, nullptr, nullptr);
Expand Down Expand Up @@ -73,4 +74,55 @@ BOOST_AUTO_TEST_CASE(cb_dedup_compressed) {

clear_examples(examples, vw);
VW::finish(*vw);
}

BOOST_AUTO_TEST_CASE(compare_dsjson_with_fb_models) {
std::string input_files = get_test_files_location();

std::string fb_model = input_files + "/test_outputs/m_average_fb.model";
std::string dsjson_model =
input_files + "/test_outputs/m_average_dsjson.model";

std::remove(fb_model.c_str());
std::remove(dsjson_model.c_str());

{
// run with flatbuffer joined logs
auto full_file_name =
input_files + "/valid_joined_logs/average_reward_100_interactions.fb";

auto vw = VW::initialize("--cb_explore_adf --binary_parser --quiet -f " +
fb_model + " -d " + full_file_name,
nullptr, false, nullptr, nullptr);

VW::start_parser(*vw);
VW::LEARNER::generic_driver(*vw);
VW::end_parser(*vw);

VW::finish(*vw);
}

{
// run with json joined logs
auto full_file_name =
input_files + "/valid_joined_logs/average_reward_100_interactions.json";

auto vw = VW::initialize("--cb_explore_adf --dsjson --quiet -f " +
dsjson_model + " -d " + full_file_name,
nullptr, false, nullptr, nullptr);

VW::start_parser(*vw);
VW::LEARNER::generic_driver(*vw);
VW::end_parser(*vw);

VW::finish(*vw);
}

// read the models and compare
auto buffer_fb_model = read_file(fb_model);
auto buffer_dsjson_model = read_file(dsjson_model);

BOOST_CHECK_EQUAL_COLLECTIONS(buffer_fb_model.begin(), buffer_fb_model.end(),
buffer_dsjson_model.begin(),
buffer_dsjson_model.end());
}

0 comments on commit 13950d0

Please sign in to comment.