Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: generalize protogalaxy to multiple instances #5510

Merged
merged 18 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions barretenberg/cpp/scripts/analyze_client_ivc_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
print(
f"{column['function']:<{max_label_length}}{column['ms']:>8} {column['%']:>8}")
for key in to_keep:
time_ms = bench[key]/1e6
if key not in bench:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

making this script less brittle

time_ms = 0
else:
time_ms = bench[key]/1e6
print(f"{key:<{max_label_length}}{time_ms:>8.0f} {time_ms/sum_of_kept_times_ms:>8.2%}")

# Validate that kept times account for most of the total measured time.
Expand All @@ -45,7 +48,10 @@
print(
f"{column['function']:<{max_label_length}}{column['ms']:>8} {column['%']:>7}")
for key in ['commit(t)', 'compute_combiner(t)', 'compute_perturbator(t)', 'compute_univariate(t)']:
time_ms = bench[key]/1e6
if key not in bench:
time_ms = 0
else:
time_ms = bench[key]/1e6
print(f"{key:<{max_label_length}}{time_ms:>8.0f} {time_ms/sum_of_kept_times_ms:>8.2%}")

print('\nBreakdown of ProtogalaxyProver::fold_instances:')
Expand All @@ -57,7 +63,10 @@
]
max_label_length = max(len(label) for label in protogalaxy_round_labels)
for key in protogalaxy_round_labels:
time_ms = bench[key]/1e6
if key not in bench:
time_ms = 0
else:
time_ms = bench[key]/1e6
total_time_ms = bench["ProtogalaxyProver::fold_instances(t)"]/1e6
print(f"{key:<{max_label_length}}{time_ms:>8.0f} {time_ms/total_time_ms:>8.2%}")

Expand Down
61 changes: 61 additions & 0 deletions barretenberg/cpp/scripts/analyze_protogalaxy_bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import json
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could delete this script; was used to get a breakdown of protogalaxy bench

from pathlib import Path

PREFIX = Path("build-op-count-time")
PROTOGALAXY_BENCH_JSON = Path("protogalaxy_bench.json")
BENCHMARK = "fold_k<GoblinUltraFlavor, 3>/16"

# Single out an independent set of functions accounting for most of BENCHMARK's real_time
to_keep = [
"ProtogalaxyProver::fold_instances(t)",
]
with open(PREFIX/PROTOGALAXY_BENCH_JSON, "r") as read_file:
read_result = json.load(read_file)
for _bench in read_result["benchmarks"]:
print(_bench)
if _bench["name"] == BENCHMARK:
bench = _bench
bench_components = dict(filter(lambda x: x[0] in to_keep, bench.items()))

# For each kept time, get the proportion over all kept times.
sum_of_kept_times_ms = sum(float(time)
for _, time in bench_components.items())/1e6
max_label_length = max(len(label) for label in to_keep)
column = {"function": "function", "ms": "ms", "%": "% sum"}
print(
f"{column['function']:<{max_label_length}}{column['ms']:>8} {column['%']:>8}")
for key in to_keep:
time_ms = bench[key]/1e6
print(f"{key:<{max_label_length}}{time_ms:>8.0f} {time_ms/sum_of_kept_times_ms:>8.2%}")

# Validate that kept times account for most of the total measured time.
total_time_ms = bench["real_time"]
totals = '\nTotal time accounted for: {:.0f}ms/{:.0f}ms = {:.2%}'
totals = totals.format(
sum_of_kept_times_ms, total_time_ms, sum_of_kept_times_ms/total_time_ms)
print(totals)

print("\nMajor contributors:")
print(
f"{column['function']:<{max_label_length}}{column['ms']:>8} {column['%']:>7}")
for key in ['commit(t)', 'compute_combiner(t)', 'compute_perturbator(t)', 'compute_univariate(t)']:
if key not in bench:
time_ms = 0
else:
time_ms = bench[key]/1e6
print(f"{key:<{max_label_length}}{time_ms:>8.0f} {time_ms/sum_of_kept_times_ms:>8.2%}")

print('\nBreakdown of ProtogalaxyProver::fold_instances:')
protogalaxy_round_labels = [
"ProtoGalaxyProver_::preparation_round(t)",
"ProtoGalaxyProver_::perturbator_round(t)",
"ProtoGalaxyProver_::combiner_quotient_round(t)",
"ProtoGalaxyProver_::accumulator_update_round(t)"
]
max_label_length = max(len(label) for label in protogalaxy_round_labels)
for key in protogalaxy_round_labels:
time_ms = bench[key]/1e6
total_time_ms = bench["ProtogalaxyProver::fold_instances(t)"]/1e6
print(f"{key:<{max_label_length}}{time_ms:>8.0f} {time_ms/total_time_ms:>8.2%}")


25 changes: 25 additions & 0 deletions barretenberg/cpp/scripts/benchmark_protogalaxy.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/usr/bin/env bash
set -eu

TARGET="protogalaxy_bench"
FILTER="/16$"
BUILD_DIR=build-op-count-time

