Skip to content

Commit

Permalink
fix: StandardCircuitBuilder create_logic_constraint and uint logic_op…
Browse files Browse the repository at this point in the history
…erator (#4530)

Fixes create_logic_constraint in StandardCircuitBuilder. Previously it
didn't constrain the quads to be computed from the input bits.
Fixes logic_operator in standard uint. Previously it didn't constrain
the input values to be equal the new values after the operation.

---------

Co-authored-by: Innokentii Sennovskii <isennovskiy@gmail.com>
  • Loading branch information
Sarkoxed and Rumata888 committed Feb 9, 2024
1 parent 7b4c6e7 commit ce51d20
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,15 @@ std::vector<uint32_t> StandardCircuitBuilder_<FF>::decompose_into_base4_accumula
return accumulators;
}

/**
* @brief Create an AND or an XOR constraint
*
* @param a The first argument variable index
* @param b The second argument variable index
* @param num_bits The width of arguments. Has to be even
* @param is_xor_gate If true, create an xor gate, otherwise an and gate
* @return accumulator_triple_<FF> Accumulated witnesses (steps) for input arguments and output
*/
template <typename FF>
accumulator_triple_<FF> StandardCircuitBuilder_<FF>::create_logic_constraint(const uint32_t a,
const uint32_t b,
Expand All @@ -311,9 +320,14 @@ accumulator_triple_<FF> StandardCircuitBuilder_<FF>::create_logic_constraint(con

accumulator_triple_<FF> accumulators;

// Get the values of inputs
const uint256_t left_witness_value(this->get_variable(a));
const uint256_t right_witness_value(this->get_variable(b));

ASSERT(left_witness_value < (uint256_t(1) << num_bits));
ASSERT(right_witness_value < (uint256_t(1) << num_bits));

// We are starting accumulation with zeros
FF left_accumulator = FF::zero();
FF right_accumulator = FF::zero();
FF out_accumulator = FF::zero();
Expand All @@ -323,23 +337,34 @@ accumulator_triple_<FF> StandardCircuitBuilder_<FF>::create_logic_constraint(con
uint32_t out_accumulator_idx = this->zero_idx;
constexpr FF four = FF(4);
constexpr FF neg_two = -FF(2);
constexpr FF two = FF(2);

// Num bits is expected to be even
ASSERT(num_bits % 2 == 0);

// Accumulate the bits in quads starting from the high ones
for (size_t i = num_bits - 1; i < num_bits; i -= 2) {
// Get bit values of arguments
bool left_hi_val = left_witness_value.get_bit(i);
bool left_lo_val = left_witness_value.get_bit(i - 1);
bool right_hi_val = right_witness_value.get_bit((i));
bool right_lo_val = right_witness_value.get_bit(i - 1);

// Convert to wintesses
uint32_t left_hi_idx = this->add_variable(left_hi_val ? FF::one() : FF::zero());
uint32_t left_lo_idx = this->add_variable(left_lo_val ? FF::one() : FF::zero());
uint32_t right_hi_idx = this->add_variable(right_hi_val ? FF::one() : FF::zero());
uint32_t right_lo_idx = this->add_variable(right_lo_val ? FF::one() : FF::zero());

// Compute resulting bits
bool out_hi_val = is_xor_gate ? left_hi_val ^ right_hi_val : left_hi_val & right_hi_val;
bool out_lo_val = is_xor_gate ? left_lo_val ^ right_lo_val : left_lo_val & right_lo_val;

// Convert to witnesses
uint32_t out_hi_idx = this->add_variable(out_hi_val ? FF::one() : FF::zero());
uint32_t out_lo_idx = this->add_variable(out_lo_val ? FF::one() : FF::zero());

// Constrain all individual bit witnesses to be boolean
create_bool_gate(left_hi_idx);
create_bool_gate(right_hi_idx);
create_bool_gate(out_hi_idx);
Expand All @@ -348,6 +373,7 @@ accumulator_triple_<FF> StandardCircuitBuilder_<FF>::create_logic_constraint(con
create_bool_gate(right_lo_idx);
create_bool_gate(out_lo_idx);

// Create 2 individual xor or and gates
// a & b = ab
// a ^ b = a + b - 2ab
create_poly_gate({ left_hi_idx,
Expand All @@ -368,21 +394,36 @@ accumulator_triple_<FF> StandardCircuitBuilder_<FF>::create_logic_constraint(con
FF::neg_one(),
FF::zero() });

// Reconstruct the value of the left quad and add as witness
FF left_quad =
this->get_variable(left_lo_idx) + this->get_variable(left_hi_idx) + this->get_variable(left_hi_idx);
uint32_t left_quad_idx = this->add_variable(left_quad);

// Connect the left bits to the left quad
create_add_gate({ left_hi_idx, left_lo_idx, left_quad_idx, two, FF::one(), FF::neg_one(), FF::zero() });

// Reconstruct the value of the right quad and add as witness
FF right_quad =
this->get_variable(right_lo_idx) + this->get_variable(right_hi_idx) + this->get_variable(right_hi_idx);
FF out_quad = this->get_variable(out_lo_idx) + this->get_variable(out_hi_idx) + this->get_variable(out_hi_idx);

uint32_t left_quad_idx = this->add_variable(left_quad);
uint32_t right_quad_idx = this->add_variable(right_quad);

// Connect the left bits to the left quad
create_add_gate({ right_hi_idx, right_lo_idx, right_quad_idx, two, FF::one(), FF::neg_one(), FF::zero() });

// Reconstruct the value of the output quad and add as witness
FF out_quad = this->get_variable(out_lo_idx) + this->get_variable(out_hi_idx) + this->get_variable(out_hi_idx);
uint32_t out_quad_idx = this->add_variable(out_quad);

// Connect the out bits to the out quad
create_add_gate({ out_hi_idx, out_lo_idx, out_quad_idx, two, FF::one(), FF::neg_one(), FF::zero() });

// Compute the value of the left accumulator and add as witness
FF new_left_accumulator = left_accumulator + left_accumulator;
new_left_accumulator = new_left_accumulator + new_left_accumulator;
new_left_accumulator = new_left_accumulator + left_quad;
uint32_t new_left_accumulator_idx = this->add_variable(new_left_accumulator);

// Connect the left quad, previous accumulator and current accumulator
create_add_gate({ left_accumulator_idx,
left_quad_idx,
new_left_accumulator_idx,
Expand All @@ -391,11 +432,13 @@ accumulator_triple_<FF> StandardCircuitBuilder_<FF>::create_logic_constraint(con
FF::neg_one(),
FF::zero() });

// Compute the value of the right accumulator and add as witness
FF new_right_accumulator = right_accumulator + right_accumulator;
new_right_accumulator = new_right_accumulator + new_right_accumulator;
new_right_accumulator = new_right_accumulator + right_quad;
uint32_t new_right_accumulator_idx = this->add_variable(new_right_accumulator);

// Connect the right quad, previous accumulator and current accumulator
create_add_gate({ right_accumulator_idx,
right_quad_idx,
new_right_accumulator_idx,
Expand All @@ -404,18 +447,21 @@ accumulator_triple_<FF> StandardCircuitBuilder_<FF>::create_logic_constraint(con
FF::neg_one(),
FF::zero() });

// Compute the value of the out accumulator and add as witness
FF new_out_accumulator = out_accumulator + out_accumulator;
new_out_accumulator = new_out_accumulator + new_out_accumulator;
new_out_accumulator = new_out_accumulator + out_quad;
uint32_t new_out_accumulator_idx = this->add_variable(new_out_accumulator);

// Connect the out quad, previous accumulator and current accumulator
create_add_gate(
{ out_accumulator_idx, out_quad_idx, new_out_accumulator_idx, four, FF::one(), FF::neg_one(), FF::zero() });

accumulators.left.emplace_back(new_left_accumulator_idx);
accumulators.right.emplace_back(new_right_accumulator_idx);
accumulators.out.emplace_back(new_out_accumulator_idx);

// Update current accumulators
left_accumulator = new_left_accumulator;
left_accumulator_idx = new_left_accumulator_idx;

Expand All @@ -425,6 +471,9 @@ accumulator_triple_<FF> StandardCircuitBuilder_<FF>::create_logic_constraint(con
out_accumulator = new_out_accumulator;
out_accumulator_idx = new_out_accumulator_idx;
}
// Connect the accumulators to inputs
this->assert_equal(accumulators.left.back(), a);
this->assert_equal(accumulators.right.back(), b);
return accumulators;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ uint<Builder, Native> uint<Builder, Native>::logic_operator(const uint& other, c
const uint256_t rhs = other.get_value();
uint256_t out = 0;

// Compute the value of the result
switch (op_type) {
case AND: {
out = lhs & rhs;
Expand All @@ -473,11 +474,14 @@ uint<Builder, Native> uint<Builder, Native>::logic_operator(const uint& other, c
}
}

// If both inputs are constants, just output a new constant uint with the result
if (is_constant() && other.is_constant()) {
// returns a constant uint.
return uint<Builder, Native>(ctx, out);
}

// If one of the inputs is a constant, we need to create a witness from it, because we can only perform logical
// constraints between witnesses
const uint32_t lhs_idx = is_constant() ? ctx->add_variable(lhs) : witness_index;
const uint32_t rhs_idx = other.is_constant() ? ctx->add_variable(rhs) : other.witness_index;

Expand Down

0 comments on commit ce51d20

Please sign in to comment.