diff --git a/barretenberg/acir_tests/flows/all_cmds.sh b/barretenberg/acir_tests/flows/all_cmds.sh index 97f9f8ea4c1..a65159351ed 100755 --- a/barretenberg/acir_tests/flows/all_cmds.sh +++ b/barretenberg/acir_tests/flows/all_cmds.sh @@ -9,6 +9,7 @@ FLAGS="-c $CRS_PATH $VFLAG" $BIN gates $FLAGS $BFLAG > /dev/null $BIN prove -o proof $FLAGS $BFLAG $BIN write_vk -o vk $FLAGS $BFLAG +$BIN write_pk -o pk $FLAGS $BFLAG $BIN verify -k vk -p proof $FLAGS # Check supplemental functions. diff --git a/barretenberg/cpp/src/barretenberg/bb/main.cpp b/barretenberg/cpp/src/barretenberg/bb/main.cpp index 11791b2d383..0cc46d9dd24 100644 --- a/barretenberg/cpp/src/barretenberg/bb/main.cpp +++ b/barretenberg/cpp/src/barretenberg/bb/main.cpp @@ -1,5 +1,6 @@ #include "barretenberg/dsl/acir_format/acir_format.hpp" #include "barretenberg/dsl/types.hpp" +#include "barretenberg/plonk/proof_system/proving_key/serialize.hpp" #include "config.hpp" #include "get_bytecode.hpp" #include "get_crs.hpp" @@ -183,7 +184,7 @@ bool verify(const std::string& proof_path, bool recursive, const std::string& vk * @param bytecodePath Path to the file containing the serialized circuit * @param outputPath Path to write the verification key to */ -void writeVk(const std::string& bytecodePath, const std::string& outputPath) +void write_vk(const std::string& bytecodePath, const std::string& outputPath) { auto constraint_system = get_constraint_system(bytecodePath); auto acir_composer = init(constraint_system); @@ -199,6 +200,22 @@ void writeVk(const std::string& bytecodePath, const std::string& outputPath) } } +void write_pk(const std::string& bytecodePath, const std::string& outputPath) +{ + auto constraint_system = get_constraint_system(bytecodePath); + auto acir_composer = init(constraint_system); + auto pk = acir_composer.init_proving_key(constraint_system); + auto serialized_pk = to_buffer(*pk); + + if (outputPath == "-") { + writeRawBytesToStdout(serialized_pk); + vinfo("pk written to stdout"); + } else { + write_file(outputPath, serialized_pk); + vinfo("pk written to: ", outputPath); + } +} + /** * @brief Writes a Solidity verifier contract for an ACIR circuit to a file * @@ -253,7 +270,7 @@ void contract(const std::string& output_path, const std::string& vk_path) * @param vk_path Path to the file containing the serialized verification key * @param output_path Path to write the proof to */ -void proofAsFields(const std::string& proof_path, std::string const& vk_path, const std::string& output_path) +void proof_as_fields(const std::string& proof_path, std::string const& vk_path, const std::string& output_path) { auto acir_composer = init(); auto vk_data = from_buffer(read_file(vk_path)); @@ -282,7 +299,7 @@ void proofAsFields(const std::string& proof_path, std::string const& vk_path, co * @param vk_path Path to the file containing the serialized verification key * @param output_path Path to write the verification key to */ -void vkAsFields(const std::string& vk_path, const std::string& output_path) +void vk_as_fields(const std::string& vk_path, const std::string& output_path) { auto acir_composer = init(); auto vk_data = from_buffer(read_file(vk_path)); @@ -311,7 +328,7 @@ void vkAsFields(const std::string& vk_path, const std::string& output_path) * * @param output_path Path to write the information to */ -void acvmInfo(const std::string& output_path) +void acvm_info(const std::string& output_path) { const char* jsonData = R"({ @@ -335,12 +352,12 @@ void acvmInfo(const std::string& output_path) } } -bool flagPresent(std::vector& args, const std::string& flag) +bool flag_present(std::vector& args, const std::string& flag) { return std::find(args.begin(), args.end(), flag) != args.end(); } -std::string getOption(std::vector& args, const std::string& option, const std::string& defaultValue) +std::string get_option(std::vector& args, const std::string& option, const std::string& defaultValue) { auto itr = std::find(args.begin(), args.end(), option); return (itr != args.end() && std::next(itr) != args.end()) ? *(std::next(itr)) : defaultValue; @@ -350,7 +367,7 @@ int main(int argc, char* argv[]) { try { std::vector args(argv + 1, argv + argc); - verbose = flagPresent(args, "-v") || flagPresent(args, "--verbose"); + verbose = flag_present(args, "-v") || flag_present(args, "--verbose"); if (args.empty()) { std::cerr << "No command provided.\n"; @@ -359,12 +376,13 @@ int main(int argc, char* argv[]) std::string command = args[0]; - std::string bytecode_path = getOption(args, "-b", "./target/acir.gz"); - std::string witness_path = getOption(args, "-w", "./target/witness.gz"); - std::string proof_path = getOption(args, "-p", "./proofs/proof"); - std::string vk_path = getOption(args, "-k", "./target/vk"); - CRS_PATH = getOption(args, "-c", "./crs"); - bool recursive = flagPresent(args, "-r") || flagPresent(args, "--recursive"); + std::string bytecode_path = get_option(args, "-b", "./target/acir.gz"); + std::string witness_path = get_option(args, "-w", "./target/witness.gz"); + std::string proof_path = get_option(args, "-p", "./proofs/proof"); + std::string vk_path = get_option(args, "-k", "./target/vk"); + std::string pk_path = get_option(args, "-r", "./target/pk"); + CRS_PATH = get_option(args, "-c", "./crs"); + bool recursive = flag_present(args, "-r") || flag_present(args, "--recursive"); // Skip CRS initialization for any command which doesn't require the CRS. if (command == "--version") { @@ -372,8 +390,8 @@ int main(int argc, char* argv[]) return 0; } if (command == "info") { - std::string output_path = getOption(args, "-o", "info.json"); - acvmInfo(output_path); + std::string output_path = get_option(args, "-o", "info.json"); + acvm_info(output_path); return 0; } @@ -381,24 +399,27 @@ int main(int argc, char* argv[]) return proveAndVerify(bytecode_path, witness_path, recursive) ? 0 : 1; } if (command == "prove") { - std::string output_path = getOption(args, "-o", "./proofs/proof"); + std::string output_path = get_option(args, "-o", "./proofs/proof"); prove(bytecode_path, witness_path, recursive, output_path); } else if (command == "gates") { gateCount(bytecode_path); } else if (command == "verify") { return verify(proof_path, recursive, vk_path) ? 0 : 1; } else if (command == "contract") { - std::string output_path = getOption(args, "-o", "./target/contract.sol"); + std::string output_path = get_option(args, "-o", "./target/contract.sol"); contract(output_path, vk_path); } else if (command == "write_vk") { - std::string output_path = getOption(args, "-o", "./target/vk"); - writeVk(bytecode_path, output_path); + std::string output_path = get_option(args, "-o", "./target/vk"); + write_vk(bytecode_path, output_path); + } else if (command == "write_pk") { + std::string output_path = get_option(args, "-o", "./target/pk"); + write_pk(bytecode_path, output_path); } else if (command == "proof_as_fields") { - std::string output_path = getOption(args, "-o", proof_path + "_fields.json"); - proofAsFields(proof_path, vk_path, output_path); + std::string output_path = get_option(args, "-o", proof_path + "_fields.json"); + proof_as_fields(proof_path, vk_path, output_path); } else if (command == "vk_as_fields") { - std::string output_path = getOption(args, "-o", vk_path + "_fields.json"); - vkAsFields(vk_path, output_path); + std::string output_path = get_option(args, "-o", vk_path + "_fields.json"); + vk_as_fields(vk_path, output_path); } else { std::cerr << "Unknown command: " << command << "\n"; return 1; diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_proofs/acir_composer.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_proofs/acir_composer.cpp index 5f7cee439c6..0dc4a117735 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_proofs/acir_composer.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_proofs/acir_composer.cpp @@ -4,6 +4,7 @@ #include "barretenberg/dsl/acir_format/acir_format.hpp" #include "barretenberg/dsl/acir_format/recursion_constraint.hpp" #include "barretenberg/dsl/types.hpp" +#include "barretenberg/plonk/proof_system/proving_key/proving_key.hpp" #include "barretenberg/plonk/proof_system/proving_key/serialize.hpp" #include "barretenberg/plonk/proof_system/verification_key/sol_gen.hpp" #include "barretenberg/plonk/proof_system/verification_key/verification_key.hpp" @@ -30,12 +31,14 @@ void AcirComposer::create_circuit(acir_format::acir_format& constraint_system) vinfo("gates: ", builder_.get_total_circuit_size()); } -void AcirComposer::init_proving_key(acir_format::acir_format& constraint_system) +std::shared_ptr AcirComposer::init_proving_key( + acir_format::acir_format& constraint_system) { create_circuit(constraint_system); acir_format::Composer composer; vinfo("computing proving key..."); proving_key_ = composer.compute_proving_key(builder_); + return proving_key_; } std::vector AcirComposer::create_proof(acir_format::acir_format& constraint_system, diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_proofs/acir_composer.hpp b/barretenberg/cpp/src/barretenberg/dsl/acir_proofs/acir_composer.hpp index 32b678268e3..db78f067a22 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_proofs/acir_composer.hpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_proofs/acir_composer.hpp @@ -14,7 +14,7 @@ class AcirComposer { void create_circuit(acir_format::acir_format& constraint_system); - void init_proving_key(acir_format::acir_format& constraint_system); + std::shared_ptr init_proving_key(acir_format::acir_format& constraint_system); std::vector create_proof(acir_format::acir_format& constraint_system, acir_format::WitnessVector& witness, diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_proofs/c_bind.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_proofs/c_bind.cpp index 0bdfbb519d2..b92213f9724 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_proofs/c_bind.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_proofs/c_bind.cpp @@ -6,6 +6,7 @@ #include "barretenberg/common/serialize.hpp" #include "barretenberg/common/slab_allocator.hpp" #include "barretenberg/dsl/acir_format/acir_format.hpp" +#include "barretenberg/plonk/proof_system/proving_key/serialize.hpp" #include "barretenberg/plonk/proof_system/verification_key/verification_key.hpp" #include "barretenberg/srs/global_crs.hpp" #include @@ -73,6 +74,15 @@ WASM_EXPORT void acir_get_verification_key(in_ptr acir_composer_ptr, uint8_t** o *out = to_heap_buffer(to_buffer(*vk)); } +WASM_EXPORT void acir_get_proving_key(in_ptr acir_composer_ptr, uint8_t const* acir_vec, uint8_t** out) +{ + auto acir_composer = reinterpret_cast(*acir_composer_ptr); + auto constraint_system = acir_format::circuit_buf_to_acir_format(from_buffer>(acir_vec)); + auto pk = acir_composer->init_proving_key(constraint_system); + // We flatten to a vector first, as that's how we treat it on the calling side. + *out = to_heap_buffer(to_buffer(*pk)); +} + WASM_EXPORT void acir_verify_proof(in_ptr acir_composer_ptr, uint8_t const* proof_buf, bool const* is_recursive, diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_proofs/c_bind.hpp b/barretenberg/cpp/src/barretenberg/dsl/acir_proofs/c_bind.hpp index e17af7a260d..5ffa298b2fc 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_proofs/c_bind.hpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_proofs/c_bind.hpp @@ -38,6 +38,8 @@ WASM_EXPORT void acir_init_verification_key(in_ptr acir_composer_ptr); WASM_EXPORT void acir_get_verification_key(in_ptr acir_composer_ptr, uint8_t** out); +WASM_EXPORT void acir_get_proving_key(in_ptr acir_composer_ptr, uint8_t const* acir_vec, uint8_t** out); + WASM_EXPORT void acir_verify_proof(in_ptr acir_composer_ptr, uint8_t const* proof_buf, bool const* is_recursive, diff --git a/barretenberg/ts/src/barretenberg_api/index.ts b/barretenberg/ts/src/barretenberg_api/index.ts index b47f0d8f0c7..ea1aac55dad 100644 --- a/barretenberg/ts/src/barretenberg_api/index.ts +++ b/barretenberg/ts/src/barretenberg_api/index.ts @@ -381,6 +381,18 @@ export class BarretenbergApi { return out[0]; } + async acirGetProvingKey(acirComposerPtr: Ptr, constraintSystemBuf: Uint8Array): Promise { + const inArgs = [acirComposerPtr, constraintSystemBuf].map(serializeBufferable); + const outTypes: OutputType[] = [BufferDeserializer()]; + const result = await this.wasm.callWasmExport( + 'acir_get_proving_key', + inArgs, + outTypes.map(t => t.SIZE_IN_BYTES), + ); + const out = result.map((r, i) => outTypes[i].fromBuffer(r)); + return out[0]; + } + async acirVerifyProof(acirComposerPtr: Ptr, proofBuf: Uint8Array, isRecursive: boolean): Promise { const inArgs = [acirComposerPtr, proofBuf, isRecursive].map(serializeBufferable); const outTypes: OutputType[] = [BoolDeserializer()]; diff --git a/barretenberg/ts/src/main.ts b/barretenberg/ts/src/main.ts index a985ea59e79..016c8a63c04 100755 --- a/barretenberg/ts/src/main.ts +++ b/barretenberg/ts/src/main.ts @@ -219,6 +219,25 @@ export async function writeVk(bytecodePath: string, crsPath: string, outputPath: } } +export async function writePk(bytecodePath: string, crsPath: string, outputPath: string) { + const { api, acirComposer } = await init(bytecodePath, crsPath); + try { + debug('initing proving key...'); + const bytecode = getBytecode(bytecodePath); + const pk = await api.acirGetProvingKey(acirComposer, bytecode); + + if (outputPath === '-') { + process.stdout.write(pk); + debug(`pk written to stdout`); + } else { + writeFileSync(outputPath, pk); + debug(`pk written to: ${outputPath}`); + } + } finally { + await api.destroy(); + } +} + export async function proofAsFields(proofPath: string, vkPath: string, outputPath: string) { const { api, acirComposer } = await initLite(); @@ -347,6 +366,16 @@ program await writeVk(bytecodePath, crsPath, outputPath); }); +program + .command('write_pk') + .description('Output proving key.') + .option('-b, --bytecode-path ', 'Specify the bytecode path', './target/acir.gz') + .requiredOption('-o, --output-path ', 'Specify the path to write the key') + .action(async ({ bytecodePath, outputPath, crsPath }) => { + handleGlobalOptions(); + await writePk(bytecodePath, crsPath, outputPath); + }); + program .command('proof_as_fields') .description('Return the proof as fields elements')