Skip to content

Commit

Permalink
chore: final fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasRidhuan committed Apr 10, 2024
1 parent 21bdef4 commit bdff6f8
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 73 deletions.
10 changes: 4 additions & 6 deletions barretenberg/cpp/pil/avm/avm_alu.pil
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ namespace avm_alu(256);
pol commit cf;

// Compute predicate telling whether there is a row entry in the ALU table.
#[SELECTOR_REL]
alu_sel = op_add + op_sub + op_mul + op_not + op_eq + op_lt + op_lte;
#[SELECTOR_REL_2]
cmp_sel = op_lt + op_lte;

// ========= Type Constraints =============================================
Expand Down Expand Up @@ -440,30 +438,30 @@ namespace avm_alu(256);
cmp_rng_ctr * ((1 - rng_chk_sel) * (1 - op_eq_diff_inv) + op_eq_diff_inv) - rng_chk_sel = 0;

// We perform a range check if we have some range checks remaining or we are performing a comparison op
pol RNG_CHK_OP = (rng_chk_sel + cmp_sel);
pol RNG_CHK_OP = rng_chk_sel + cmp_sel;

pol commit rng_chk_lookup_selector;
#[RNG_CHK_LOOKUP_SELECTOR]
rng_chk_lookup_selector = cmp_sel + rng_chk_sel;

// Perform 128-bit range check on lo part
#[LOWER_CMP_RNG_CHK]
a_lo = SUM_128 * (RNG_CHK_OP);
a_lo = SUM_128 * RNG_CHK_OP;


// Perform 128-bit range check on hi part
#[UPPER_CMP_RNG_CHK]
a_hi = (u16_r7 + u16_r8 * 2**16 +
u16_r9 * 2**32 + u16_r10 * 2**48 +
u16_r11 * 2**64 + u16_r12 * 2**80 +
u16_r13 * 2**96 + u16_r14 * 2**112) * (RNG_CHK_OP);
u16_r13 * 2**96 + u16_r14 * 2**112) * RNG_CHK_OP;

