Skip to content

Commit

Permalink
feat: ECCVM witness generation optimisation (#5211)
Browse files Browse the repository at this point in the history
This PR modifies the witness generation code for the ECCVM circuit
builder.

In our ivc benchmarks, the overall proportion of work performed by
ECCVM::create_prover has reduced from 10% to less than 1%.

Key changes are multithreading witness generation, as well as removing a
substantial number of field inversions that we were unnecessarily
performing. The inversions are now more effectively performed via
calling `field_t::batch_invert`

```
Benchmarking lock created at ~/BENCHMARK_IN_PROGRESS.
client_ivc_bench                                                                                                                                                                  100%   15MB  47.2MB/s   00:00    
2024-03-18T10:50:07+00:00
Running ./client_ivc_bench
Run on (16 X 3631.57 MHz CPU s)
CPU Caches:
  L1 Data 32 KiB (x8)
  L1 Instruction 32 KiB (x8)
  L2 Unified 1024 KiB (x8)
  L3 Unified 36608 KiB (x1)
Load Average: 1.16, 0.82, 0.33
--------------------------------------------------------------------------------
Benchmark                      Time             CPU   Iterations UserCounters...
--------------------------------------------------------------------------------
ClientIVCBench/Full/6      23697 ms        18934 ms            1 Decider::construct_proof=1 Decider::construct_proof(t)=755.044M ECCVMComposer::compute_commitment_key=1 ECCVMComposer::compute_commitment_key(t)=3.77177M ECCVMComposer::compute_witness=1 ECCVMComposer::compute_witness(t)=129.434M ECCVMComposer::create_prover=1 ECCVMComposer::create_prover(t)=149.26M ECCVMComposer::create_proving_key=1 ECCVMComposer::create_proving_key(t)=15.833M ECCVMProver::construct_proof=1 ECCVMProver::construct_proof(t)=1.78177G Goblin::merge=11 Goblin::merge(t)=128.554M GoblinTranslatorCircuitBuilder::constructor=1 GoblinTranslatorCircuitBuilder::constructor(t)=58.2017M GoblinTranslatorComposer::create_prover=1 GoblinTranslatorComposer::create_prover(t)=121.617M GoblinTranslatorProver::construct_proof=1 GoblinTranslatorProver::construct_proof(t)=928.122M ProtoGalaxyProver_::accumulator_update_round=10 ProtoGalaxyProver_::accumulator_update_round(t)=727.574M ProtoGalaxyProver_::combiner_quotient_round=10 ProtoGalaxyProver_::combiner_quotient_round(t)=7.29332G ProtoGalaxyProver_::perturbator_round=10 ProtoGalaxyProver_::perturbator_round(t)=1.32753G ProtoGalaxyProver_::preparation_round=10 ProtoGalaxyProver_::preparation_round(t)=4.16456G ProtogalaxyProver::fold_instances=10 ProtogalaxyProver::fold_instances(t)=13.513G ProverInstance(Circuit&)=11 ProverInstance(Circuit&)(t)=1.96494G batch_mul_with_endomorphism=30 batch_mul_with_endomorphism(t)=567.025M commit=425 commit(t)=4.03553G compute_combiner=10 compute_combiner(t)=7.29114G compute_perturbator=9 compute_perturbator(t)=1.32717G compute_univariate=48 compute_univariate(t)=1.43152G construct_circuits=6 construct_circuits(t)=4.27911G
Benchmarking lock deleted.
client_ivc_bench.json                                                                                                                                                             100% 4027   130.8KB/s   00:00    
function                                        ms     % sum
construct_circuits(t)                         4279    18.12%
ProverInstance(Circuit&)(t)                   1965     8.32%
ProtogalaxyProver::fold_instances(t)         13513    57.21%
Decider::construct_proof(t)                    755     3.20%
ECCVMComposer::create_prover(t)                149     0.63%
GoblinTranslatorComposer::create_prover(t)     122     0.51%
ECCVMProver::construct_proof(t)               1782     7.54%
GoblinTranslatorProver::construct_proof(t)     928     3.93%
Goblin::merge(t)                               129     0.54%

Total time accounted for: 23621ms/23697ms = 99.68%

Major contributors:
function                                        ms    % sum
commit(t)                                     4036   17.08%
compute_combiner(t)                           7291   30.87%
compute_perturbator(t)                        1327    5.62%
compute_univariate(t)                         1432    6.06%

Breakdown of ECCVMProver::create_prover:
ECCVMComposer::compute_witness(t)              129    86.72%
ECCVMComposer::create_proving_key(t)            16    10.61%

Breakdown of ProtogalaxyProver::fold_instances:
ProtoGalaxyProver_::preparation_round(t)           4165    30.82%
ProtoGalaxyProver_::perturbator_round(t)           1328     9.82%
ProtoGalaxyProver_::combiner_quotient_round(t)     7293    53.97%
ProtoGalaxyProver_::accumulator_update_round(t)     728     5.38%
```
  • Loading branch information
zac-williamson committed Mar 18, 2024
1 parent dd56a0e commit 85ac726
Show file tree
Hide file tree
Showing 9 changed files with 801 additions and 389 deletions.
7 changes: 0 additions & 7 deletions barretenberg/cpp/scripts/analyze_client_ivc_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,6 @@
time_ms = bench[key]/1e6
print(f"{key:<{max_label_length}}{time_ms:>8.0f} {time_ms/sum_of_kept_times_ms:>8.2%}")


print('\nBreakdown of ECCVMProver::create_prover:')
for key in ["ECCVMComposer::compute_witness(t)", "ECCVMComposer::create_proving_key(t)"]:
time_ms = bench[key]/1e6
total_time_ms = bench["ECCVMComposer::create_prover(t)"]/1e6
print(f"{key:<{max_label_length}}{time_ms:>8.0f} {time_ms/total_time_ms:>8.2%}")

print('\nBreakdown of ProtogalaxyProver::fold_instances:')
protogalaxy_round_labels = [
"ProtoGalaxyProver_::preparation_round(t)",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ TYPED_TEST(ECCVMComposerTests, EqFails)
.z1 = 0,
.z2 = 0,
.mul_scalar_full = 0 });
builder.op_queue->num_transcript_rows++;
auto composer = ECCVMComposer_<Flavor>();
auto prover = composer.create_prover(builder);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ template <typename CycleGroup> struct ScalarMul {
typename CycleGroup::affine_element base_point;
std::array<int, NUM_WNAF_SLICES> wnaf_slices;
bool wnaf_skew;
std::array<typename CycleGroup::affine_element, POINT_TABLE_SIZE> precomputed_table;
// size bumped by 1 to record base_point.dbl()
std::array<typename CycleGroup::affine_element, POINT_TABLE_SIZE + 1> precomputed_table;
};

