From 0b15db2bea70696597911e82b60f0def595c1150 Mon Sep 17 00:00:00 2001 From: Maxim Vezenov Date: Mon, 18 Mar 2024 12:56:56 +0000 Subject: [PATCH] feat!: Acir call opcode (#4773) --- .../dsl/acir_format/serde/acir.hpp | 142 +++++++++--------- build_manifest.yml | 1 + noir/Dockerfile.packages-test | 3 + .../noir-repo/acvm-repo/acir/codegen/acir.cpp | 56 ++++++- .../acvm-repo/acir/src/circuit/opcodes.rs | 16 ++ .../acvm/src/compiler/transformers/mod.rs | 1 + noir/noir-repo/acvm-repo/acvm/src/pwg/mod.rs | 1 + 7 files changed, 144 insertions(+), 76 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp index f72a3b2e724..49a8b588856 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp @@ -1063,18 +1063,7 @@ struct Directive { static ToLeRadix bincodeDeserialize(std::vector); }; - struct PermutationSort { - std::vector> inputs; - uint32_t tuple; - std::vector bits; - std::vector sort_by; - - friend bool operator==(const PermutationSort&, const PermutationSort&); - std::vector bincodeSerialize() const; - static PermutationSort bincodeDeserialize(std::vector); - }; - - std::variant value; + std::variant value; friend bool operator==(const Directive&, const Directive&); std::vector bincodeSerialize() const; @@ -1144,7 +1133,17 @@ struct Opcode { static MemoryInit bincodeDeserialize(std::vector); }; - std::variant value; + struct Call { + uint32_t id; + std::vector inputs; + std::vector outputs; + + friend bool operator==(const Call&, const Call&); + std::vector bincodeSerialize() const; + static Call bincodeDeserialize(std::vector); + }; + + std::variant value; friend bool operator==(const Opcode&, const Opcode&); std::vector bincodeSerialize() const; @@ -6450,68 +6449,6 @@ Circuit::Directive::ToLeRadix serde::Deserializable Directive::PermutationSort::bincodeSerialize() const -{ - auto serializer = serde::BincodeSerializer(); - serde::Serializable::serialize(*this, serializer); - return std::move(serializer).bytes(); -} - -inline Directive::PermutationSort Directive::PermutationSort::bincodeDeserialize(std::vector input) -{ - auto deserializer = serde::BincodeDeserializer(input); - auto value = serde::Deserializable::deserialize(deserializer); - if (deserializer.get_buffer_offset() < input.size()) { - throw_or_abort("Some input bytes were not read"); - } - return value; -} - -} // end of namespace Circuit - -template <> -template -void serde::Serializable::serialize(const Circuit::Directive::PermutationSort& obj, - Serializer& serializer) -{ - serde::Serializable::serialize(obj.inputs, serializer); - serde::Serializable::serialize(obj.tuple, serializer); - serde::Serializable::serialize(obj.bits, serializer); - serde::Serializable::serialize(obj.sort_by, serializer); -} - -template <> -template -Circuit::Directive::PermutationSort serde::Deserializable::deserialize( - Deserializer& deserializer) -{ - Circuit::Directive::PermutationSort obj; - obj.inputs = serde::Deserializable::deserialize(deserializer); - obj.tuple = serde::Deserializable::deserialize(deserializer); - obj.bits = serde::Deserializable::deserialize(deserializer); - obj.sort_by = serde::Deserializable::deserialize(deserializer); - return obj; -} - -namespace Circuit { - inline bool operator==(const Expression& lhs, const Expression& rhs) { if (!(lhs.mul_terms == rhs.mul_terms)) { @@ -7509,6 +7446,61 @@ Circuit::Opcode::MemoryInit serde::Deserializable:: namespace Circuit { +inline bool operator==(const Opcode::Call& lhs, const Opcode::Call& rhs) +{ + if (!(lhs.id == rhs.id)) { + return false; + } + if (!(lhs.inputs == rhs.inputs)) { + return false; + } + if (!(lhs.outputs == rhs.outputs)) { + return false; + } + return true; +} + +inline std::vector Opcode::Call::bincodeSerialize() const +{ + auto serializer = serde::BincodeSerializer(); + serde::Serializable::serialize(*this, serializer); + return std::move(serializer).bytes(); +} + +inline Opcode::Call Opcode::Call::bincodeDeserialize(std::vector input) +{ + auto deserializer = serde::BincodeDeserializer(input); + auto value = serde::Deserializable::deserialize(deserializer); + if (deserializer.get_buffer_offset() < input.size()) { + throw_or_abort("Some input bytes were not read"); + } + return value; +} + +} // end of namespace Circuit + +template <> +template +void serde::Serializable::serialize(const Circuit::Opcode::Call& obj, Serializer& serializer) +{ + serde::Serializable::serialize(obj.id, serializer); + serde::Serializable::serialize(obj.inputs, serializer); + serde::Serializable::serialize(obj.outputs, serializer); +} + +template <> +template +Circuit::Opcode::Call serde::Deserializable::deserialize(Deserializer& deserializer) +{ + Circuit::Opcode::Call obj; + obj.id = serde::Deserializable::deserialize(deserializer); + obj.inputs = serde::Deserializable::deserialize(deserializer); + obj.outputs = serde::Deserializable::deserialize(deserializer); + return obj; +} + +namespace Circuit { + inline bool operator==(const OpcodeLocation& lhs, const OpcodeLocation& rhs) { if (!(lhs.value == rhs.value)) { diff --git a/build_manifest.yml b/build_manifest.yml index c07877468bc..10382887281 100644 --- a/build_manifest.yml +++ b/build_manifest.yml @@ -36,6 +36,7 @@ noir-packages-tests: rebuildPatterns: .rebuild_patterns_packages dependencies: - noir + - noir-packages # Builds the brillig to avm transpiler. avm-transpiler: diff --git a/noir/Dockerfile.packages-test b/noir/Dockerfile.packages-test index 33fac5120fb..b9b4ac32267 100644 --- a/noir/Dockerfile.packages-test +++ b/noir/Dockerfile.packages-test @@ -1,6 +1,9 @@ FROM aztecprotocol/noir AS noir +FROM --platform=linux/amd64 aztecprotocol/noir-packages as noir-packages FROM node:20 AS builder +COPY --from=noir-packages /usr/src/noir/packages /usr/src/noir/packages + COPY --from=noir /usr/src/noir/noir-repo/target/release /usr/src/noir/noir-repo/target/release ENV PATH=${PATH}:/usr/src/noir/noir-repo/target/release RUN curl https://sh.rustup.rs -sSf | bash -s -- -y diff --git a/noir/noir-repo/acvm-repo/acir/codegen/acir.cpp b/noir/noir-repo/acvm-repo/acir/codegen/acir.cpp index 6fdb62c5674..4c1497a1dfb 100644 --- a/noir/noir-repo/acvm-repo/acir/codegen/acir.cpp +++ b/noir/noir-repo/acvm-repo/acir/codegen/acir.cpp @@ -1071,7 +1071,17 @@ namespace Circuit { static MemoryInit bincodeDeserialize(std::vector); }; - std::variant value; + struct Call { + uint32_t id; + std::vector inputs; + std::vector outputs; + + friend bool operator==(const Call&, const Call&); + std::vector bincodeSerialize() const; + static Call bincodeDeserialize(std::vector); + }; + + std::variant value; friend bool operator==(const Opcode&, const Opcode&); std::vector bincodeSerialize() const; @@ -6135,6 +6145,50 @@ Circuit::Opcode::MemoryInit serde::Deserializable:: return obj; } +namespace Circuit { + + inline bool operator==(const Opcode::Call &lhs, const Opcode::Call &rhs) { + if (!(lhs.id == rhs.id)) { return false; } + if (!(lhs.inputs == rhs.inputs)) { return false; } + if (!(lhs.outputs == rhs.outputs)) { return false; } + return true; + } + + inline std::vector Opcode::Call::bincodeSerialize() const { + auto serializer = serde::BincodeSerializer(); + serde::Serializable::serialize(*this, serializer); + return std::move(serializer).bytes(); + } + + inline Opcode::Call Opcode::Call::bincodeDeserialize(std::vector input) { + auto deserializer = serde::BincodeDeserializer(input); + auto value = serde::Deserializable::deserialize(deserializer); + if (deserializer.get_buffer_offset() < input.size()) { + throw serde::deserialization_error("Some input bytes were not read"); + } + return value; + } + +} // end of namespace Circuit + +template <> +template +void serde::Serializable::serialize(const Circuit::Opcode::Call &obj, Serializer &serializer) { + serde::Serializable::serialize(obj.id, serializer); + serde::Serializable::serialize(obj.inputs, serializer); + serde::Serializable::serialize(obj.outputs, serializer); +} + +template <> +template +Circuit::Opcode::Call serde::Deserializable::deserialize(Deserializer &deserializer) { + Circuit::Opcode::Call obj; + obj.id = serde::Deserializable::deserialize(deserializer); + obj.inputs = serde::Deserializable::deserialize(deserializer); + obj.outputs = serde::Deserializable::deserialize(deserializer); + return obj; +} + namespace Circuit { inline bool operator==(const OpcodeLocation &lhs, const OpcodeLocation &rhs) { diff --git a/noir/noir-repo/acvm-repo/acir/src/circuit/opcodes.rs b/noir/noir-repo/acvm-repo/acir/src/circuit/opcodes.rs index f725ba8c32a..064a9d1244a 100644 --- a/noir/noir-repo/acvm-repo/acir/src/circuit/opcodes.rs +++ b/noir/noir-repo/acvm-repo/acir/src/circuit/opcodes.rs @@ -29,6 +29,17 @@ pub enum Opcode { block_id: BlockId, init: Vec, }, + /// Calls to functions represented as a separate circuit. A call opcode allows us + /// to build a call stack when executing the outer-most circuit. + Call { + /// Id for the function being called. It is the responsibility of the executor + /// to fetch the appropriate circuit from this id. + id: u32, + /// Inputs to the function call + inputs: Vec, + /// Outputs of the function call + outputs: Vec, + }, } impl std::fmt::Display for Opcode { @@ -86,6 +97,11 @@ impl std::fmt::Display for Opcode { write!(f, "INIT ")?; write!(f, "(id: {}, len: {}) ", block_id.0, init.len()) } + Opcode::Call { id, inputs, outputs } => { + write!(f, "CALL func {}: ", id)?; + writeln!(f, "inputs: {:?}", inputs)?; + writeln!(f, "outputs: {:?}", outputs) + } } } } diff --git a/noir/noir-repo/acvm-repo/acvm/src/compiler/transformers/mod.rs b/noir/noir-repo/acvm-repo/acvm/src/compiler/transformers/mod.rs index 214243d9360..2e549854521 100644 --- a/noir/noir-repo/acvm-repo/acvm/src/compiler/transformers/mod.rs +++ b/noir/noir-repo/acvm-repo/acvm/src/compiler/transformers/mod.rs @@ -145,6 +145,7 @@ pub(super) fn transform_internal( new_acir_opcode_positions.push(acir_opcode_positions[index]); transformed_opcodes.push(opcode); } + Opcode::Call { .. } => todo!("Handle Call opcodes in the ACVM"), } } diff --git a/noir/noir-repo/acvm-repo/acvm/src/pwg/mod.rs b/noir/noir-repo/acvm-repo/acvm/src/pwg/mod.rs index d8323e5ef5f..0fd733a6336 100644 --- a/noir/noir-repo/acvm-repo/acvm/src/pwg/mod.rs +++ b/noir/noir-repo/acvm-repo/acvm/src/pwg/mod.rs @@ -281,6 +281,7 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> { Ok(Some(foreign_call)) => return self.wait_for_foreign_call(foreign_call), res => res.map(|_| ()), }, + Opcode::Call { .. } => todo!("Handle Call opcodes in the ACVM"), }; self.handle_opcode_resolution(resolution) }