// Shift all elements "across" by 2 columns
// TODO: there is an optimisation where we are able to do 1 less range check as the range check on
// P_SUB_B is implied by the other range checks.
// Briefly: given a > b and p > a and p > a - b - 1, it is sufficient confirm that p > b without a range check
// To accomplish this we would likely change the order of the range_check so we can skip p_sub_b
#[SHIFT_RELS]
#[SHIFT_RELS_0]
(a_lo' - b_lo) * rng_chk_sel' = 0;
(a_hi' - b_hi) * rng_chk_sel' = 0;
#[SHIFT_RELS_1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,6 @@ template <typename FF> struct Avm_aluRow {
inline std::string get_relation_label_avm_alu(int index)
{
switch (index) {
case 0:
return "SELECTOR_REL";

case 1:
return "SELECTOR_REL_2";

case 10:
return "ALU_ADD_SUB_1";

Expand Down Expand Up @@ -164,7 +158,7 @@ inline std::string get_relation_label_avm_alu(int index)
return "UPPER_CMP_RNG_CHK";

case 40:
return "SHIFT_RELS";
return "SHIFT_RELS_0";

case 42:
return "SHIFT_RELS_1";
Expand Down
26 changes: 15 additions & 11 deletions barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_alu_trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,8 @@ std::tuple<uint8_t, uint8_t, std::vector<uint16_t>> AvmAluTraceBuilder::to_alu_s
}

/**
* @brief This is a helper function that is used to generate the range check entries for the comparison operation.
* This additionally increments the counts for the corresponding range lookups entries.
* @brief This is a helper function that is used to generate the range check entries for the comparison operation
* (LT/LTE opcodes). This additionally increments the counts for the corresponding range lookups entries.
* @param row The initial row where the comparison operation was performed
* @param hi_lo_limbs The vector of 128-bit limbs hi and lo pairs of limbs that will be range checked.
* @return A vector of AluTraceEntry rows for the range checks for the comparison operation.
Expand All @@ -480,28 +480,30 @@ std::vector<AvmAluTraceBuilder::AluTraceEntry> AvmAluTraceBuilder::cmp_range_che
std::vector<AvmAluTraceBuilder::AluTraceEntry> rows{ std::move(row) };
rows.resize(num_rows, {});

// We need to ensure that the number of rows is even
ASSERT(hi_lo_limbs.size() % 2 == 0);
// Now for each row, we need to unpack a pair from the hi_lo_limb array into the ALUs 8-bit and 16-bit registers
// The first row unpacks a_lo and a_hi, the second row unpacks b_lo and b_hi, and so on.
for (size_t j = 0; j < num_rows; j++) {
auto* r = &rows.at(j);
auto& r = rows.at(j);
uint256_t lo_limb = hi_lo_limbs.at(2 * j);
uint256_t hi_limb = hi_lo_limbs.at(2 * j + 1);
uint256_t limb = lo_limb + (hi_limb << 128);
// Unpack lo limb and handle in the 8-bit registers
auto [alu_u8_r0, alu_u8_r1, alu_u16_reg] = AvmAluTraceBuilder::to_alu_slice_registers(limb);
r->alu_u8_r0 = alu_u8_r0;
r->alu_u8_r1 = alu_u8_r1;
std::copy(alu_u16_reg.begin(), alu_u16_reg.end(), r->alu_u16_reg.begin());
r.alu_u8_r0 = alu_u8_r0;
r.alu_u8_r1 = alu_u8_r1;
std::copy(alu_u16_reg.begin(), alu_u16_reg.end(), r.alu_u16_reg.begin());

r->cmp_rng_ctr = j > 0 ? static_cast<uint8_t>(num_rows - j) : 0;
r->rng_chk_sel = j > 0;
r->alu_op_eq_diff_inv = j > 0 ? FF(num_rows - j).invert() : 0;
r.cmp_rng_ctr = j > 0 ? static_cast<uint8_t>(num_rows - j) : 0;
r.rng_chk_sel = j > 0;
r.alu_op_eq_diff_inv = j > 0 ? FF(num_rows - j).invert() : 0;

std::vector<FF> limb_arr = { hi_lo_limbs.begin() + static_cast<int>(2 * j), hi_lo_limbs.end() };
// Resizing here is probably suboptimal for performance, we can probably handle the shorter vectors and
// pad with zero during the finalise
limb_arr.resize(10, FF::zero());
r->hi_lo_limbs = limb_arr;
r.hi_lo_limbs = limb_arr;
}
return rows;
}
Expand Down Expand Up @@ -560,6 +562,7 @@ std::tuple<uint256_t, uint256_t, bool> gt_or_lte_witness(uint256_t const& a, uin
* @param a Left operand of the LT
* @param b Right operand of the LT
* @param clk Clock referring to the operation in the main trace.
* @param in_tag Instruction tag defining the number of bits for the LT.
*
* @return FF The boolean result of LT casted to a finite field element
*/
Expand All @@ -569,7 +572,7 @@ FF AvmAluTraceBuilder::op_lt(FF const& a, FF const& b, AvmMemoryTag in_tag, uint
bool c = uint256_t(a) < uint256_t(b);

// Note: This is counter-intuitive, to show that a < b we actually show that b > a
// The subtlely is here that the circuit is designed as a GT(x,y) circuit, therefor we swap the inputs a & b
// The subtlety is here that the circuit is designed as a GT(x,y) circuit, therefore we swap the inputs a & b
// Get the decomposition of b
auto [a_lo, a_hi] = decompose(b);
// Get the decomposition of a
Expand Down Expand Up @@ -616,6 +619,7 @@ FF AvmAluTraceBuilder::op_lt(FF const& a, FF const& b, AvmMemoryTag in_tag, uint
* @param a Left operand of the LTE
* @param b Right operand of the LTE
* @param clk Clock referring to the operation in the main trace.
* @param in_tag Instruction tag defining the number of bits for the LT.
*
* @return FF The boolean result of LTE casted to a finite field element
*/
Expand Down
77 changes: 45 additions & 32 deletions barretenberg/cpp/src/barretenberg/vm/tests/avm_comparison.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ namespace tests_avm {
using namespace bb::avm_trace;
namespace {

void common_validate_cmp(Row row,
Row alu_row,
void common_validate_cmp(Row const& row,
Row const& alu_row,
FF const& a,
FF const& b,
FF const& c,
Expand Down Expand Up @@ -49,7 +49,7 @@ void common_validate_cmp(Row row,

// Check the instruction tags
EXPECT_EQ(row.avm_main_r_in_tag, FF(static_cast<uint32_t>(tag)));
EXPECT_EQ(row.avm_main_w_in_tag, FF(1));
EXPECT_EQ(row.avm_main_w_in_tag, FF(static_cast<uint32_t>(AvmMemoryTag::U8)));

// Check that intermediate registers are correctly copied in Alu trace
EXPECT_EQ(alu_row.avm_alu_ia, a);
Expand Down Expand Up @@ -77,15 +77,16 @@ std::vector<ThreeOpParam> positive_op_lte_test_values = {
FF(1) } }
};

std::vector<ThreeOpParamRow> gen_three_op_params(std::vector<ThreeOpParam> operands, std::vector<AvmMemoryTag> mem_tags)
std::vector<ThreeOpParamRow> gen_three_op_params(std::vector<ThreeOpParam> operands,
std::vector<AvmMemoryTag> mem_tag_arr)
{
std::vector<ThreeOpParamRow> params;
for (size_t i = 0; i < 5; i++) {
params.emplace_back(operands[i], mem_tags[i]);
params.emplace_back(operands[i], mem_tag_arr[i]);
}
return params;
}
std::vector<AvmMemoryTag> mem_tags{
std::vector<AvmMemoryTag> mem_tag_arr{
{ AvmMemoryTag::U8, AvmMemoryTag::U16, AvmMemoryTag::U32, AvmMemoryTag::U64, AvmMemoryTag::U128 }
};
class AvmCmpTests : public ::testing::Test {
Expand All @@ -108,35 +109,44 @@ TEST_P(AvmCmpTestsLT, ParamTest)
{
const auto [params, mem_tag] = GetParam();
const auto [a, b, c] = params;
trace_builder.calldata_copy(0, 0, 3, 0, std::vector<FF>{ a, b, c });
if (mem_tag == AvmMemoryTag::FF) {
trace_builder.calldata_copy(0, 0, 2, 0, std::vector<FF>{ a, b });
} else {
trace_builder.op_set(0, uint128_t(a), 0, mem_tag);
trace_builder.op_set(0, uint128_t(b), 1, mem_tag);
}
trace_builder.op_lt(0, 0, 1, 2, mem_tag);
trace_builder.return_op(0, 0, 0);
auto trace = trace_builder.finalize();
// Validate the trace

// Get the row in the avm with the LT selector set
auto row = std::ranges::find_if(trace.begin(), trace.end(), [](Row r) { return r.avm_main_sel_op_lt == FF(1); });

// Use the row in the main trace to find the same operation in the alu trace.
FF clk = row->avm_main_clk;
auto alu_row = std::ranges::find_if(
trace.begin(), trace.end(), [clk](Row r) { return r.avm_alu_clk == clk && r.avm_alu_op_lt; });
trace.begin(), trace.end(), [clk](Row r) { return r.avm_alu_clk == clk && r.avm_alu_op_lt == FF(1); });
// Check that both rows were found
EXPECT_TRUE(row != trace.end());
EXPECT_TRUE(alu_row != trace.end());
common_validate_cmp(*row, *alu_row, a, b, c, FF(0), FF(1), FF(2), AvmMemoryTag::FF);
ASSERT_TRUE(row != trace.end());
ASSERT_TRUE(alu_row != trace.end());
common_validate_cmp(*row, *alu_row, a, b, c, FF(0), FF(1), FF(2), mem_tag);

validate_trace_proof(std::move(trace));
}
INSTANTIATE_TEST_SUITE_P(AvmCmpTests,
AvmCmpTestsLT,
testing::ValuesIn(gen_three_op_params(positive_op_lt_test_values, mem_tags)));
testing::ValuesIn(gen_three_op_params(positive_op_lt_test_values, mem_tag_arr)));

TEST_P(AvmCmpTestsLTE, ParamTest)
{
const auto [params, mem_tag] = GetParam();
const auto [a, b, c] = params;
trace_builder.calldata_copy(0, 0, 3, 0, std::vector<FF>{ a, b, c });
if (mem_tag == AvmMemoryTag::FF) {
trace_builder.calldata_copy(0, 0, 2, 0, std::vector<FF>{ a, b });
} else {
trace_builder.op_set(0, uint128_t(a), 0, mem_tag);
trace_builder.op_set(0, uint128_t(b), 1, mem_tag);
}
trace_builder.op_lte(0, 0, 1, 2, mem_tag);
trace_builder.return_op(0, 0, 0);
auto trace = trace_builder.finalize();
Expand All @@ -147,14 +157,14 @@ TEST_P(AvmCmpTestsLTE, ParamTest)
auto alu_row = std::ranges::find_if(
trace.begin(), trace.end(), [clk](Row r) { return r.avm_alu_clk == clk && r.avm_alu_op_lte; });
// Check that both rows were found
EXPECT_TRUE(row != trace.end());
EXPECT_TRUE(alu_row != trace.end());
common_validate_cmp(*row, *alu_row, a, b, c, FF(0), FF(1), FF(2), AvmMemoryTag::FF);
ASSERT_TRUE(row != trace.end());
ASSERT_TRUE(alu_row != trace.end());
common_validate_cmp(*row, *alu_row, a, b, c, FF(0), FF(1), FF(2), mem_tag);
validate_trace_proof(std::move(trace));
}
INSTANTIATE_TEST_SUITE_P(AvmCmpTests,
AvmCmpTestsLTE,
testing::ValuesIn(gen_three_op_params(positive_op_lte_test_values, mem_tags)));
testing::ValuesIn(gen_three_op_params(positive_op_lte_test_values, mem_tag_arr)));

/******************************************************************************
*
Expand All @@ -178,19 +188,17 @@ std::vector<std::tuple<std::string, CMP_FAILURES>> cmp_failures = {
{ "RES_HI", CMP_FAILURES::ResHiCheckFailed },
{ "CMP_CTR_REL_2", CMP_FAILURES::CounterRelationFailed },
{ "CTR_NON_ZERO_REL", CMP_FAILURES::CounterNonZeroCheckFailed },
{ "SHIFT_RELS", CMP_FAILURES::ShiftRelationFailed },
{ "SHIFT_RELS_0", CMP_FAILURES::ShiftRelationFailed },
{ "LOOKUP_U16_0", CMP_FAILURES::RangeCheckFailed },

};
std::vector<ThreeOpParam> neg_test_lt = { { 12023, 439321, 0 } };
std::vector<ThreeOpParam> neg_test_lte = { { 12023, 12023, 0 } };
std::vector<ThreeOpParam> neg_test_lt = { { FF::modulus - 1, FF::modulus_minus_two, 0 } };
std::vector<ThreeOpParam> neg_test_lte = { { FF::modulus - 1, FF::modulus - 1, 0 } };

using EXPECTED_ERRORS = std::tuple<std::string, CMP_FAILURES>;

std::vector<Row> gen_mutated_trace_cmp(std::vector<Row> trace,
std::function<bool(Row)>&& select_row,
FF c_mutated,
CMP_FAILURES fail_mode)
std::vector<Row> gen_mutated_trace_cmp(
std::vector<Row> trace, std::function<bool(Row)> select_row, FF c_mutated, CMP_FAILURES fail_mode, bool is_lte)
{
auto main_trace_row = std::ranges::find_if(trace.begin(), trace.end(), select_row);
auto main_clk = main_trace_row->avm_main_clk;
Expand All @@ -202,7 +210,7 @@ std::vector<Row> gen_mutated_trace_cmp(std::vector<Row> trace,
std::ranges::find_if(trace.begin(), trace.end(), [](Row r) { return r.avm_alu_cmp_rng_ctr > FF(0); });
switch (fail_mode) {
case IncorrectInputDecomposition:
alu_row->avm_alu_a_lo = FF(0);
alu_row->avm_alu_a_lo = alu_row->avm_alu_a_lo + FF(1);
break;
case SubLoCheckFailed:
alu_row->avm_alu_p_a_borrow = FF::one() - alu_row->avm_alu_p_a_borrow;
Expand Down Expand Up @@ -233,11 +241,16 @@ std::vector<Row> gen_mutated_trace_cmp(std::vector<Row> trace,
mutate_ic_in_trace(trace, std::move(select_row), c_mutated, true);

// Now we have to also update the value of res_lo = (A_SUB_B_LO * IS_GT + B_SUB_A_LO * (1 - IS_GT))
// to be B_SUB_A_LO
alu_row->avm_alu_borrow = FF(0);
FF mutated_res_lo =
alu_row->avm_alu_b_lo - alu_row->avm_alu_a_lo + alu_row->avm_alu_borrow * (uint256_t(1) << 128);
FF mutated_res_hi = alu_row->avm_alu_b_hi - alu_row->avm_alu_a_hi - alu_row->avm_alu_borrow;

if (is_lte) {
mutated_res_lo = alu_row->avm_alu_a_lo - alu_row->avm_alu_b_lo - FF::one() +
alu_row->avm_alu_borrow * (uint256_t(1) << 128);
mutated_res_hi = alu_row->avm_alu_a_hi - alu_row->avm_alu_b_hi - alu_row->avm_alu_borrow;
}
alu_row->avm_alu_res_lo = mutated_res_lo;
alu_row->avm_alu_res_hi = mutated_res_hi;
// For each subsequent row that involve the range check, we need to update the shifted values
Expand Down Expand Up @@ -304,8 +317,8 @@ TEST_P(AvmCmpNegativeTestsLT, ParamTest)
trace_builder.op_lt(0, 0, 1, 2, AvmMemoryTag::FF);
trace_builder.return_op(0, 0, 0);
auto trace = trace_builder.finalize();
std::function<bool(Row)>&& select_row = [](Row r) { return r.avm_main_sel_op_lt == FF(1); };
trace = gen_mutated_trace_cmp(trace, std::move(select_row), output, failure_mode);
std::function<bool(Row)> select_row = [](Row r) { return r.avm_main_sel_op_lt == FF(1); };
trace = gen_mutated_trace_cmp(trace, select_row, output, failure_mode, false);
EXPECT_THROW_WITH_MESSAGE(validate_trace_proof(std::move(trace)), failure_string);
}

Expand All @@ -320,11 +333,11 @@ TEST_P(AvmCmpNegativeTestsLTE, ParamTest)
const auto [a, b, output] = params;
auto trace_builder = avm_trace::AvmTraceBuilder();
trace_builder.calldata_copy(0, 0, 3, 0, std::vector<FF>{ a, b, output });
trace_builder.op_lt(0, 0, 1, 2, AvmMemoryTag::FF);
trace_builder.op_lte(0, 0, 1, 2, AvmMemoryTag::FF);
trace_builder.return_op(0, 0, 0);
auto trace = trace_builder.finalize();
std::function<bool(Row)>&& select_row = [](Row r) { return r.avm_main_sel_op_lt == FF(1); };
trace = gen_mutated_trace_cmp(trace, std::move(select_row), output, failure_mode);
std::function<bool(Row)> select_row = [](Row r) { return r.avm_main_sel_op_lte == FF(1); };
trace = gen_mutated_trace_cmp(trace, select_row, output, failure_mode, true);
EXPECT_THROW_WITH_MESSAGE(validate_trace_proof(std::move(trace)), failure_string);
}
INSTANTIATE_TEST_SUITE_P(AvmCmpNegativeTests,
Expand Down

0 comments on commit bdff6f8

Please sign in to comment.