From 94f436ed12f17d2671dbaea8bf581fc0cda0986d Mon Sep 17 00:00:00 2001 From: Lucas Xia Date: Mon, 12 Feb 2024 16:53:24 -0500 Subject: [PATCH] refactor: updating field conversion code without pointer hack (#4537) We currently use a pointer hack for functions like `calc_num_bn254_frs` and `convert_from_bn254_frs`. This PR aims to clean these up using traits and template metaprogramming. Also closes https://github.com/AztecProtocol/barretenberg/issues/846, by just sending and receiving an std::array instead of an AllValues object. --- .../src/barretenberg/ecc/curves/bn254/fq.hpp | 3 + .../src/barretenberg/ecc/curves/bn254/fr.hpp | 3 + .../ecc/fields/field_conversion.cpp | 59 +--- .../ecc/fields/field_conversion.hpp | 225 ++++--------- .../ecc/groups/affine_element.hpp | 5 +- .../barretenberg/polynomials/univariate.hpp | 14 + .../primitives/field/field_conversion.hpp | 313 +++++------------- .../field/field_conversion.test.cpp | 37 +-- .../recursion/honk/transcript/transcript.hpp | 4 +- .../src/barretenberg/sumcheck/sumcheck.hpp | 2 +- 10 files changed, 198 insertions(+), 467 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/ecc/curves/bn254/fq.hpp b/barretenberg/cpp/src/barretenberg/ecc/curves/bn254/fq.hpp index 782702d5a52..ffda71bbd47 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/curves/bn254/fq.hpp +++ b/barretenberg/cpp/src/barretenberg/ecc/curves/bn254/fq.hpp @@ -60,6 +60,9 @@ class Bn254FqParams { // used in msgpack schema serialization static constexpr char schema_name[] = "fq"; static constexpr bool has_high_2adicity = false; + + // The modulus is larger than BN254 scalar field modulus, so it maps to two BN254 scalars + static constexpr size_t NUM_BN254_SCALARS = 2; }; using fq = field; diff --git a/barretenberg/cpp/src/barretenberg/ecc/curves/bn254/fr.hpp b/barretenberg/cpp/src/barretenberg/ecc/curves/bn254/fr.hpp index bfa55e67e2e..cf43e95d569 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/curves/bn254/fr.hpp +++ b/barretenberg/cpp/src/barretenberg/ecc/curves/bn254/fr.hpp @@ -66,6 +66,9 @@ class Bn254FrParams { // used in msgpack schema serialization static constexpr char schema_name[] = "fr"; static constexpr bool has_high_2adicity = true; + + // This is a BN254 scalar, so it represents one BN254 scalar + static constexpr size_t NUM_BN254_SCALARS = 1; }; using fr = field; diff --git a/barretenberg/cpp/src/barretenberg/ecc/fields/field_conversion.cpp b/barretenberg/cpp/src/barretenberg/ecc/fields/field_conversion.cpp index 0d09d6b5752..dc37e2f8d36 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/fields/field_conversion.cpp +++ b/barretenberg/cpp/src/barretenberg/ecc/fields/field_conversion.cpp @@ -7,18 +7,6 @@ namespace bb::field_conversion { static constexpr uint64_t NUM_LIMB_BITS = plonk::NUM_LIMB_BITS_IN_FIELD_SIMULATION; static constexpr uint64_t TOTAL_BITS = 254; -bb::fr convert_from_bn254_frs(std::span fr_vec, bb::fr* /*unused*/) -{ - ASSERT(fr_vec.size() == 1); - return fr_vec[0]; -} - -bool convert_from_bn254_frs(std::span fr_vec, bool* /*unused*/) -{ - ASSERT(fr_vec.size() == 1); - return fr_vec[0] != 0; -} - /** * @brief Converts 2 bb::fr elements to grumpkin::fr * @details First, this function must take in 2 bb::fr elements because the grumpkin::fr field has a larger modulus than @@ -32,7 +20,7 @@ bool convert_from_bn254_frs(std::span fr_vec, bool* /*unused*/) * @param high_bits_in * @return grumpkin::fr */ -grumpkin::fr convert_from_bn254_frs(std::span fr_vec, grumpkin::fr* /*unused*/) +grumpkin::fr convert_grumpkin_fr_from_bn254_frs(std::span fr_vec) { // Combines the two elements into one uint256_t, and then convert that to a grumpkin::fr ASSERT(uint256_t(fr_vec[0]) < (uint256_t(1) << (NUM_LIMB_BITS * 2))); // lower 136 bits @@ -42,25 +30,6 @@ grumpkin::fr convert_from_bn254_frs(std::span fr_vec, grumpkin::fr return result; } -curve::BN254::AffineElement convert_from_bn254_frs(std::span fr_vec, - curve::BN254::AffineElement* /*unused*/) -{ - curve::BN254::AffineElement val; - val.x = convert_from_bn254_frs(fr_vec.subspan(0, 2)); - val.y = convert_from_bn254_frs(fr_vec.subspan(2, 2)); - return val; -} - -curve::Grumpkin::AffineElement convert_from_bn254_frs(std::span fr_vec, - curve::Grumpkin::AffineElement* /*unused*/) -{ - ASSERT(fr_vec.size() == 2); - curve::Grumpkin::AffineElement val; - val.x = fr_vec[0]; - val.y = fr_vec[1]; - return val; -} - /** * @brief Converts grumpkin::fr to 2 bb::fr elements * @details First, this function must return 2 bb::fr elements because the grumpkin::fr field has a larger modulus than @@ -74,7 +43,7 @@ curve::Grumpkin::AffineElement convert_from_bn254_frs(std::span fr * @param input * @return std::array */ -std::vector convert_to_bn254_frs(const grumpkin::fr& val) +std::vector convert_grumpkin_fr_to_bn254_frs(const grumpkin::fr& val) { // Goal is to slice up the 64 bit limbs of grumpkin::fr/uint256_t to mirror the 68 bit limbs of bigfield // We accomplish this by dividing the grumpkin::fr's value into two 68*2=136 bit pieces. @@ -89,30 +58,6 @@ std::vector convert_to_bn254_frs(const grumpkin::fr& val) return result; } -std::vector convert_to_bn254_frs(const bb::fr& val) -{ - std::vector fr_vec{ val }; - return fr_vec; -} - -std::vector convert_to_bn254_frs(const curve::BN254::AffineElement& val) -{ - auto fr_vec_x = convert_to_bn254_frs(val.x); - auto fr_vec_y = convert_to_bn254_frs(val.y); - std::vector fr_vec(fr_vec_x.begin(), fr_vec_x.end()); - fr_vec.insert(fr_vec.end(), fr_vec_y.begin(), fr_vec_y.end()); - return fr_vec; -} - -std::vector convert_to_bn254_frs(const curve::Grumpkin::AffineElement& val) -{ - auto fr_vec_x = convert_to_bn254_frs(val.x); - auto fr_vec_y = convert_to_bn254_frs(val.y); - std::vector fr_vec(fr_vec_x.begin(), fr_vec_x.end()); - fr_vec.insert(fr_vec.end(), fr_vec_y.begin(), fr_vec_y.end()); - return fr_vec; -} - grumpkin::fr convert_to_grumpkin_fr(const bb::fr& f) { const uint64_t NUM_BITS_IN_TWO_LIMBS = 2 * NUM_LIMB_BITS; // the number of bits in 2 bigfield limbs which is 136 diff --git a/barretenberg/cpp/src/barretenberg/ecc/fields/field_conversion.hpp b/barretenberg/cpp/src/barretenberg/ecc/fields/field_conversion.hpp index a3abc63673d..8e4962dd525 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/fields/field_conversion.hpp +++ b/barretenberg/cpp/src/barretenberg/ecc/fields/field_conversion.hpp @@ -4,6 +4,7 @@ #include "barretenberg/ecc/curves/bn254/fr.hpp" #include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp" #include "barretenberg/polynomials/univariate.hpp" +#include "barretenberg/proof_system/types/circuit_type.hpp" namespace bb::field_conversion { @@ -15,48 +16,22 @@ namespace bb::field_conversion { * @tparam T * @return constexpr size_t */ -template constexpr size_t calc_num_bn254_frs(); - -constexpr size_t calc_num_bn254_frs(bb::fr* /*unused*/) -{ - return 1; -} - -constexpr size_t calc_num_bn254_frs(grumpkin::fr* /*unused*/) -{ - return 2; -} - -template constexpr size_t calc_num_bn254_frs(T* /*unused*/) -{ - return 1; // meant for integral types that are less than 254 bits -} - -constexpr size_t calc_num_bn254_frs(curve::BN254::AffineElement* /*unused*/) -{ - return 2 * calc_num_bn254_frs(); -} - -constexpr size_t calc_num_bn254_frs(curve::Grumpkin::AffineElement* /*unused*/) -{ - return 2 * calc_num_bn254_frs(); -} - -template constexpr size_t calc_num_bn254_frs(std::array* /*unused*/) -{ - return N * calc_num_bn254_frs(); -} - -template constexpr size_t calc_num_bn254_frs(bb::Univariate* /*unused*/) -{ - return N * calc_num_bn254_frs(); -} - template constexpr size_t calc_num_bn254_frs() { - return calc_num_bn254_frs(static_cast(nullptr)); + if constexpr (IsAnyOf) { + return 1; + } else if constexpr (IsAnyOf) { + return T::Params::NUM_BN254_SCALARS; + } else if constexpr (IsAnyOf) { + return 2 * calc_num_bn254_frs(); + } else { + // Array or Univariate + return calc_num_bn254_frs() * (std::tuple_size::value); + } } +grumpkin::fr convert_grumpkin_fr_from_bn254_frs(std::span fr_vec); + /** * @brief Conversions from vector of bb::fr elements to transcript types. * @details We want to support the following types: bool, size_t, uint32_t, uint64_t, bb::fr, grumpkin::fr, @@ -68,75 +43,40 @@ template constexpr size_t calc_num_bn254_frs() * @param fr_vec * @return T */ -template T convert_from_bn254_frs(std::span fr_vec); - -bool convert_from_bn254_frs(std::span fr_vec, bool* /*unused*/); - -template inline T convert_from_bn254_frs(std::span fr_vec, T* /*unused*/) -{ - ASSERT(fr_vec.size() == 1); - return static_cast(fr_vec[0]); -} - -bb::fr convert_from_bn254_frs(std::span fr_vec, bb::fr* /*unused*/); - -grumpkin::fr convert_from_bn254_frs(std::span fr_vec, grumpkin::fr* /*unused*/); - -curve::BN254::AffineElement convert_from_bn254_frs(std::span fr_vec, - curve::BN254::AffineElement* /*unused*/); - -curve::Grumpkin::AffineElement convert_from_bn254_frs(std::span fr_vec, - curve::Grumpkin::AffineElement* /*unused*/); - -template -inline std::array convert_from_bn254_frs(std::span fr_vec, std::array* /*unused*/) -{ - std::array val; - for (size_t i = 0; i < N; ++i) { - val[i] = fr_vec[i]; - } - return val; -} - -template -inline std::array convert_from_bn254_frs(std::span fr_vec, - std::array* /*unused*/) -{ - std::array val; - for (size_t i = 0; i < N; ++i) { - std::vector fr_vec_tmp{ fr_vec[2 * i], - fr_vec[2 * i + 1] }; // each pair of consecutive elements is a grumpkin::fr - val[i] = convert_from_bn254_frs(fr_vec_tmp); - } - return val; -} - -template -inline Univariate convert_from_bn254_frs(std::span fr_vec, Univariate* /*unused*/) -{ - Univariate val; - for (size_t i = 0; i < N; ++i) { - val.evaluations[i] = fr_vec[i]; - } - return val; -} - -template -inline Univariate convert_from_bn254_frs(std::span fr_vec, - Univariate* /*unused*/) +template T convert_from_bn254_frs(std::span fr_vec) { - Univariate val; - for (size_t i = 0; i < N; ++i) { - std::vector fr_vec_tmp{ fr_vec[2 * i], fr_vec[2 * i + 1] }; - val.evaluations[i] = convert_from_bn254_frs(fr_vec_tmp); + if constexpr (IsAnyOf) { + ASSERT(fr_vec.size() == 1); + return bool(fr_vec[0]); + } else if constexpr (IsAnyOf) { + ASSERT(fr_vec.size() == 1); + return static_cast(fr_vec[0]); + } else if constexpr (IsAnyOf) { + ASSERT(fr_vec.size() == 2); + return convert_grumpkin_fr_from_bn254_frs(fr_vec); + } else if constexpr (IsAnyOf) { + using BaseField = typename T::Fq; + constexpr size_t BASE_FIELD_SCALAR_SIZE = calc_num_bn254_frs(); + ASSERT(fr_vec.size() == 2 * BASE_FIELD_SCALAR_SIZE); + T val; + val.x = convert_from_bn254_frs(fr_vec.subspan(0, BASE_FIELD_SCALAR_SIZE)); + val.y = convert_from_bn254_frs(fr_vec.subspan(BASE_FIELD_SCALAR_SIZE, BASE_FIELD_SCALAR_SIZE)); + return val; + } else { + // Array or Univariate + T val; + constexpr size_t FieldScalarSize = calc_num_bn254_frs(); + ASSERT(fr_vec.size() == FieldScalarSize * std::tuple_size::value); + size_t i = 0; + for (auto& x : val) { + x = convert_from_bn254_frs(fr_vec.subspan(FieldScalarSize * i, FieldScalarSize)); + ++i; + } + return val; } - return val; } -template T convert_from_bn254_frs(std::span fr_vec) -{ - return convert_from_bn254_frs(fr_vec, static_cast(nullptr)); -} +std::vector convert_grumpkin_fr_to_bn254_frs(const grumpkin::fr& val); /** * @brief Conversion from transcript values to bb::frs @@ -147,65 +87,28 @@ template T convert_from_bn254_frs(std::span fr_vec) * @param val * @return std::vector */ -template std::vector inline convert_to_bn254_frs(const T& val) -{ - std::vector fr_vec{ val }; - return fr_vec; -} - -std::vector convert_to_bn254_frs(const grumpkin::fr& val); - -std::vector convert_to_bn254_frs(const bb::fr& val); - -std::vector convert_to_bn254_frs(const curve::BN254::AffineElement& val); - -std::vector convert_to_bn254_frs(const curve::Grumpkin::AffineElement& val); - -template std::vector inline convert_to_bn254_frs(const std::array& val) -{ - std::vector fr_vec(val.begin(), val.end()); - return fr_vec; -} - -template std::vector inline convert_to_bn254_frs(const std::array& val) -{ - std::vector fr_vec; - for (size_t i = 0; i < N; ++i) { - auto tmp_vec = convert_to_bn254_frs(val[i]); - fr_vec.insert(fr_vec.end(), tmp_vec.begin(), tmp_vec.end()); - } - return fr_vec; -} - -template std::vector inline convert_to_bn254_frs(const bb::Univariate& val) -{ - std::vector fr_vec; - for (size_t i = 0; i < N; ++i) { - auto tmp_vec = convert_to_bn254_frs(val.evaluations[i]); - fr_vec.insert(fr_vec.end(), tmp_vec.begin(), tmp_vec.end()); - } - return fr_vec; -} - -template std::vector inline convert_to_bn254_frs(const bb::Univariate& val) -{ - std::vector fr_vec; - for (size_t i = 0; i < N; ++i) { - auto tmp_vec = convert_to_bn254_frs(val.evaluations[i]); - fr_vec.insert(fr_vec.end(), tmp_vec.begin(), tmp_vec.end()); - } - return fr_vec; -} - -template std::vector inline convert_to_bn254_frs(const AllValues& val) -{ - auto data = val.get_all(); - std::vector fr_vec; - for (auto& item : data) { - auto tmp_vec = convert_to_bn254_frs(item); - fr_vec.insert(fr_vec.end(), tmp_vec.begin(), tmp_vec.end()); +template std::vector convert_to_bn254_frs(const T& val) +{ + if constexpr (IsAnyOf) { + std::vector fr_vec{ val }; + return fr_vec; + } else if constexpr (IsAnyOf) { + return convert_grumpkin_fr_to_bn254_frs(val); + } else if constexpr (IsAnyOf) { + auto fr_vec_x = convert_to_bn254_frs(val.x); + auto fr_vec_y = convert_to_bn254_frs(val.y); + std::vector fr_vec(fr_vec_x.begin(), fr_vec_x.end()); + fr_vec.insert(fr_vec.end(), fr_vec_y.begin(), fr_vec_y.end()); + return fr_vec; + } else { + // Array or Univariate + std::vector fr_vec; + for (auto& x : val) { + auto tmp_vec = convert_to_bn254_frs(x); + fr_vec.insert(fr_vec.end(), tmp_vec.begin(), tmp_vec.end()); + } + return fr_vec; } - return fr_vec; } grumpkin::fr convert_to_grumpkin_fr(const bb::fr& f); diff --git a/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element.hpp b/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element.hpp index 117dab0ff2b..be823c3f176 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element.hpp +++ b/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element.hpp @@ -8,8 +8,11 @@ namespace bb::group_elements { template concept SupportsHashToCurve = T::can_hash_to_curve; -template class alignas(64) affine_element { +template class alignas(64) affine_element { public: + using Fq = Fq_; + using Fr = Fr_; + using in_buf = const uint8_t*; using vec_in_buf = const uint8_t*; using out_buf = uint8_t*; diff --git a/barretenberg/cpp/src/barretenberg/polynomials/univariate.hpp b/barretenberg/cpp/src/barretenberg/polynomials/univariate.hpp index 188c0bed3c4..03467a23f7b 100644 --- a/barretenberg/cpp/src/barretenberg/polynomials/univariate.hpp +++ b/barretenberg/cpp/src/barretenberg/polynomials/univariate.hpp @@ -25,6 +25,8 @@ template class Univariate static constexpr size_t LENGTH = domain_end - domain_start; using View = UnivariateView; + using value_type = Fr; // used to get the type of the elements consistently with std::array + // TODO(https://github.com/AztecProtocol/barretenberg/issues/714) Try out std::valarray? std::array evaluations; @@ -337,6 +339,13 @@ template class Univariate result *= full_numerator_value; return result; }; + + // Begin iterators + auto begin() { return evaluations.begin(); } + auto begin() const { return evaluations.begin(); } + // End iterators + auto end() { return evaluations.end(); } + auto end() const { return evaluations.end(); } }; template @@ -497,3 +506,8 @@ template std::array array_to_array }; } // namespace bb + +namespace std { +template struct tuple_size> : std::integral_constant {}; + +} // namespace std \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field_conversion.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field_conversion.hpp index f14b05b8cd5..160743853a0 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field_conversion.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field_conversion.hpp @@ -28,10 +28,10 @@ template inline T convert_challenge(Builder& buil } } -template inline std::array, 2> convert_grumpkin_fr_to_bn254_frs(const fq& input) +template inline std::vector> convert_grumpkin_fr_to_bn254_frs(const fq& input) { fr shift(static_cast(1) << NUM_LIMB_BITS); - std::array, 2> result; + std::vector> result(2); result[0] = input.binary_basis_limbs[0].element + (input.binary_basis_limbs[1].element * shift); result[1] = input.binary_basis_limbs[2].element + (input.binary_basis_limbs[3].element * shift); return result; @@ -39,251 +39,122 @@ template inline std::array, 2> convert_grumpkin_f /** * @brief Calculates the size of a types (in their native form) in terms of frs * @details We want to support the following types: fr, fq, - * bn254_element, bb::Univariate, std::array, for + * bn254_element, grumpkin_element, std::array, for * FF = fr or fq, and N is arbitrary + * @tparam Builder * @tparam T * @return constexpr size_t */ -template constexpr size_t calc_num_bn254_frs(); - -template constexpr size_t calc_num_bn254_frs(fr* /*unused*/) -{ - return 1; -} - -template constexpr size_t calc_num_bn254_frs(fq* /*unused*/) -{ - return 2; -} - -template constexpr size_t calc_num_bn254_frs(bn254_element* /*unused*/) -{ - return 2 * calc_num_bn254_frs>(); -} - -template constexpr size_t calc_num_bn254_frs(grumpkin_element* /*unused*/) -{ - return 2 * calc_num_bn254_frs>(); -} - -template constexpr size_t calc_num_bn254_frs(std::array* /*unused*/) -{ - return N * calc_num_bn254_frs(); -} - -template constexpr size_t calc_num_bn254_frs(bb::Univariate* /*unused*/) -{ - return N * calc_num_bn254_frs(); -} - -template constexpr size_t calc_num_bn254_frs() -{ - return calc_num_bn254_frs(static_cast(nullptr)); +template constexpr size_t calc_num_bn254_frs() +{ + if constexpr (IsAnyOf>) { + return Bn254FrParams::NUM_BN254_SCALARS; + } else if constexpr (IsAnyOf>) { + return Bn254FqParams::NUM_BN254_SCALARS; + } else if constexpr (IsAnyOf>) { + return 2 * calc_num_bn254_frs>(); + } else if constexpr (IsAnyOf>) { + return 2 * calc_num_bn254_frs>(); + } else { + // Array or Univariate + return calc_num_bn254_frs() * (std::tuple_size::value); + } } /** * @brief Conversions from vector of fr elements to transcript types. * @details We want to support the following types: fr, fq, - * bn254_element, bb::Univariate, std::array, for + * bn254_element, grumpkin_element, std::array, for * FF = fr or fq, and N is arbitrary + * @tparam Builder * @tparam T + * @param builder * @param fr_vec * @return T */ -template T convert_from_bn254_frs(Builder& builder, std::span> fr_vec); - -template -inline fr convert_from_bn254_frs(const Builder& /*unused*/, - std::span> fr_vec, - fr* /*unused*/) -{ - ASSERT(fr_vec.size() == 1); - return fr_vec[0]; -} - -template -inline fq convert_from_bn254_frs(const Builder& /*unused*/, - std::span> fr_vec, - fq* /*unused*/) -{ - ASSERT(fr_vec.size() == 2); - bigfield result(fr_vec[0], fr_vec[1], 0, 0); - return result; -} - -template -inline bn254_element convert_from_bn254_frs(Builder& builder, - std::span> fr_vec, - bn254_element* /*unused*/) -{ - ASSERT(fr_vec.size() == 4); - bn254_element val; - val.x = convert_from_bn254_frs>(builder, fr_vec.subspan(0, 2)); - val.y = convert_from_bn254_frs>(builder, fr_vec.subspan(2, 2)); - return val; -} - -template -inline grumpkin_element convert_from_bn254_frs(Builder& builder, - std::span> fr_vec, - grumpkin_element* /*unused*/) -{ - ASSERT(fr_vec.size() == 2); - grumpkin_element val(convert_from_bn254_frs>(builder, fr_vec.subspan(0, 1)), - convert_from_bn254_frs>(builder, fr_vec.subspan(1, 1)), - false); - return val; -} - -template -inline std::array, N> convert_from_bn254_frs(const Builder& /*unused*/, - std::span> fr_vec, - std::array, N>* /*unused*/) -{ - std::array, N> val; - for (size_t i = 0; i < N; ++i) { - val[i] = fr_vec[i]; - } - return val; -} - -template -inline std::array, N> convert_from_bn254_frs(Builder& builder, - std::span> fr_vec, - std::array, N>* /*unused*/) -{ - std::array, N> val; - for (size_t i = 0; i < N; ++i) { - std::vector> fr_vec_tmp{ fr_vec[2 * i], - fr_vec[2 * i + 1] }; // each pair of consecutive elements is a fq - val[i] = convert_from_bn254_frs>(builder, fr_vec_tmp); +template T convert_from_bn254_frs(Builder& builder, std::span> fr_vec) +{ + if constexpr (IsAnyOf>) { + ASSERT(fr_vec.size() == 1); + return fr_vec[0]; + } else if constexpr (IsAnyOf>) { + ASSERT(fr_vec.size() == 2); + fq result(fr_vec[0], fr_vec[1], 0, 0); + return result; + } else if constexpr (IsAnyOf>) { + using BaseField = fq; + constexpr size_t BASE_FIELD_SCALAR_SIZE = calc_num_bn254_frs(); + ASSERT(fr_vec.size() == 2 * BASE_FIELD_SCALAR_SIZE); + bn254_element result; + result.x = convert_from_bn254_frs(builder, fr_vec.subspan(0, BASE_FIELD_SCALAR_SIZE)); + result.y = convert_from_bn254_frs( + builder, fr_vec.subspan(BASE_FIELD_SCALAR_SIZE, BASE_FIELD_SCALAR_SIZE)); + return result; + } else if constexpr (IsAnyOf>) { + using BaseField = fr; + constexpr size_t BASE_FIELD_SCALAR_SIZE = calc_num_bn254_frs(); + ASSERT(fr_vec.size() == 2 * BASE_FIELD_SCALAR_SIZE); + grumpkin_element result( + convert_from_bn254_frs>(builder, fr_vec.subspan(0, BASE_FIELD_SCALAR_SIZE)), + convert_from_bn254_frs>( + builder, fr_vec.subspan(BASE_FIELD_SCALAR_SIZE, BASE_FIELD_SCALAR_SIZE)), + false); + return result; + } else { + // Array or Univariate + T val; + constexpr size_t FieldScalarSize = calc_num_bn254_frs(); + ASSERT(fr_vec.size() == FieldScalarSize * std::tuple_size::value); + size_t i = 0; + for (auto& x : val) { + x = convert_from_bn254_frs( + builder, fr_vec.subspan(FieldScalarSize * i, FieldScalarSize)); + ++i; + } + return val; } - return val; -} - -template -inline bb::Univariate, N> convert_from_bn254_frs(const Builder& /*unused*/, - std::span> fr_vec, - bb::Univariate, N>* /*unused*/) -{ - bb::Univariate, N> val; - for (size_t i = 0; i < N; ++i) { - val.evaluations[i] = fr_vec[i]; - } - return val; -} - -template -inline bb::Univariate, N> convert_from_bn254_frs(Builder& builder, - std::span> fr_vec, - bb::Univariate, N>* /*unused*/) -{ - bb::Univariate, N> val; - for (size_t i = 0; i < N; ++i) { - std::vector> fr_vec_tmp{ fr_vec[2 * i], fr_vec[2 * i + 1] }; - val.evaluations[i] = convert_from_bn254_frs>(builder, fr_vec_tmp); - } - return val; -} - -template -inline T convert_from_bn254_frs(Builder& builder, std::span> fr_vec) -{ - return convert_from_bn254_frs(builder, fr_vec, static_cast(nullptr)); } /** * @brief Conversion from transcript values to frs * @details We want to support the following types: bool, size_t, uint32_t, uint64_t, fr, fq, - * bn254_element, curve::Grumpkin::AffineElement, bb::Univariate, std::array, grumpkin_element, std::array, for FF = fr/fq, and N is arbitrary. + * @tparam Builder * @tparam T * @param val * @return std::vector> */ -template inline std::vector> convert_to_bn254_frs(const fq& val) -{ - auto fr_arr = convert_grumpkin_fr_to_bn254_frs(val); - std::vector> fr_vec(fr_arr.begin(), fr_arr.end()); - return fr_vec; -} - -template inline std::vector> convert_to_bn254_frs(const fr& val) -{ - std::vector> fr_vec{ val }; - return fr_vec; -} - -template inline std::vector> convert_to_bn254_frs(const bn254_element& val) -{ - auto fr_vec_x = convert_to_bn254_frs(val.x); - auto fr_vec_y = convert_to_bn254_frs(val.y); - std::vector> fr_vec(fr_vec_x.begin(), fr_vec_x.end()); - fr_vec.insert(fr_vec.end(), fr_vec_y.begin(), fr_vec_y.end()); - return fr_vec; -} - -template inline std::vector> convert_to_bn254_frs(const grumpkin_element& val) -{ - auto fr_vec_x = convert_to_bn254_frs(val.x); - auto fr_vec_y = convert_to_bn254_frs(val.y); - std::vector> fr_vec(fr_vec_x.begin(), fr_vec_x.end()); - fr_vec.insert(fr_vec.end(), fr_vec_y.begin(), fr_vec_y.end()); - return fr_vec; -} - -template -inline std::vector> convert_to_bn254_frs(const std::array, N>& val) -{ - std::vector> fr_vec(val.begin(), val.end()); - return fr_vec; -} - -template -inline std::vector> convert_to_bn254_frs(const std::array, N>& val) -{ - std::vector> fr_vec; - for (size_t i = 0; i < N; ++i) { - auto tmp_vec = convert_to_bn254_frs(val[i]); - fr_vec.insert(fr_vec.end(), tmp_vec.begin(), tmp_vec.end()); - } - return fr_vec; -} - -template -inline std::vector> convert_to_bn254_frs(const bb::Univariate, N>& val) -{ - std::vector> fr_vec; - for (size_t i = 0; i < N; ++i) { - auto tmp_vec = convert_to_bn254_frs(val.evaluations[i]); - fr_vec.insert(fr_vec.end(), tmp_vec.begin(), tmp_vec.end()); - } - return fr_vec; -} - -template -inline std::vector> convert_to_bn254_frs(const bb::Univariate, N>& val) -{ - std::vector> fr_vec; - for (size_t i = 0; i < N; ++i) { - auto tmp_vec = convert_to_bn254_frs(val.evaluations[i]); - fr_vec.insert(fr_vec.end(), tmp_vec.begin(), tmp_vec.end()); - } - return fr_vec; -} - -// TODO(https://github.com/AztecProtocol/barretenberg/issues/846): solve this annoying asymmetry - AllValues vs -// std::array, N> -template -inline std::vector> convert_to_bn254_frs(const AllValues& val) -{ - auto data = val.get_all(); - std::vector> fr_vec; - for (auto& item : data) { - auto tmp_vec = convert_to_bn254_frs(item); - fr_vec.insert(fr_vec.end(), tmp_vec.begin(), tmp_vec.end()); +template std::vector> convert_to_bn254_frs(const T& val) +{ + if constexpr (IsAnyOf>) { + std::vector> fr_vec{ val }; + return fr_vec; + } else if constexpr (IsAnyOf>) { + return convert_grumpkin_fr_to_bn254_frs(val); + } else if constexpr (IsAnyOf>) { + using BaseField = fq; + auto fr_vec_x = convert_to_bn254_frs(val.x); + auto fr_vec_y = convert_to_bn254_frs(val.y); + std::vector> fr_vec(fr_vec_x.begin(), fr_vec_x.end()); + fr_vec.insert(fr_vec.end(), fr_vec_y.begin(), fr_vec_y.end()); + return fr_vec; + } else if constexpr (IsAnyOf>) { + using BaseField = fr; + auto fr_vec_x = convert_to_bn254_frs(val.x); + auto fr_vec_y = convert_to_bn254_frs(val.y); + std::vector> fr_vec(fr_vec_x.begin(), fr_vec_x.end()); + fr_vec.insert(fr_vec.end(), fr_vec_y.begin(), fr_vec_y.end()); + return fr_vec; + } else { + // Array or Univariate + std::vector> fr_vec; + for (auto& x : val) { + auto tmp_vec = convert_to_bn254_frs(x); + fr_vec.insert(fr_vec.end(), tmp_vec.begin(), tmp_vec.end()); + } + return fr_vec; } - return fr_vec; } } // namespace bb::stdlib::field_conversion \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field_conversion.test.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field_conversion.test.cpp index 51624b50abe..73edec370e7 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field_conversion.test.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field_conversion.test.cpp @@ -1,4 +1,5 @@ #include "barretenberg/stdlib/primitives/field/field_conversion.hpp" +#include "barretenberg/common/zip_view.hpp" #include namespace bb::stdlib::field_conversion_tests { @@ -12,34 +13,22 @@ template class StdlibFieldConversionTests : public ::testing: public: template void check_conversion(Builder& builder, T x) { - size_t len = bb::stdlib::field_conversion::calc_num_bn254_frs(); - auto frs = bb::stdlib::field_conversion::convert_to_bn254_frs(x); + size_t len = bb::stdlib::field_conversion::calc_num_bn254_frs(); + auto frs = bb::stdlib::field_conversion::convert_to_bn254_frs(x); EXPECT_EQ(len, frs.size()); auto y = bb::stdlib::field_conversion::convert_from_bn254_frs(builder, frs); EXPECT_EQ(x.get_value(), y.get_value()); } - template void check_conversion_array(Builder& builder, T x) + template void check_conversion_iterable(Builder& builder, T x) { - size_t len = bb::stdlib::field_conversion::calc_num_bn254_frs(); - auto frs = bb::stdlib::field_conversion::convert_to_bn254_frs(x); + size_t len = bb::stdlib::field_conversion::calc_num_bn254_frs(); + auto frs = bb::stdlib::field_conversion::convert_to_bn254_frs(x); EXPECT_EQ(len, frs.size()); auto y = bb::stdlib::field_conversion::convert_from_bn254_frs(builder, frs); EXPECT_EQ(x.size(), y.size()); - for (size_t i = 0; i < x.size(); i++) { - EXPECT_EQ(x[i].get_value(), y[i].get_value()); - } - } - - template void check_conversion_univariate(Builder& builder, T x) - { - size_t len = bb::stdlib::field_conversion::calc_num_bn254_frs(); - auto frs = bb::stdlib::field_conversion::convert_to_bn254_frs(x); - EXPECT_EQ(len, frs.size()); - auto y = bb::stdlib::field_conversion::convert_from_bn254_frs(builder, frs); - EXPECT_EQ(x.evaluations.size(), y.evaluations.size()); - for (size_t i = 0; i < x.evaluations.size(); i++) { - EXPECT_EQ(x.evaluations[i].get_value(), y.evaluations[i].get_value()); + for (auto [val1, val2] : zip_view(x, y)) { + EXPECT_EQ(val1.get_value(), val2.get_value()); } } }; @@ -132,7 +121,7 @@ TYPED_TEST(StdlibFieldConversionTests, FieldConversionArrayBn254Fr) std::array, 4> x1{ fr(&builder, 1), fr(&builder, 2), fr(&builder, 3), fr(&builder, 4) }; - this->check_conversion_array(builder, x1); + this->check_conversion_iterable(builder, x1); std::array, 7> x2{ fr(&builder, bb::fr::modulus_minus_two), fr(&builder, bb::fr::modulus_minus_two - 123), @@ -141,7 +130,7 @@ TYPED_TEST(StdlibFieldConversionTests, FieldConversionArrayBn254Fr) fr(&builder, 367032), fr(&builder, 12985028), fr(&builder, bb::fr::modulus_minus_two - 125015028) }; - this->check_conversion_array(builder, x2); + this->check_conversion_iterable(builder, x2); } /** @@ -167,7 +156,7 @@ TYPED_TEST(StdlibFieldConversionTests, FieldConversionArrayGrumpkinFr) &builder, static_cast(std::string("018555a8eb50cf07f64b019ebaf3af3c925c93e631f3ecd455db07bbb52bbdd3"))), }; - this->check_conversion_array(builder, x1); + this->check_conversion_iterable(builder, x1); } /** @@ -182,7 +171,7 @@ TYPED_TEST(StdlibFieldConversionTests, FieldConversionUnivariateBn254Fr) Univariate, 4> x{ { fr(&builder, 1), fr(&builder, 2), fr(&builder, 3), fr(&builder, 4) } }; - this->check_conversion_univariate(builder, x); + this->check_conversion_iterable(builder, x); } /** @@ -208,7 +197,7 @@ TYPED_TEST(StdlibFieldConversionTests, FieldConversionUnivariateGrumpkinFr) static_cast( std::string("2bf1eaf87f7d27e8dc4056e9af975985bccc89077a21891d6c7b6ccce0631f95"))) } }; - this->check_conversion_univariate(builder, x); + this->check_conversion_iterable(builder, x); } /** diff --git a/barretenberg/cpp/src/barretenberg/stdlib/recursion/honk/transcript/transcript.hpp b/barretenberg/cpp/src/barretenberg/stdlib/recursion/honk/transcript/transcript.hpp index 9fef4b874f9..dcc23cdfa6a 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/recursion/honk/transcript/transcript.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/recursion/honk/transcript/transcript.hpp @@ -37,7 +37,7 @@ template struct StdlibTranscriptParams { } template static constexpr size_t calc_num_bn254_frs() { - return bb::stdlib::field_conversion::calc_num_bn254_frs(); + return bb::stdlib::field_conversion::calc_num_bn254_frs(); } template static inline T convert_from_bn254_frs(std::span frs) { @@ -48,7 +48,7 @@ template struct StdlibTranscriptParams { template static inline std::vector convert_to_bn254_frs(const T& element) { Builder* builder = element.get_context(); - return bb::stdlib::field_conversion::convert_to_bn254_frs(*builder, element); + return bb::stdlib::field_conversion::convert_to_bn254_frs(*builder, element); } }; diff --git a/barretenberg/cpp/src/barretenberg/sumcheck/sumcheck.hpp b/barretenberg/cpp/src/barretenberg/sumcheck/sumcheck.hpp index 09e2f41de68..7fffcb15269 100644 --- a/barretenberg/cpp/src/barretenberg/sumcheck/sumcheck.hpp +++ b/barretenberg/cpp/src/barretenberg/sumcheck/sumcheck.hpp @@ -123,7 +123,7 @@ template class SumcheckProver { zip_view(multivariate_evaluations.get_all(), partially_evaluated_polynomials.get_all())) { eval = poly[0]; } - transcript->send_to_verifier("Sumcheck:evaluations", multivariate_evaluations); + transcript->send_to_verifier("Sumcheck:evaluations", multivariate_evaluations.get_all()); return { multivariate_challenge, multivariate_evaluations }; };