Skip to content

Commit

Permalink
test: add mem_tags and simplify counter clearing
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasRidhuan committed Apr 10, 2024
1 parent f8fdeeb commit e0f4bdb
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 90 deletions.
59 changes: 39 additions & 20 deletions barretenberg/cpp/src/barretenberg/vm/tests/avm_comparison.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,17 @@ void common_validate_cmp(Row row,
EXPECT_EQ(alu_row.avm_alu_ic, c);
}
} // namespace
using ThreeOpParamRow = std::array<FF, 3>;
std::vector<ThreeOpParamRow> positive_op_lt_test_values = { { { FF(1), FF(1), FF(0) },
{ FF(5323), FF(321), FF(0) },
{ FF(13793), FF(10590617LLU), FF(1) },
{ FF(0x7bff744e3cdf79LLU), FF(0x14ccccccccb6LLU), FF(0) },
{ FF(uint256_t{ 0xb900000000000001 }),
FF(uint256_t{ 0x1006021301080000 } << 64) +
uint256_t{ 0x000000000000001080876844827 },
1 } } };
std::vector<ThreeOpParamRow> positive_op_lte_test_values = {
using ThreeOpParam = std::array<FF, 3>;
using ThreeOpParamRow = std::tuple<ThreeOpParam, AvmMemoryTag>;
std::vector<ThreeOpParam> positive_op_lt_test_values = { { { FF(1), FF(1), FF(0) },
{ FF(5323), FF(321), FF(0) },
{ FF(13793), FF(10590617LLU), FF(1) },
{ FF(0x7bff744e3cdf79LLU), FF(0x14ccccccccb6LLU), FF(0) },
{ FF(uint256_t{ 0xb900000000000001 }),
FF(uint256_t{ 0x1006021301080000 } << 64) +
uint256_t{ 0x000000000000001080876844827 },
1 } } };
std::vector<ThreeOpParam> positive_op_lte_test_values = {
{ { FF(1), FF(1), FF(1) },
{ FF(5323), FF(321), FF(0) },
{ FF(13793), FF(10590617LLU), FF(1) },
Expand All @@ -75,6 +76,18 @@ std::vector<ThreeOpParamRow> positive_op_lte_test_values = {
FF(uint256_t{ 0x1006021301080000 } << 64) + uint256_t{ 0x000000000000001080876844827 },
FF(1) } }
};

std::vector<ThreeOpParamRow> gen_three_op_params(std::vector<ThreeOpParam> operands, std::vector<AvmMemoryTag> mem_tags)
{
std::vector<ThreeOpParamRow> params;
for (size_t i = 0; i < 5; i++) {
params.emplace_back(operands[i], mem_tags[i]);
}
return params;
}
std::vector<AvmMemoryTag> mem_tags{
{ AvmMemoryTag::U8, AvmMemoryTag::U16, AvmMemoryTag::U32, AvmMemoryTag::U64, AvmMemoryTag::U128 }
};
class AvmCmpTests : public ::testing::Test {
public:
AvmTraceBuilder trace_builder;
Expand All @@ -93,9 +106,10 @@ class AvmCmpTestsLTE : public AvmCmpTests, public testing::WithParamInterface<Th
******************************************************************************/
TEST_P(AvmCmpTestsLT, ParamTest)
{
const auto [a, b, c] = GetParam();
const auto [params, mem_tag] = GetParam();
const auto [a, b, c] = params;
trace_builder.calldata_copy(0, 0, 3, 0, std::vector<FF>{ a, b, c });
trace_builder.op_lt(0, 0, 1, 2, AvmMemoryTag::FF); // [1,254,0,0,....]
trace_builder.op_lt(0, 0, 1, 2, mem_tag);
trace_builder.return_op(0, 0, 0);
auto trace = trace_builder.finalize();
// Validate the trace
Expand All @@ -114,13 +128,16 @@ TEST_P(AvmCmpTestsLT, ParamTest)

validate_trace_proof(std::move(trace));
}
INSTANTIATE_TEST_SUITE_P(AvmCmpTests, AvmCmpTestsLT, testing::ValuesIn(positive_op_lt_test_values));
INSTANTIATE_TEST_SUITE_P(AvmCmpTests,
AvmCmpTestsLT,
testing::ValuesIn(gen_three_op_params(positive_op_lt_test_values, mem_tags)));

TEST_P(AvmCmpTestsLTE, ParamTest)
{
const auto [a, b, c] = GetParam();
const auto [params, mem_tag] = GetParam();
const auto [a, b, c] = params;
trace_builder.calldata_copy(0, 0, 3, 0, std::vector<FF>{ a, b, c });
trace_builder.op_lte(0, 0, 1, 2, AvmMemoryTag::FF); // [1,254,0,0,....]
trace_builder.op_lte(0, 0, 1, 2, mem_tag);
trace_builder.return_op(0, 0, 0);
auto trace = trace_builder.finalize();
auto row = std::ranges::find_if(trace.begin(), trace.end(), [](Row r) { return r.avm_main_sel_op_lte == FF(1); });
Expand All @@ -135,7 +152,9 @@ TEST_P(AvmCmpTestsLTE, ParamTest)
common_validate_cmp(*row, *alu_row, a, b, c, FF(0), FF(1), FF(2), AvmMemoryTag::FF);
validate_trace_proof(std::move(trace));
}
INSTANTIATE_TEST_SUITE_P(AvmCmpTests, AvmCmpTestsLTE, testing::ValuesIn(positive_op_lte_test_values));
INSTANTIATE_TEST_SUITE_P(AvmCmpTests,
AvmCmpTestsLTE,
testing::ValuesIn(gen_three_op_params(positive_op_lte_test_values, mem_tags)));

/******************************************************************************
*
Expand Down Expand Up @@ -163,8 +182,8 @@ std::vector<std::tuple<std::string, CMP_FAILURES>> cmp_failures = {
{ "LOOKUP_U16_0", CMP_FAILURES::RangeCheckFailed },

};
std::vector<ThreeOpParamRow> neg_test_lt = { { 12023, 439321, 0 } };
std::vector<ThreeOpParamRow> neg_test_lte = { { 12023, 12023, 0 } };
std::vector<ThreeOpParam> neg_test_lt = { { 12023, 439321, 0 } };
std::vector<ThreeOpParam> neg_test_lte = { { 12023, 12023, 0 } };

using EXPECTED_ERRORS = std::tuple<std::string, CMP_FAILURES>;

Expand Down Expand Up @@ -271,9 +290,9 @@ std::vector<Row> gen_mutated_trace_cmp(std::vector<Row> trace,
return trace;
}
class AvmCmpNegativeTestsLT : public AvmCmpTests,
public testing::WithParamInterface<std::tuple<EXPECTED_ERRORS, ThreeOpParamRow>> {};
public testing::WithParamInterface<std::tuple<EXPECTED_ERRORS, ThreeOpParam>> {};
class AvmCmpNegativeTestsLTE : public AvmCmpTests,
public testing::WithParamInterface<std::tuple<EXPECTED_ERRORS, ThreeOpParamRow>> {};
public testing::WithParamInterface<std::tuple<EXPECTED_ERRORS, ThreeOpParam>> {};

TEST_P(AvmCmpNegativeTestsLT, ParamTest)
{
Expand Down
89 changes: 19 additions & 70 deletions barretenberg/cpp/src/barretenberg/vm/tests/helpers.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,154 +83,103 @@ void mutate_ic_in_trace(std::vector<Row>& trace, std::function<bool(Row)>&& sele
void clear_range_check_counters(std::vector<Row>& trace, uint256_t previous_value)
{
// Find the main row where the old u8 value in the first register is looked up
FF lookup_value = static_cast<uint8_t>(previous_value);
auto lookup_row = std::ranges::find_if(trace.begin(), trace.end(), [lookup_value](Row r) {
return r.avm_main_clk == lookup_value && r.avm_main_sel_rng_8 == FF(1);
});
size_t lookup_value = static_cast<uint8_t>(previous_value);
// Decrement the counter
lookup_row->lookup_u8_0_counts = lookup_row->lookup_u8_0_counts - 1;
trace.at(lookup_value).lookup_u8_0_counts = trace.at(lookup_value).lookup_u8_0_counts - 1;
previous_value >>= 8;
lookup_value = static_cast<uint8_t>(previous_value);
// Find the main row where the old u16 value in the second register is looked up
lookup_row = std::ranges::find_if(trace.begin(), trace.end(), [lookup_value](Row r) {
return r.avm_main_clk == lookup_value && r.avm_main_sel_rng_8 == FF(1);
});
// Decrement the counter
lookup_row->lookup_u8_1_counts = lookup_row->lookup_u8_1_counts - 1;
trace.at(lookup_value).lookup_u8_1_counts = trace.at(lookup_value).lookup_u8_1_counts - 1;
previous_value >>= 8;

// U_16_0: Find the main row where the old u16 value in the first register is looked up
lookup_value = static_cast<uint16_t>(previous_value);
lookup_row = std::ranges::find_if(trace.begin(), trace.end(), [lookup_value](Row r) {
return r.avm_main_clk == lookup_value && r.avm_main_sel_rng_16 == FF(1);
});
// Decrement the counter
lookup_row->lookup_u16_0_counts = lookup_row->lookup_u16_0_counts - 1;
trace.at(lookup_value).lookup_u16_0_counts = trace.at(lookup_value).lookup_u16_0_counts - 1;
previous_value >>= 16;

// U_16_1: Find the main row where the old u16 value in the second register is looked up
lookup_value = static_cast<uint16_t>(previous_value);
lookup_row = std::ranges::find_if(trace.begin(), trace.end(), [lookup_value](Row r) {
return r.avm_main_clk == lookup_value && r.avm_main_sel_rng_16 == FF(1);
});
// Decrement the counter
lookup_row->lookup_u16_1_counts = lookup_row->lookup_u16_1_counts - 1;
trace.at(lookup_value).lookup_u16_1_counts = trace.at(lookup_value).lookup_u16_1_counts - 1;
previous_value >>= 16;

// U_16_2: Find the main row where the old u16 value in the second register is looked up
lookup_value = static_cast<uint16_t>(previous_value);
lookup_row = std::ranges::find_if(trace.begin(), trace.end(), [lookup_value](Row r) {
return r.avm_main_clk == lookup_value && r.avm_main_sel_rng_16 == FF(1);
});
// Decrement the counter
lookup_row->lookup_u16_2_counts = lookup_row->lookup_u16_2_counts - 1;
trace.at(lookup_value).lookup_u16_2_counts = trace.at(lookup_value).lookup_u16_2_counts - 1;
previous_value >>= 16;

// U_16_3: Find the main row where the old u16 value in the second register is looked up
lookup_value = static_cast<uint16_t>(previous_value);
lookup_row = std::ranges::find_if(trace.begin(), trace.end(), [lookup_value](Row r) {
return r.avm_main_clk == lookup_value && r.avm_main_sel_rng_16 == FF(1);
});
// Decrement the counter
lookup_row->lookup_u16_3_counts = lookup_row->lookup_u16_3_counts - 1;
trace.at(lookup_value).lookup_u16_3_counts = trace.at(lookup_value).lookup_u16_3_counts - 1;
previous_value >>= 16;

// U_16_4: Find the main row where the old u16 value in the second register is looked up
lookup_value = static_cast<uint16_t>(previous_value);
lookup_row = std::ranges::find_if(trace.begin(), trace.end(), [lookup_value](Row r) {
return r.avm_main_clk == lookup_value && r.avm_main_sel_rng_16 == FF(1);
});
// Decrement the counter
lookup_row->lookup_u16_4_counts = lookup_row->lookup_u16_4_counts - 1;
trace.at(lookup_value).lookup_u16_4_counts = trace.at(lookup_value).lookup_u16_4_counts - 1;
previous_value >>= 16;

// U_16_5: Find the main row where the old u16 value in the second register is looked up
lookup_value = static_cast<uint16_t>(previous_value);
lookup_row = std::ranges::find_if(trace.begin(), trace.end(), [lookup_value](Row r) {
return r.avm_main_clk == lookup_value && r.avm_main_sel_rng_16 == FF(1);
});
// Decrement the counter
lookup_row->lookup_u16_5_counts = lookup_row->lookup_u16_5_counts - 1;
trace.at(lookup_value).lookup_u16_5_counts = trace.at(lookup_value).lookup_u16_5_counts - 1;
previous_value >>= 16;

// U_16_6: Find the main row where the old u16 value in the second register is looked up
lookup_value = static_cast<uint16_t>(previous_value);
lookup_row = std::ranges::find_if(trace.begin(), trace.end(), [lookup_value](Row r) {
return r.avm_main_clk == lookup_value && r.avm_main_sel_rng_16 == FF(1);
});
// Decrement the counter
lookup_row->lookup_u16_6_counts = lookup_row->lookup_u16_6_counts - 1;
trace.at(lookup_value).lookup_u16_6_counts = trace.at(lookup_value).lookup_u16_6_counts - 1;
previous_value >>= 16;

// U_16_7: Find the main row where the old u16 value in the second register is looked up
lookup_value = static_cast<uint16_t>(previous_value);
lookup_row = std::ranges::find_if(trace.begin(), trace.end(), [lookup_value](Row r) {
return r.avm_main_clk == lookup_value && r.avm_main_sel_rng_16 == FF(1);
});
// Decrement the counter
lookup_row->lookup_u16_7_counts = lookup_row->lookup_u16_7_counts - 1;
trace.at(lookup_value).lookup_u16_7_counts = trace.at(lookup_value).lookup_u16_7_counts - 1;
previous_value >>= 16;

// U_16_8: Find the main row where the old u16 value in the second register is looked up
lookup_value = static_cast<uint16_t>(previous_value);
lookup_row = std::ranges::find_if(trace.begin(), trace.end(), [lookup_value](Row r) {
return r.avm_main_clk == lookup_value && r.avm_main_sel_rng_16 == FF(1);
});
// Decrement the counter
lookup_row->lookup_u16_8_counts = lookup_row->lookup_u16_8_counts - 1;
trace.at(lookup_value).lookup_u16_8_counts = trace.at(lookup_value).lookup_u16_8_counts - 1;
previous_value >>= 16;

// U_16_9: Find the main row where the old u16 value in the second register is looked up
lookup_value = static_cast<uint16_t>(previous_value);
lookup_row = std::ranges::find_if(trace.begin(), trace.end(), [lookup_value](Row r) {
return r.avm_main_clk == lookup_value && r.avm_main_sel_rng_16 == FF(1);
});
// Decrement the counter
lookup_row->lookup_u16_9_counts = lookup_row->lookup_u16_9_counts - 1;
trace.at(lookup_value).lookup_u16_9_counts = trace.at(lookup_value).lookup_u16_9_counts - 1;
previous_value >>= 16;

// U_16_10: Find the main row where the old u16 value in the second register is looked up
lookup_value = static_cast<uint16_t>(previous_value);
lookup_row = std::ranges::find_if(trace.begin(), trace.end(), [lookup_value](Row r) {
return r.avm_main_clk == lookup_value && r.avm_main_sel_rng_16 == FF(1);
});
// Decrement the counter
lookup_row->lookup_u16_10_counts = lookup_row->lookup_u16_10_counts - 1;
trace.at(lookup_value).lookup_u16_10_counts = trace.at(lookup_value).lookup_u16_10_counts - 1;
previous_value >>= 16;

// U_16_11: Find the main row where the old u16 value in the second register is looked up
lookup_value = static_cast<uint16_t>(previous_value);
lookup_row = std::ranges::find_if(trace.begin(), trace.end(), [lookup_value](Row r) {
return r.avm_main_clk == lookup_value && r.avm_main_sel_rng_16 == FF(1);
});
// Decrement the counter
lookup_row->lookup_u16_11_counts = lookup_row->lookup_u16_11_counts - 1;
trace.at(lookup_value).lookup_u16_11_counts = trace.at(lookup_value).lookup_u16_11_counts - 1;
previous_value >>= 16;

// U_16_12: Find the main row where the old u16 value in the second register is looked up
lookup_value = static_cast<uint16_t>(previous_value);
lookup_row = std::ranges::find_if(trace.begin(), trace.end(), [lookup_value](Row r) {
return r.avm_main_clk == lookup_value && r.avm_main_sel_rng_16 == FF(1);
});
// Decrement the counter
lookup_row->lookup_u16_12_counts = lookup_row->lookup_u16_12_counts - 1;
trace.at(lookup_value).lookup_u16_12_counts = trace.at(lookup_value).lookup_u16_12_counts - 1;
previous_value >>= 16;

// U_16_13: Find the main row where the old u16 value in the second register is looked up
lookup_value = static_cast<uint16_t>(previous_value);
lookup_row = std::ranges::find_if(trace.begin(), trace.end(), [lookup_value](Row r) {
return r.avm_main_clk == lookup_value && r.avm_main_sel_rng_16 == FF(1);
});
// Decrement the counter
lookup_row->lookup_u16_13_counts = lookup_row->lookup_u16_13_counts - 1;
trace.at(lookup_value).lookup_u16_13_counts = trace.at(lookup_value).lookup_u16_13_counts - 1;
previous_value >>= 16;

// U_16_14: Find the main row where the old u16 value in the second register is looked up
lookup_value = static_cast<uint16_t>(previous_value);
lookup_row = std::ranges::find_if(trace.begin(), trace.end(), [lookup_value](Row r) {
return r.avm_main_clk == lookup_value && r.avm_main_sel_rng_16 == FF(1);
});
// Decrement the counter
lookup_row->lookup_u16_14_counts = lookup_row->lookup_u16_14_counts - 1;
trace.at(lookup_value).lookup_u16_14_counts = trace.at(lookup_value).lookup_u16_14_counts - 1;
previous_value >>= 16;
}
} // namespace tests_avm

0 comments on commit e0f4bdb

Please sign in to comment.