# Move above script dir.
cd $(dirname $0)/..

# Measure the benchmarks with ops time counting
./scripts/benchmark_remote.sh protogalaxy_bench\
"./protogalaxy_bench --benchmark_filter=$FILTER\
--benchmark_out=$TARGET.json\
--benchmark_out_format=json"\
op-count-time\
build-op-count-time

# Retrieve output from benching instance
cd $BUILD_DIR
scp $BB_SSH_KEY $BB_SSH_INSTANCE:$BB_SSH_CPP_PATH/build/$TARGET.json .

# Analyze the results
cd ../
python3 ./scripts/analyze_protogalaxy_bench.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <benchmark/benchmark.h>

#include "barretenberg/common/op_count_google_bench.hpp"
#include "barretenberg/protogalaxy/protogalaxy_prover.hpp"
#include "barretenberg/stdlib_circuit_builders/mock_circuits.hpp"
#include "barretenberg/stdlib_circuit_builders/ultra_circuit_builder.hpp"
Expand All @@ -11,11 +12,11 @@ using namespace benchmark;
namespace bb {

// Fold one instance into an accumulator.
template <typename Flavor> void fold_one(State& state) noexcept
template <typename Flavor, size_t k> void fold_k(State& state) noexcept
{
using ProverInstance = ProverInstance_<Flavor>;
using Instance = ProverInstance;
using Instances = ProverInstances_<Flavor, 2>;
using Instances = ProverInstances_<Flavor, k + 1>;
using ProtoGalaxyProver = ProtoGalaxyProver_<Instances>;
using Builder = typename Flavor::CircuitBuilder;

Expand All @@ -28,19 +29,29 @@ template <typename Flavor> void fold_one(State& state) noexcept
MockCircuits::construct_arithmetic_circuit(builder, log2_num_gates);
return std::make_shared<ProverInstance>(builder);
};
std::vector<std::shared_ptr<Instance>> instances;
// TODO(https://github.com/AztecProtocol/barretenberg/issues/938): Parallelize this loop
for (size_t i = 0; i < k + 1; ++i) {
instances.emplace_back(construct_instance());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mb it's worth paralellising this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its outside the actual benchmark so its fine

}

std::shared_ptr<Instance> instance_1 = construct_instance();
std::shared_ptr<Instance> instance_2 = construct_instance();

ProtoGalaxyProver folding_prover({ instance_1, instance_2 });
ProtoGalaxyProver folding_prover(instances);

for (auto _ : state) {
BB_REPORT_OP_COUNT_IN_BENCH(state);
auto proof = folding_prover.fold_instances();
}
}

BENCHMARK(fold_one<UltraFlavor>)->/* vary the circuit size */ DenseRange(14, 20)->Unit(kMillisecond);
BENCHMARK(fold_one<GoblinUltraFlavor>)->/* vary the circuit size */ DenseRange(14, 20)->Unit(kMillisecond);
BENCHMARK(fold_k<UltraFlavor, 1>)->/* vary the circuit size */ DenseRange(14, 20)->Unit(kMillisecond);
BENCHMARK(fold_k<GoblinUltraFlavor, 1>)->/* vary the circuit size */ DenseRange(14, 20)->Unit(kMillisecond);

BENCHMARK(fold_k<UltraFlavor, 2>)->/* vary the circuit size */ DenseRange(14, 20)->Unit(kMillisecond);
BENCHMARK(fold_k<GoblinUltraFlavor, 2>)->/* vary the circuit size */ DenseRange(14, 20)->Unit(kMillisecond);

BENCHMARK(fold_k<UltraFlavor, 3>)->/* vary the circuit size */ DenseRange(14, 20)->Unit(kMillisecond);
BENCHMARK(fold_k<GoblinUltraFlavor, 3>)->/* vary the circuit size */ DenseRange(14, 20)->Unit(kMillisecond);

} // namespace bb