template <typename CycleGroup> using MSM = std::vector<ScalarMul<CycleGroup>>;
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -36,75 +36,76 @@ template <typename Flavor> class ECCVMPrecomputedTablesBuilder {
static std::vector<PrecomputeState> compute_precompute_state(
const std::vector<bb::eccvm::ScalarMul<CycleGroup>>& ecc_muls)
{
std::vector<PrecomputeState> precompute_state;
static constexpr size_t num_rows_per_scalar = NUM_WNAF_SLICES / WNAF_SLICES_PER_ROW;
const size_t num_precompute_rows = num_rows_per_scalar * ecc_muls.size() + 1;
std::vector<PrecomputeState> precompute_state(num_precompute_rows);

// start with empty row (shiftable polynomials must have 0 as first coefficient)
precompute_state.push_back(PrecomputeState{});
static constexpr size_t num_rows_per_scalar = NUM_WNAF_SLICES / WNAF_SLICES_PER_ROW;
precompute_state[0] = PrecomputeState{};

// current impl doesn't work if not 4
static_assert(WNAF_SLICES_PER_ROW == 4);

for (const auto& entry : ecc_muls) {
const auto& slices = entry.wnaf_slices;
uint256_t scalar_sum = 0;

const Element point = entry.base_point;
const Element d2 = point.dbl();

for (size_t i = 0; i < num_rows_per_scalar; ++i) {
PrecomputeState row;
const int slice0 = slices[i * WNAF_SLICES_PER_ROW];
const int slice1 = slices[i * WNAF_SLICES_PER_ROW + 1];
const int slice2 = slices[i * WNAF_SLICES_PER_ROW + 2];
const int slice3 = slices[i * WNAF_SLICES_PER_ROW + 3];

const int slice0base2 = (slice0 + 15) / 2;
const int slice1base2 = (slice1 + 15) / 2;
const int slice2base2 = (slice2 + 15) / 2;
const int slice3base2 = (slice3 + 15) / 2;

// convert into 2-bit chunks
row.s1 = slice0base2 >> 2;
row.s2 = slice0base2 & 3;
row.s3 = slice1base2 >> 2;
row.s4 = slice1base2 & 3;
row.s5 = slice2base2 >> 2;
row.s6 = slice2base2 & 3;
row.s7 = slice3base2 >> 2;
row.s8 = slice3base2 & 3;
bool last_row = (i == num_rows_per_scalar - 1);

row.skew = last_row ? entry.wnaf_skew : false;

row.scalar_sum = scalar_sum;

// N.B. we apply a constraint that requires slice1 to be positive for the 1st row of each scalar sum.
// This ensures we do not have WNAF representations of negative values
const int row_chunk = slice3 + slice2 * (1 << 4) + slice1 * (1 << 8) + slice0 * (1 << 12);

bool chunk_negative = row_chunk < 0;

scalar_sum = scalar_sum << (WNAF_SLICE_BITS * WNAF_SLICES_PER_ROW);
if (chunk_negative) {
scalar_sum -= static_cast<uint64_t>(-row_chunk);
} else {
scalar_sum += static_cast<uint64_t>(row_chunk);
run_loop_in_parallel(ecc_muls.size(), [&](size_t start, size_t end) {
for (size_t j = start; j < end; j++) {
const auto& entry = ecc_muls[j];
const auto& slices = entry.wnaf_slices;
uint256_t scalar_sum = 0;

for (size_t i = 0; i < num_rows_per_scalar; ++i) {
PrecomputeState row;
const int slice0 = slices[i * WNAF_SLICES_PER_ROW];
const int slice1 = slices[i * WNAF_SLICES_PER_ROW + 1];
const int slice2 = slices[i * WNAF_SLICES_PER_ROW + 2];
const int slice3 = slices[i * WNAF_SLICES_PER_ROW + 3];

const int slice0base2 = (slice0 + 15) / 2;
const int slice1base2 = (slice1 + 15) / 2;
const int slice2base2 = (slice2 + 15) / 2;
const int slice3base2 = (slice3 + 15) / 2;

// convert into 2-bit chunks
row.s1 = slice0base2 >> 2;
row.s2 = slice0base2 & 3;
row.s3 = slice1base2 >> 2;
row.s4 = slice1base2 & 3;
row.s5 = slice2base2 >> 2;
row.s6 = slice2base2 & 3;
row.s7 = slice3base2 >> 2;
row.s8 = slice3base2 & 3;
bool last_row = (i == num_rows_per_scalar - 1);

row.skew = last_row ? entry.wnaf_skew : false;

row.scalar_sum = scalar_sum;

// N.B. we apply a constraint that requires slice1 to be positive for the 1st row of each scalar
// sum. This ensures we do not have WNAF representations of negative values
const int row_chunk = slice3 + slice2 * (1 << 4) + slice1 * (1 << 8) + slice0 * (1 << 12);

bool chunk_negative = row_chunk < 0;

scalar_sum = scalar_sum << (WNAF_SLICE_BITS * WNAF_SLICES_PER_ROW);
if (chunk_negative) {
scalar_sum -= static_cast<uint64_t>(-row_chunk);
} else {
scalar_sum += static_cast<uint64_t>(row_chunk);
}
row.round = static_cast<uint32_t>(i);
row.point_transition = last_row;
row.pc = entry.pc;

if (last_row) {
ASSERT(scalar_sum - entry.wnaf_skew == entry.scalar);
}

row.precompute_double = entry.precomputed_table[bb::eccvm::POINT_TABLE_SIZE];
// fill accumulator in reverse order i.e. first row = 15[P], then 13[P], ..., 1[P]
row.precompute_accumulator = entry.precomputed_table[bb::eccvm::POINT_TABLE_SIZE - 1 - i];
precompute_state[j * num_rows_per_scalar + i + 1] = (row);
}
row.round = static_cast<uint32_t>(i);
row.point_transition = last_row;
row.pc = entry.pc;

if (last_row) {
ASSERT(scalar_sum - entry.wnaf_skew == entry.scalar);
}

row.precompute_double = d2;
// fill accumulator in reverse order i.e. first row = 15[P], then 13[P], ..., 1[P]
row.precompute_accumulator = entry.precomputed_table[bb::eccvm::POINT_TABLE_SIZE - 1 - i];
precompute_state.emplace_back(row);
}
}
});
return precompute_state;
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ template <typename Flavor> class ECCVMTranscriptBuilder {
static std::vector<TranscriptState> compute_transcript_state(
const std::vector<bb::eccvm::VMOperation<CycleGroup>>& vm_operations, const uint32_t total_number_of_muls)
{
std::vector<TranscriptState> transcript_state;
const size_t num_transcript_entries = vm_operations.size() + 2;

std::vector<TranscriptState> transcript_state(num_transcript_entries);
std::vector<FF> inverse_trace(num_transcript_entries - 2);
VMState state{
.pc = total_number_of_muls,
.count = 0,
Expand All @@ -69,11 +72,10 @@ template <typename Flavor> class ECCVMTranscriptBuilder {
.is_accumulator_empty = true,
};
VMState updated_state;

// add an empty row. 1st row all zeroes because of our shiftable polynomials
transcript_state.emplace_back(TranscriptState{});
transcript_state[0] = (TranscriptState{});
for (size_t i = 0; i < vm_operations.size(); ++i) {
TranscriptState row;
TranscriptState& row = transcript_state[i + 1];
const bb::eccvm::VMOperation<CycleGroup>& entry = vm_operations[i];

const bool is_mul = entry.mul;
Expand Down Expand Up @@ -158,28 +160,32 @@ template <typename Flavor> class ECCVMTranscriptBuilder {
ASSERT((row.msm_output_x != row.accumulator_x) &&
"eccvm: attempting msm. Result point x-coordinate matches accumulator x-coordinate.");
state.msm_accumulator = CycleGroup::affine_point_at_infinity;
row.collision_check = (row.msm_output_x - row.accumulator_x).invert();
inverse_trace[i] = (row.msm_output_x - row.accumulator_x);
} else if (entry.add && !row.accumulator_empty) {
ASSERT((row.base_x != row.accumulator_x) &&
"eccvm: attempting to add points with matching x-coordinates");
row.collision_check = (row.base_x - row.accumulator_x).invert();
inverse_trace[i] = (row.base_x - row.accumulator_x);
} else {
inverse_trace[i] = (0);
}

state = updated_state;

if (entry.mul && next_not_msm) {
state.msm_accumulator = CycleGroup::affine_point_at_infinity;
}
transcript_state.emplace_back(row);
}

TranscriptState final_row;
FF::batch_invert(&inverse_trace[0], inverse_trace.size());
for (size_t i = 0; i < inverse_trace.size(); ++i) {
transcript_state[i + 1].collision_check = inverse_trace[i];
}
TranscriptState& final_row = transcript_state.back();
final_row.pc = updated_state.pc;
final_row.accumulator_x = (updated_state.accumulator.is_point_at_infinity()) ? 0 : updated_state.accumulator.x;
final_row.accumulator_y = (updated_state.accumulator.is_point_at_infinity()) ? 0 : updated_state.accumulator.y;
final_row.accumulator_empty = updated_state.is_accumulator_empty;

transcript_state.push_back(final_row);
return transcript_state;
}
};
Expand Down
Loading

0 comments on commit 85ac726

Please sign in to comment.