BENCHMARK_MAIN();
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ void _bench_round(::benchmark::State& state, void (*F)(ProtoGalaxyProver_<Prover
return std::make_shared<ProverInstance>(builder);
};

// TODO(https://github.com/AztecProtocol/barretenberg/issues/938): Parallelize this loop, also extend to more than
// k=1
std::shared_ptr<ProverInstance> prover_instance_1 = construct_instance();
std::shared_ptr<ProverInstance> prover_instance_2 = construct_instance();

Expand Down
83 changes: 82 additions & 1 deletion barretenberg/cpp/src/barretenberg/polynomials/univariate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,13 +281,94 @@ template <class Fr, size_t domain_end, size_t domain_start = 0> class Univariate

std::copy(evaluations.begin(), evaluations.end(), result.evaluations.begin());

static constexpr Fr inverse_two = Fr(2).invert();
if constexpr (LENGTH == 2) {
Fr delta = value_at(1) - value_at(0);
static_assert(EXTENDED_LENGTH != 0);
for (size_t idx = domain_start; idx < EXTENDED_DOMAIN_END - 1; idx++) {
for (size_t idx = domain_end - 1; idx < EXTENDED_DOMAIN_END - 1; idx++) {
result.value_at(idx + 1) = result.value_at(idx) + delta;
}
return result;
} else if constexpr (LENGTH == 3) {
// Based off https://hackmd.io/@aztec-network/SyR45cmOq?type=view
// The technique used here is the same as the length == 3 case below.
Fr a = (value_at(2) + value_at(0)) * inverse_two - value_at(1);
Fr b = value_at(1) - a - value_at(0);
Fr a2 = a + a;
Fr a_mul = a2;
for (size_t i = 0; i < domain_end - 2; i++) {
a_mul += a2;
}
Fr extra = a_mul + a + b;
for (size_t idx = domain_end - 1; idx < EXTENDED_DOMAIN_END - 1; idx++) {
result.value_at(idx + 1) = result.value_at(idx) + extra;
extra += a2;
}
return result;
} else if constexpr (LENGTH == 4) {
static constexpr Fr inverse_six = Fr(6).invert(); // computed at compile time for efficiency

// To compute a barycentric extension, we can compute the coefficients of the univariate.
// We have the evaluation of the polynomial at the domain (which is assumed to be 0, 1, 2, 3).
// Therefore, we have the 4 linear equations from plugging into f(x) = ax^3 + bx^2 + cx + d:
// a*0 + b*0 + c*0 + d = f(0)
// a*1 + b*1 + c*1 + d = f(1)
// a*2^3 + b*2^2 + c*2 + d = f(2)
// a*3^3 + b*3^2 + c*3 + d = f(3)
// These equations can be rewritten as a matrix equation M * [a, b, c, d] = [f(0), f(1), f(2), f(3)], where
// M is:
// 0, 0, 0, 1
// 1, 1, 1, 1
// 2^3, 2^2, 2, 1
// 3^3, 3^2, 3, 1
// We can invert this matrix in order to compute a, b, c, d:
// -1/6, 1/2, -1/2, 1/6
// 1, -5/2, 2, -1/2
// -11/6, 3, -3/2, 1/3
// 1, 0, 0, 0
// To compute these values, we can multiply everything by 6 and multiply by inverse_six at the end for each
// coefficient The resulting computation here does 18 field adds, 6 subtracts, 3 muls to compute a, b, c,
// and d.
Fr zero_times_3 = value_at(0) + value_at(0) + value_at(0);
Fr zero_times_6 = zero_times_3 + zero_times_3;
Fr zero_times_12 = zero_times_6 + zero_times_6;
Fr one_times_3 = value_at(1) + value_at(1) + value_at(1);
Fr one_times_6 = one_times_3 + one_times_3;
Fr two_times_3 = value_at(2) + value_at(2) + value_at(2);
Fr three_times_2 = value_at(3) + value_at(3);
Fr three_times_3 = three_times_2 + value_at(3);

Fr one_minus_two_times_3 = one_times_3 - two_times_3;
Fr one_minus_two_times_6 = one_minus_two_times_3 + one_minus_two_times_3;
Fr one_minus_two_times_12 = one_minus_two_times_6 + one_minus_two_times_6;
Fr a = (one_minus_two_times_3 + value_at(3) - value_at(0)) * inverse_six; // compute a in 1 muls and 4 adds
Fr b = (zero_times_6 - one_minus_two_times_12 - one_times_3 - three_times_3) * inverse_six;
Fr c = (value_at(0) - zero_times_12 + one_minus_two_times_12 + one_times_6 + two_times_3 + three_times_2) *
inverse_six;

// Then, outside of the a, b, c, d computation, we need to do some extra precomputation
// This work is 3 field muls, 8 adds
Fr a_plus_b = a + b;
Fr a_plus_b_times_2 = a_plus_b + a_plus_b;
size_t start_idx_sqr = (domain_end - 1) * (domain_end - 1);
size_t idx_sqr_three = start_idx_sqr + start_idx_sqr + start_idx_sqr;
Fr idx_sqr_three_times_a = Fr(idx_sqr_three) * a;
Fr x_a_term = Fr(6 * (domain_end - 1)) * a;
Fr three_a = a + a + a;
Fr six_a = three_a + three_a;

Fr three_a_plus_two_b = a_plus_b_times_2 + a;
Fr linear_term = Fr(domain_end - 1) * three_a_plus_two_b + (a_plus_b + c);
// For each new evaluation, we do only 6 field additions and 0 muls.
for (size_t idx = domain_end - 1; idx < EXTENDED_DOMAIN_END - 1; idx++) {
result.value_at(idx + 1) = result.value_at(idx) + idx_sqr_three_times_a + linear_term;

idx_sqr_three_times_a += x_a_term + three_a;
x_a_term += six_a;

linear_term += three_a_plus_two_b;
}
return result;
} else {
for (size_t k = domain_end; k != EXTENDED_DOMAIN_END; ++k) {
result.value_at(k) = 0;
Expand Down
Loading
Loading