diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..7520c79 --- /dev/null +++ b/.clang-format @@ -0,0 +1,30 @@ +Language: Cpp +BasedOnStyle: Google +DerivePointerAlignment: false +AllowShortFunctionsOnASingleLine: Empty +BinPackArguments: false +BinPackParameters: false + +# The include rules below structures includes as +# +# - STL headers (anything without an extension, tbp) +# - Other headers (anything that ends with .h) +# - External SCL headers (anything of the form ) +# - Internal SCL headers (anything of the form "scl/...") +# +# The only exception is when a .cc file includes a header file with the same +# name at the same path. + +IncludeCategories: + - Regex: '^' + Priority: 2 + SortPriority: 0 + - Regex: '^<.*' + Priority: 1 + SortPriority: 0 + - Regex: '^scl/.*' + Priority: 5 + SortPriority: 0 diff --git a/.clang-tidy b/.clang-tidy new file mode 100644 index 0000000..7428941 --- /dev/null +++ b/.clang-tidy @@ -0,0 +1,27 @@ +Checks: '-*,bugprone-*,performance-*,readability-*,google-global-names-in-headers,cert-dcl59-cpp,-bugprone-easily-swappable-parameters,-readability-identifier-length,-readability-magic-numbers,-readability-function-cognitive-complexity,-readability-function-size' + +# Enabled checks: +# - bugprone +# - performance +# - readability +# - google-global-names-in-headers +# - cert-dcl59-cpp +# +# Specific disabled checks +# +# bugprone-easily-swappable-parameters: +# Doesn't make sense to exclude functions taking multiple ints in SCL. +# +# readability-identifier-length: +# Short identifiers make sense. +# +# readability-magic-numbers: +# Too strict. +# +# readability-function-cognitive-complexity +# Catch2. +# +# readability-function-size +# Catch2. + +AnalyzeTemporaryDtors: false diff --git a/.github/workflows/Checks.yml b/.github/workflows/Checks.yml index 05adccc..93089d7 100644 --- a/.github/workflows/Checks.yml +++ b/.github/workflows/Checks.yml @@ -9,7 +9,7 @@ on: jobs: documentation: name: Documentation - runs-on: ubuntu-latest + runs-on: ubuntu-20.04 steps: - uses: actions/checkout@v2 @@ -39,11 +39,11 @@ jobs: - uses: actions/checkout@v2 - name: Setup - run: sudo apt-get install -y clang-format-12 + run: sudo apt-get install -y clang-format - name: Check shell: bash run: | - find . -type f \( -iname "*.h" -o -iname "*.cc" \) -exec clang-format -n --style=Google {} \; &> checks.txt + find . -type f \( -iname "*.h" -o -iname "*.cc" \) -exec clang-format -n {} \; &> checks.txt cat checks.txt test ! -s checks.txt diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index c52ab45..28c491c 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -12,7 +12,7 @@ env: jobs: build: name: Coverage and Linting - runs-on: ubuntu-latest + runs-on: ubuntu-20.04 steps: - uses: actions/checkout@v2 diff --git a/CMakeLists.txt b/CMakeLists.txt index 25be133..5ac7c25 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,7 +16,7 @@ cmake_minimum_required( VERSION 3.14 ) -project( scl VERSION 3.0.0 DESCRIPTION "Secure Computation Library" ) +project( scl VERSION 4.0.0 DESCRIPTION "Secure Computation Library" ) if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) @@ -35,10 +35,12 @@ set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -Wall -Wextra -pedantic -Werror -std=gnu++17") set(SCL_SOURCE_FILES - src/scl/prg.cc - src/scl/hash.cc + src/scl/util/str.cc + + src/scl/primitives/prg.cc + src/scl/primitives/sha3.cc + src/scl/primitives/sha256.cc - src/scl/math/str.cc src/scl/math/mersenne61.cc src/scl/math/mersenne127.cc @@ -87,8 +89,9 @@ if(CMAKE_BUILD_TYPE MATCHES "Debug") set(SCL_TEST_SOURCE_FILES test/scl/main.cc - test/scl/test_hash.cc - test/scl/test_prg.cc + test/scl/primitives/test_prg.cc + test/scl/primitives/test_sha3.cc + test/scl/primitives/test_sha256.cc test/scl/gf7.cc test/scl/math/test_mersenne61.cc @@ -102,6 +105,7 @@ if(CMAKE_BUILD_TYPE MATCHES "Debug") test/scl/ss/test_additive.cc test/scl/ss/test_poly.cc test/scl/ss/test_shamir.cc + test/scl/ss/test_feldman.cc test/scl/net/util.cc test/scl/net/test_config.cc diff --git a/README.md b/README.md index 820fd48..14121e7 100644 --- a/README.md +++ b/README.md @@ -62,3 +62,15 @@ inspiration. SCL uses Doxygen for documentation. Run `./scripts/build_documentation.sh` to generate the documentation. This is placed in the `doc/` folder. Documentation uses `doxygen`, so make sure that's installed. + +# Citing + +I'd greatly appreciate any work that uses SCL include the below bibtex entry + +``` +@misc{secure-computation-library, + author = {Anders Dalskov}, + title = {{SCL (Secure Computation Library)---utility library for prototyping MPC applications}}, + howpublished = {\url{https://github.com/anderspkd/secure-computation-library}}, +} +``` diff --git a/RELEASE.txt b/RELEASE.txt index f56e41d..80120a4 100644 --- a/RELEASE.txt +++ b/RELEASE.txt @@ -1,3 +1,14 @@ +4.0: Shamir, Feldman, SHA-256 +- Refactor Shamir to allow caching of Lagrange coefficients +- Add support for Feldman Secret Sharing +- Add support for SHA-256 +- Add bibtex blob for citing SCL +- Refactor interface for hash functions +- Refactor interface for Shamir +- bugs: + - Fix negation of 0 in Secp256k1::Field and Secp256k1::Order + - Make serialization and deserialization of curve points behave more sanely + 3.0: More features, build changes - Add method for returning a point as a pair of affine coordinates - Add method to check if a channel has data available diff --git a/examples/01_primitives.cc b/examples/01_primitives.cc index 441e330..a99b3cc 100644 --- a/examples/01_primitives.cc +++ b/examples/01_primitives.cc @@ -42,5 +42,5 @@ int main() { /* The DigestToString can be used to print a hex representation of a digest. */ - std::cout << scl::DigestToString(digest) << "\n"; + std::cout << scl::details::DigestToString(digest) << "\n"; } diff --git a/examples/02_finite_fields.cc b/examples/02_finite_fields.cc index 891acc7..7951e45 100644 --- a/examples/02_finite_fields.cc +++ b/examples/02_finite_fields.cc @@ -18,10 +18,10 @@ * along with this program. If not, see . */ -#include - #include +#include + int main() { /* This defines a "Finite Field" with space for at least 32 bits of * computation. At the moment, SCL supports two primes: One that is 61 bits @@ -67,7 +67,7 @@ int main() { std::cout << a << " ?= " << b << ": " << (a == b) << "\n"; std::cout << a << " ?= " << a << ": " << (a == a) << "\n"; - scl::PRG prg; + auto prg = scl::PRG::Create(); /* Using a PRG (see the PRG example), we can generate random field elements. */ diff --git a/examples/03_secret_sharing.cc b/examples/03_secret_sharing.cc index c5f9dcc..d98b553 100644 --- a/examples/03_secret_sharing.cc +++ b/examples/03_secret_sharing.cc @@ -18,16 +18,17 @@ * along with this program. If not, see . */ -#include -#include - #include #include +#include +#include + int main() { using Fp = scl::Fp<32>; using Vec = scl::Vec; - scl::PRG prg; + + auto prg = scl::PRG::Create(); /* We can easily create an additive secret sharing of some secret value: */ @@ -46,8 +47,8 @@ int main() { * correction. Lets see error detection at work first */ - scl::details::ShamirSSFactory factory( - 1, prg, scl::details::SecurityLevel::CORRECT); + auto factory = + scl::ShamirSSFactory::Create(1, prg, scl::SecurityLevel::CORRECT); /* We create 4 shamir shares with a threshold of 1. */ auto shamir_shares = factory.Share(secret); @@ -56,17 +57,15 @@ int main() { /* Of course, these can be reconstructed. The second parameter is the * threshold. This performs reconstruction with error detection. */ - auto recon = factory.GetInterpolator(); auto shamir_reconstructed = - recon.Reconstruct(shamir_shares, scl::details::SecurityLevel::DETECT); + factory.Recover(shamir_shares, scl::SecurityLevel::DETECT); std::cout << shamir_reconstructed << "\n"; /* If we introduce an error, then reconstruction fails */ shamir_shares[2] = Fp(123); try { - std::cout << recon.Reconstruct(shamir_shares, - scl::details::SecurityLevel::DETECT) + std::cout << factory.Recover(shamir_shares, scl::SecurityLevel::DETECT) << "\n"; } catch (std::logic_error& e) { std::cout << e.what() << "\n"; @@ -75,7 +74,7 @@ int main() { /* On the other hand, we can use the robust reconstruction since the threshold * is low enough. I.e., because 4 >= 3*1 + 1. */ - auto r = recon.Reconstruct(shamir_shares); + auto r = factory.Recover(shamir_shares); std::cout << r << "\n"; /* With a bit of extra work, we can even learn which share had the error. @@ -104,7 +103,7 @@ int main() { */ shamir_shares[1] = Fp(22); try { - recon.Reconstruct(shamir_shares); + factory.Recover(shamir_shares); } catch (std::logic_error& e) { std::cout << e.what() << "\n"; } diff --git a/examples/04_networking.cc b/examples/04_networking.cc index c37acc6..0ba6a82 100644 --- a/examples/04_networking.cc +++ b/examples/04_networking.cc @@ -18,11 +18,10 @@ * along with this program. If not, see . */ -#include - #include -#include "scl/net/tcp_channel.h" +#include +#include scl::NetworkConfig RunServer(int n) { scl::DiscoveryServer server(n); @@ -67,7 +66,7 @@ int main(int argc, char** argv) { auto network = scl::Network::Create(config); - for (std::size_t i = 0; i < 3; ++i) { + for (std::size_t i = 0; i < (std::size_t)n; ++i) { // similar to the TCP channel example, send our ID to everyone: network.Party(i)->Send(config.Id()); unsigned received_id; diff --git a/include/scl/math/ec.h b/include/scl/math/ec.h index caed8d4..057888a 100644 --- a/include/scl/math/ec.h +++ b/include/scl/math/ec.h @@ -26,7 +26,7 @@ #include "scl/math/ec_ops.h" #include "scl/math/ff.h" #include "scl/math/number.h" -#include "scl/prg.h" +#include "scl/primitives/prg.h" namespace scl { @@ -44,6 +44,11 @@ class EC { */ using Field = FF; + /** + * @brief A large sub-group of this curve. + */ + using Order = FF; + /** * @brief The size of a curve point in bytes. * @param compressed @@ -62,7 +67,9 @@ class EC { /** * @brief A string indicating which curve this is. */ - constexpr static const char* Name() { return Curve::kName; }; + constexpr static const char* Name() { + return Curve::kName; + }; /** * @brief Get the generator of this curve. @@ -96,7 +103,9 @@ class EC { /** * @brief Create a new point equal to the point at infinity. */ - explicit constexpr EC() { details::CurveSetPointAtInfinity(mValue); }; + explicit constexpr EC() { + details::CurveSetPointAtInfinity(mValue); + }; /** * @brief Add another EC point to this. @@ -173,7 +182,7 @@ class EC { * @param scalar the scalar * @return this. */ - EC& operator*=(const FF& scalar) { + EC& operator*=(const Order& scalar) { details::CurveScalarMultiply(mValue, scalar); return *this; }; @@ -195,8 +204,7 @@ class EC { * @param scalar the scalar * @return the point multiplied with the scalar. */ - friend EC operator*(const EC& point, - const FF& scalar) { + friend EC operator*(const EC& point, const Order& scalar) { EC copy(point); return copy *= scalar; }; diff --git a/include/scl/math/ec_ops.h b/include/scl/math/ec_ops.h index 8a601f4..8849f75 100644 --- a/include/scl/math/ec_ops.h +++ b/include/scl/math/ec_ops.h @@ -56,7 +56,8 @@ void CurveSetGenerator(typename C::ValueType& out); * @param y the y coordinate */ template -void CurveSetAffine(typename C::ValueType& out, const FF& x, +void CurveSetAffine(typename C::ValueType& out, + const FF& x, const FF& y); /** @@ -139,7 +140,8 @@ void CurveFromBytes(typename C::ValueType& out, const unsigned char* src); * @param compress whether to compress the point */ template -void CurveToBytes(unsigned char* dest, const typename C::ValueType& in, +void CurveToBytes(unsigned char* dest, + const typename C::ValueType& in, bool compress); /** diff --git a/include/scl/math/ff.h b/include/scl/math/ff.h index 17739a7..a1e6d53 100644 --- a/include/scl/math/ff.h +++ b/include/scl/math/ff.h @@ -27,7 +27,7 @@ #include "scl/math/ff_ops.h" #include "scl/math/ring.h" -#include "scl/prg.h" +#include "scl/primitives/prg.h" namespace scl { @@ -47,17 +47,23 @@ class FF final : details::RingBase> { /** * @brief Size in bytes of a field element. */ - constexpr static std::size_t ByteSize() { return Field::kByteSize; }; + constexpr static std::size_t ByteSize() { + return Field::kByteSize; + }; /** * @brief Actual bit size of an element. */ - constexpr static std::size_t BitSize() { return Field::kBitSize; }; + constexpr static std::size_t BitSize() { + return Field::kBitSize; + }; /** * @brief A short string representation of this field. */ - constexpr static const char* Name() { return Field::kName; }; + constexpr static const char* Name() { + return Field::kName; + }; /** * @brief Read a field element from a buffer. @@ -166,7 +172,9 @@ class FF final : details::RingBase> { * @param other the other element * @return this set to this * other.Inverse(). */ - FF& operator/=(const FF& other) { return operator*=(other.Inverse()); }; + FF& operator/=(const FF& other) { + return operator*=(other.Inverse()); + }; /** * @brief Negates this element. diff --git a/include/scl/math/ff_ops.h b/include/scl/math/ff_ops.h index 8c293fd..0234bc4 100644 --- a/include/scl/math/ff_ops.h +++ b/include/scl/math/ff_ops.h @@ -26,8 +26,6 @@ #include #include -#include "scl/math/str.h" - namespace scl { namespace details { @@ -89,19 +87,14 @@ bool FieldEqual(const typename F::ValueType& in1, /** * @brief Convert a field element to bytes. + * @param dest the field element to convert + * @param src where to store the converted element * * For types that are trivially copyable, this function has a default * implementation based on std::memcpy. - * - * @param dest the field element to convert - * @param src where to store the converted element */ -template , - int> = 0> -void FieldToBytes(unsigned char* dest, const typename F::ValueType& src) { - std::memcpy(dest, &src, sizeof(typename F::ValueType)); -} +template +void FieldToBytes(unsigned char* dest, const typename F::ValueType& src); /** * @brief Convert the content of a buffer to a field element. diff --git a/include/scl/math/mat.h b/include/scl/math/mat.h index c027f7d..4f76f4d 100644 --- a/include/scl/math/mat.h +++ b/include/scl/math/mat.h @@ -30,7 +30,7 @@ #include #include -#include "scl/prg.h" +#include "scl/primitives/prg.h" namespace scl { @@ -73,7 +73,8 @@ class Mat { * @param xs a \p n length vector containing the x values to use * @return a Vandermonde matrix. */ - static Mat Vandermonde(std::size_t n, std::size_t m, + static Mat Vandermonde(std::size_t n, + std::size_t m, const std::vector& xs); /** @@ -119,7 +120,8 @@ class Mat { * @param vec the elements of the matrix * @return a Matrix. */ - static Mat FromVector(std::size_t n, std::size_t m, + static Mat FromVector(std::size_t n, + std::size_t m, const std::vector& vec) { if (vec.size() != n * m) { throw std::invalid_argument("invalid dimensions"); @@ -167,12 +169,16 @@ class Mat { /** * @brief The number of rows of this matrix. */ - std::size_t Rows() const { return mRows; }; + std::size_t Rows() const { + return mRows; + }; /** * @brief The number of columns of this matrix. */ - std::size_t Cols() const { return mCols; }; + std::size_t Cols() const { + return mCols; + }; /** * @brief Provides mutable access to a matrix element. @@ -313,7 +319,9 @@ class Mat { /** * @brief Check if this matrix is square. */ - bool IsSquare() const { return Rows() == Cols(); }; + bool IsSquare() const { + return Rows() == Cols(); + }; /** * @brief Transpose this matrix. @@ -385,7 +393,9 @@ class Mat { /** * @brief The size of a matrix when serialized in bytes. */ - std::size_t ByteSize() const { return Rows() * Cols() * T::ByteSize(); } + std::size_t ByteSize() const { + return Rows() * Cols() * T::ByteSize(); + } private: Mat(std::size_t r, std::size_t c, std::vector v) @@ -445,7 +455,8 @@ Mat Mat::Random(std::size_t n, std::size_t m, PRG& prg) { } template -Mat Mat::Vandermonde(std::size_t n, std::size_t m, +Mat Mat::Vandermonde(std::size_t n, + std::size_t m, const std::vector& xs) { if (xs.size() != n) { throw std::invalid_argument("|xs| != number of rows"); diff --git a/include/scl/math/number.h b/include/scl/math/number.h index 4b536ed..17a4761 100644 --- a/include/scl/math/number.h +++ b/include/scl/math/number.h @@ -21,12 +21,12 @@ #ifndef SCL_MATH_NUMBER_H #define SCL_MATH_NUMBER_H -#include - #include #include -#include "scl/prg.h" +#include + +#include "scl/primitives/prg.h" namespace scl { @@ -99,17 +99,15 @@ class Number { return *this; }; - // This is used for all op-assign operator overloads below. -#define SCL_OP_IMPL(op, arg) \ - *this = *this op(arg); \ - return *this - /** * @brief In-place addition of two numbers. * @param number the other number * @return this */ - Number& operator+=(const Number& number) { SCL_OP_IMPL(+, number); }; + Number& operator+=(const Number& number) { + *this = *this + number; + return *this; + }; /** * @brief Add two numbers. @@ -123,7 +121,10 @@ class Number { * @param number the other number * @return this. */ - Number& operator-=(const Number& number) { SCL_OP_IMPL(-, number); }; + Number& operator-=(const Number& number) { + *this = *this - number; + return *this; + }; /** * @brief Subtract two Numbers. @@ -143,7 +144,10 @@ class Number { * @param number the other Number * @return this. */ - Number& operator*=(const Number& number) { SCL_OP_IMPL(*, number); }; + Number& operator*=(const Number& number) { + *this = *this * number; + return *this; + }; /** * @brief Multiply two Numbers. @@ -157,7 +161,10 @@ class Number { * @param number the other number * @return this. */ - Number& operator/=(const Number& number) { SCL_OP_IMPL(/, number); }; + Number& operator/=(const Number& number) { + *this = *this / number; + return *this; + }; /** * @brief Divide two Numbers. @@ -171,7 +178,10 @@ class Number { * @param shift the amount to left shift * @return this. */ - Number& operator<<=(int shift) { SCL_OP_IMPL(<<, shift); }; + Number& operator<<=(int shift) { + *this = *this << shift; + return *this; + }; /** * @brief Perform a left shift of a Number. @@ -185,7 +195,10 @@ class Number { * @param shift the amount to right shift * @return this. */ - Number& operator>>=(int shift) { SCL_OP_IMPL(>>, shift); }; + Number& operator>>=(int shift) { + *this = *this >> shift; + return *this; + }; /** * @brief Perform a right shift of a Number. @@ -199,7 +212,10 @@ class Number { * @param number the number to xor this with * @return \p this */ - Number& operator^=(const Number& number) { SCL_OP_IMPL(^, number); }; + Number& operator^=(const Number& number) { + *this = *this ^ number; + return *this; + }; /** * @brief Exclusive or of two numbers. @@ -213,7 +229,10 @@ class Number { * @param number * @return */ - Number& operator|=(const Number& number) { SCL_OP_IMPL(|, number); }; + Number& operator|=(const Number& number) { + *this = *this | number; + return *this; + }; /** * @brief operator | @@ -227,9 +246,10 @@ class Number { * @param number * @return */ - Number& operator&=(const Number& number) { SCL_OP_IMPL(&, number); }; - -#undef SCL_OP_IMPL + Number& operator&=(const Number& number) { + *this = *this & number; + return *this; + }; /** * @brief operator & @@ -320,13 +340,17 @@ class Number { * @brief Test if this Number is odd. * @return true if this Number is odd. */ - bool Odd() const { return TestBit(0); }; + bool Odd() const { + return TestBit(0); + }; /** * @brief Test if this Number is even. * @return true if this Number is even. */ - bool Even() const { return !Odd(); }; + bool Even() const { + return !Odd(); + }; /** * @brief Return a string representation of this Number. diff --git a/include/scl/math/ring.h b/include/scl/math/ring.h index a14deef..4224588 100644 --- a/include/scl/math/ring.h +++ b/include/scl/math/ring.h @@ -35,7 +35,7 @@ struct RingBase { /** * @brief Add two elements and return their sum. */ - friend T operator+(const T &lhs, const T &rhs) { + friend T operator+(const T& lhs, const T& rhs) { T temp(lhs); return temp += rhs; }; @@ -43,7 +43,7 @@ struct RingBase { /** * @brief Subtract two elements and return their difference. */ - friend T operator-(const T &lhs, const T &rhs) { + friend T operator-(const T& lhs, const T& rhs) { T temp(lhs); return temp -= rhs; }; @@ -51,7 +51,7 @@ struct RingBase { /** * @brief Return the negation of an element. */ - friend T operator-(const T &elem) { + friend T operator-(const T& elem) { T temp(elem); return temp.Negate(); }; @@ -59,7 +59,7 @@ struct RingBase { /** * @brief Multiply two elements and return their product. */ - friend T operator*(const T &lhs, const T &rhs) { + friend T operator*(const T& lhs, const T& rhs) { T temp(lhs); return temp *= rhs; }; @@ -67,7 +67,7 @@ struct RingBase { /** * @brief Divide two elements and return their quotient. */ - friend T operator/(const T &lhs, const T &rhs) { + friend T operator/(const T& lhs, const T& rhs) { T temp(lhs); return temp /= rhs; }; @@ -75,17 +75,21 @@ struct RingBase { /** * @brief Compare two elements for equality. */ - friend bool operator==(const T &lhs, const T &rhs) { return lhs.Equal(rhs); }; + friend bool operator==(const T& lhs, const T& rhs) { + return lhs.Equal(rhs); + }; /** * @brief Compare two elements for inequality. */ - friend bool operator!=(const T &lhs, const T &rhs) { return !(lhs == rhs); }; + friend bool operator!=(const T& lhs, const T& rhs) { + return !(lhs == rhs); + }; /** * @brief Write a string representation of an element to a stream. */ - friend std::ostream &operator<<(std::ostream &os, const T &r) { + friend std::ostream& operator<<(std::ostream& os, const T& r) { return os << r.ToString(); }; }; diff --git a/include/scl/math/vec.h b/include/scl/math/vec.h index ab5e89a..f527b86 100644 --- a/include/scl/math/vec.h +++ b/include/scl/math/vec.h @@ -29,7 +29,7 @@ #include #include "scl/math/mat.h" -#include "scl/prg.h" +#include "scl/primitives/prg.h" namespace scl { @@ -41,8 +41,8 @@ namespace details { * @param xe end of the first iterator * @param yb start of the second iterator */ -template -T UncheckedInnerProd(It xb, It xe, It yb) { +template +T UncheckedInnerProd(I0 xb, I0 xe, I1 yb) { T v; while (xb != xe) { v += *xb++ * *yb++; @@ -161,17 +161,23 @@ class Vec { /** * @brief The size of the Vec. */ - std::size_t Size() const { return mValues.size(); }; + std::size_t Size() const { + return mValues.size(); + }; /** * @brief Mutable access to vector elements. */ - T& operator[](std::size_t idx) { return mValues[idx]; }; + T& operator[](std::size_t idx) { + return mValues[idx]; + }; /** * @brief Read only access to vector elements. */ - T operator[](std::size_t idx) const { return mValues[idx]; }; + T operator[](std::size_t idx) const { + return mValues[idx]; + }; /** * @brief Add two Vec objects entry-wise. @@ -305,22 +311,30 @@ class Vec { /** * @brief Convert this vector into a 1-by-N row matrix. */ - Mat ToRowMatrix() const { return Mat{1, Size(), mValues}; }; + Mat ToRowMatrix() const { + return Mat{1, Size(), mValues}; + }; /** * @brief Convert this vector into a N-by-1 column matrix. */ - Mat ToColumnMatrix() const { return Mat{Size(), 1, mValues}; }; + Mat ToColumnMatrix() const { + return Mat{Size(), 1, mValues}; + }; /** * @brief Convert this Vec object to an std::vector. */ - std::vector& ToStlVector() { return mValues; }; + std::vector& ToStlVector() { + return mValues; + }; /** * @brief Convert this Vec object to a const std::vector. */ - const std::vector& ToStlVector() const { return mValues; }; + const std::vector& ToStlVector() const { + return mValues; + }; /** * @brief Extract a sub-vector @@ -343,7 +357,9 @@ class Vec { * @param end the end index, exclusive * @return a sub-vector. */ - Vec SubVector(std::size_t end) { return SubVector(0, end); }; + Vec SubVector(std::size_t end) { + return SubVector(0, end); + }; /** * @brief Return a string representation of this vector. @@ -366,67 +382,93 @@ class Vec { /** * @brief Returns the number of bytes that Write writes. */ - std::size_t ByteSize() const { return Size() * T::ByteSize(); }; + std::size_t ByteSize() const { + return Size() * T::ByteSize(); + }; /** * @brief Return an iterator pointing to the start of this Vec. */ - iterator begin() { return mValues.begin(); }; + iterator begin() { + return mValues.begin(); + }; /** * @brief Provides a const iterator to the start of this Vec. */ - const_iterator begin() const { return mValues.begin(); }; + const_iterator begin() const { + return mValues.begin(); + }; /** * @brief Provides a const iterator to the start of this Vec. */ - const_iterator cbegin() const { return mValues.cbegin(); }; + const_iterator cbegin() const { + return mValues.cbegin(); + }; /** * @brief Provides an iterator pointing to the end of this Vec. */ - iterator end() { return mValues.end(); }; + iterator end() { + return mValues.end(); + }; /** * @brief Provides a const iterator pointing to the end of this Vec. */ - const_iterator end() const { return mValues.end(); }; + const_iterator end() const { + return mValues.end(); + }; /** * @brief Provides a const iterator pointing to the end of this Vec. */ - const_iterator cend() const { return mValues.cend(); }; + const_iterator cend() const { + return mValues.cend(); + }; /** * @brief Provides a reverse iterator pointing to the end of this Vec. */ - reverse_iterator rbegin() { return mValues.rbegin(); }; + reverse_iterator rbegin() { + return mValues.rbegin(); + }; /** * @brief Provides a reverse const iterator pointing to the end of this Vec. */ - const_reverse_iterator rbegin() const { return mValues.rbegin(); }; + const_reverse_iterator rbegin() const { + return mValues.rbegin(); + }; /** * @brief Provides a reverse const iterator pointing to the end of this Vec. */ - const_reverse_iterator crbegin() const { return mValues.crbegin(); }; + const_reverse_iterator crbegin() const { + return mValues.crbegin(); + }; /** * @brief Provides a reverse iterator pointing to the start of this Vec. */ - reverse_iterator rend() { return mValues.rend(); }; + reverse_iterator rend() { + return mValues.rend(); + }; /** * @brief Provides a reverse const iterator pointing to the start of this Vec. */ - const_reverse_iterator rend() const { return mValues.rend(); }; + const_reverse_iterator rend() const { + return mValues.rend(); + }; /** * @brief Provides a reverse const iterator pointing to the start of this Vec. */ - const_reverse_iterator crend() const { return mValues.crend(); }; + const_reverse_iterator crend() const { + return mValues.crend(); + }; private: void EnsureCompatible(const Vec& other) const { diff --git a/include/scl/math/z2k.h b/include/scl/math/z2k.h index 6610075..5fa5ad2 100644 --- a/include/scl/math/z2k.h +++ b/include/scl/math/z2k.h @@ -25,7 +25,7 @@ #include "scl/math/ring.h" #include "scl/math/z2k_ops.h" -#include "scl/prg.h" +#include "scl/primitives/prg.h" namespace scl { @@ -50,22 +50,30 @@ class Z2k final : public details::RingBase> { /** * @brief The bit size of the ring. Identical to BitSize(). */ - constexpr static std::size_t SpecifiedBitSize() { return K; }; + constexpr static std::size_t SpecifiedBitSize() { + return K; + }; /** * @brief The number of bytes needed to store a ring element. */ - constexpr static std::size_t ByteSize() { return (K - 1) / 8 + 1; }; + constexpr static std::size_t ByteSize() { + return (K - 1) / 8 + 1; + }; /** * @brief The bit size of the ring. Identical to SpecifiedBitSize(). */ - constexpr static std::size_t BitSize() { return SpecifiedBitSize(); }; + constexpr static std::size_t BitSize() { + return SpecifiedBitSize(); + }; /** * @brief A short string representation of this ring. */ - constexpr static const char* Name() { return "Z2k"; }; + constexpr static const char* Name() { + return "Z2k"; + }; /** * @brief Read a ring from a buffer. @@ -193,7 +201,9 @@ class Z2k final : public details::RingBase> { * particular, an element x is invertible if x.Lsb() == * 1. That is, if it is odd. */ - unsigned Lsb() const { return details::LsbZ2k(mValue); }; + unsigned Lsb() const { + return details::LsbZ2k(mValue); + }; /** * @brief Check if this element is equal to another element. diff --git a/include/scl/math/z2k_ops.h b/include/scl/math/z2k_ops.h index c767bad..16b1c57 100644 --- a/include/scl/math/z2k_ops.h +++ b/include/scl/math/z2k_ops.h @@ -26,7 +26,7 @@ #include #include -#include "scl/math/str.h" +#include "scl/util/str.h" namespace scl { namespace details { diff --git a/include/scl/net/channel.h b/include/scl/net/channel.h index 1a4becb..b3b0732 100644 --- a/include/scl/net/channel.h +++ b/include/scl/net/channel.h @@ -92,14 +92,17 @@ class Channel { * @return true if this channel has data and false otherwise. * @note the default implementation always returns true. */ - virtual bool HasData() { return true; }; + virtual bool HasData() { + return true; + }; /** * @brief Send a trivially copyable item. * @param src the thing to send */ - template , bool> = true> + template < + typename T, + typename std::enable_if_t, bool> = true> void Send(const T& src) { Send(SCL_CC(&src), sizeof(T)); } @@ -112,8 +115,9 @@ class Channel { * * @param src an STL vector of things to send */ - template , bool> = true> + template < + typename T, + typename std::enable_if_t, bool> = true> void Send(const std::vector& src) { Send(src.size()); Send(SCL_CC(src.data()), sizeof(T) * src.size()); @@ -168,8 +172,9 @@ class Channel { * @brief Receive a trivially copyable item. * @param dst where to store the received item */ - template , bool> = true> + template < + typename T, + typename std::enable_if_t, bool> = true> void Recv(T& dst) { Recv(SCL_C(&dst), sizeof(T)); } @@ -179,8 +184,9 @@ class Channel { * @param dst where to store the received items * @note any existing content in \p dst is overwritten. */ - template , bool> = true> + template < + typename T, + typename std::enable_if_t, bool> = true> void Recv(std::vector& dst) { std::size_t size; Recv(size); diff --git a/include/scl/net/config.h b/include/scl/net/config.h index 8831779..529b675 100644 --- a/include/scl/net/config.h +++ b/include/scl/net/config.h @@ -112,22 +112,30 @@ class NetworkConfig { /** * @brief Gets the identity of this party. */ - int Id() const { return mId; }; + int Id() const { + return mId; + }; /** * @brief Gets the size of the network. */ - std::size_t NetworkSize() const { return mParties.size(); }; + std::size_t NetworkSize() const { + return mParties.size(); + }; /** * @brief Get a list of connection information for parties in this network. */ - std::vector Parties() const { return mParties; }; + std::vector Parties() const { + return mParties; + }; /** * @brief Get information about a party. */ - Party GetParty(unsigned id) const { return mParties[id]; }; + Party GetParty(unsigned id) const { + return mParties[id]; + }; /** * @brief Return a string representation of this network config. diff --git a/include/scl/net/mem_channel.h b/include/scl/net/mem_channel.h index 5e552b0..66faa82 100644 --- a/include/scl/net/mem_channel.h +++ b/include/scl/net/mem_channel.h @@ -80,7 +80,11 @@ class InMemoryChannel final : public Channel { void Send(const unsigned char* src, std::size_t n) override; std::size_t Recv(unsigned char* dst, std::size_t n) override; - bool HasData() override { return mIn->Size() > 0 || !mOverflow.empty(); }; + + bool HasData() override { + return mIn->Size() > 0 || !mOverflow.empty(); + }; + void Close() override{}; private: diff --git a/include/scl/net/network.h b/include/scl/net/network.h index b1ed6f3..e15feea 100644 --- a/include/scl/net/network.h +++ b/include/scl/net/network.h @@ -65,12 +65,16 @@ class Network { * @brief Get a channel to a particular party. * @param id the id of the party */ - Channel* Party(unsigned id) { return mChannels[id].get(); } + Channel* Party(unsigned id) { + return mChannels[id].get(); + } /** * @brief The size of the network. */ - std::size_t Size() const { return mChannels.size(); }; + std::size_t Size() const { + return mChannels.size(); + }; /** * @brief Closes all channels in the network. @@ -153,8 +157,8 @@ Network Network::Create(const scl::NetworkConfig& config) { channels[config.Id()] = scl::details::CreateChannelConnectingToSelf(); - std::thread server(scl::details::SCL_AcceptConnections, std::ref(channels), - config); + std::thread server( + scl::details::SCL_AcceptConnections, std::ref(channels), config); for (std::size_t i = 0; i < static_cast(config.Id()); ++i) { const auto party = config.GetParty(i); diff --git a/include/scl/net/shared_deque.h b/include/scl/net/shared_deque.h index 862c371..55c62a6 100644 --- a/include/scl/net/shared_deque.h +++ b/include/scl/net/shared_deque.h @@ -44,7 +44,7 @@ class SharedDeque { /** * @brief Read the top element from the queue. */ - T &Peek(); + T& Peek(); /** * @brief Remove and return the top element from the queue. @@ -54,12 +54,12 @@ class SharedDeque { /** * @brief Insert an item to the back of the queue. */ - void PushBack(const T &item); + void PushBack(const T& item); /** * @brief Move an item to the back of the queue. */ - void PushBack(T &&item); + void PushBack(T&& item); /** * @brief Number of elements currently in the queue. @@ -82,7 +82,7 @@ void SharedDeque::PopFront() { } template -T &SharedDeque::Peek() { +T& SharedDeque::Peek() { std::unique_lock lock(mMutex); while (mDeck.empty()) { mCond.wait(lock); @@ -102,7 +102,7 @@ T SharedDeque::Pop() { } template -void SharedDeque::PushBack(const T &item) { +void SharedDeque::PushBack(const T& item) { std::unique_lock lock(mMutex); mDeck.push_back(item); lock.unlock(); @@ -110,7 +110,7 @@ void SharedDeque::PushBack(const T &item) { } template -void SharedDeque::PushBack(T &&item) { +void SharedDeque::PushBack(T&& item) { std::unique_lock lock(mMutex); mDeck.push_back(std::move(item)); lock.unlock(); diff --git a/include/scl/net/tcp_channel.h b/include/scl/net/tcp_channel.h index f5af2fe..05e2d24 100644 --- a/include/scl/net/tcp_channel.h +++ b/include/scl/net/tcp_channel.h @@ -44,12 +44,16 @@ class TcpChannel final : public Channel { /** * @brief Destroying a TCP channel closes the connection. */ - ~TcpChannel() { Close(); }; + ~TcpChannel() { + Close(); + }; /** * @brief Tells whether this channel is alive or not. */ - bool Alive() const { return mAlive; }; + bool Alive() const { + return mAlive; + }; void Send(const unsigned char* src, std::size_t n) override; std::size_t Recv(unsigned char* dst, std::size_t n) override; diff --git a/include/scl/net/tcp_utils.h b/include/scl/net/tcp_utils.h index fd75e70..3bf92e3 100644 --- a/include/scl/net/tcp_utils.h +++ b/include/scl/net/tcp_utils.h @@ -21,11 +21,11 @@ #ifndef SCL_NET_TCP_UTILS_H #define SCL_NET_TCP_UTILS_H -#include - #include #include +#include + namespace scl { namespace details { diff --git a/include/scl/net/threaded_sender.h b/include/scl/net/threaded_sender.h index f06e1b5..e9e6710 100644 --- a/include/scl/net/threaded_sender.h +++ b/include/scl/net/threaded_sender.h @@ -56,7 +56,9 @@ class ThreadedSenderChannel final : public Channel { return mChannel.Recv(dst, n); }; - bool HasData() override { return mChannel.HasData(); }; + bool HasData() override { + return mChannel.HasData(); + }; private: TcpChannel mChannel; diff --git a/include/scl/p/simple.h b/include/scl/p/simple.h index 04188de..9f372f2 100644 --- a/include/scl/p/simple.h +++ b/include/scl/p/simple.h @@ -33,12 +33,14 @@ struct ProtocolStep { /** * @brief Evaluate a step of the protocol. */ - auto Run(Ctx& context) { return static_cast(this)->Run(context); }; + auto Run(Ctx& context) { + return static_cast(this)->Run(context); + }; }; /** * @brief The final step of a protocol. - */ +o */ template struct LastProtocolStep { /** @@ -53,7 +55,8 @@ struct LastProtocolStep { * @brief Recursively evaluate all internal steps of a protocol. */ template < - typename S, typename Ctx, + typename S, + typename Ctx, std::enable_if_t, S>, bool> = true> auto Evaluate(S& step, Ctx& context) { auto next = step.Run(context); @@ -63,7 +66,8 @@ auto Evaluate(S& step, Ctx& context) { /** * @brief Evaluate the last step of a protocol. */ -template , S>, bool> = true> auto Evaluate(S& last, Ctx& context) { diff --git a/include/scl/primitives.h b/include/scl/primitives.h index 0235577..15bb05d 100644 --- a/include/scl/primitives.h +++ b/include/scl/primitives.h @@ -21,7 +21,7 @@ #ifndef SCL_PRIMITIVES_H #define SCL_PRIMITIVES_H -#include "scl/hash.h" -#include "scl/prg.h" +#include "scl/primitives/hash.h" +#include "scl/primitives/prg.h" #endif // SCL_PRIMITIVES_H diff --git a/include/scl/primitives/digest.h b/include/scl/primitives/digest.h new file mode 100644 index 0000000..eae103a --- /dev/null +++ b/include/scl/primitives/digest.h @@ -0,0 +1,62 @@ +/** + * @file digest.h + * + * SCL --- Secure Computation Library + * Copyright (C) 2022 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_PRIMITIVES_DIGEST_H +#define SCL_PRIMITIVES_DIGEST_H + +#include +#include +#include +#include + +#include "scl/util/str.h" + +namespace scl { +namespace details { + +/** + * @brief A digest of some bitsize. + * + * This type is effectively std::array. + * + * @tparam Bits the bitsize of the digest + */ +template +struct Digest { + /** + * @brief The actual type of a digest. + */ + using Type = std::array; +}; + +/** + * @brief Convert a digest to a string. + * @param digest the digest + * @return a hex representation of the digest. + */ +template +std::string DigestToString(const D& digest) { + return ToHexString(digest.begin(), digest.end()); +} + +} // namespace details +} // namespace scl + +#endif // SCL_PRIMITIVES_DIGEST_H diff --git a/include/scl/primitives/hash.h b/include/scl/primitives/hash.h new file mode 100644 index 0000000..86278ee --- /dev/null +++ b/include/scl/primitives/hash.h @@ -0,0 +1,41 @@ +/** + * @file hash.h + * + * SCL --- Secure Computation Library + * Copyright (C) 2022 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_PRIMITIVES_HASH_H +#define SCL_PRIMITIVES_HASH_H + +#include + +#include "scl/primitives/sha3.h" + +namespace scl { + +/** + * @brief A default hash function given a digest size. + * + * This type defults to one of the three instantiations of SHA3 that SCL + * provides. + */ +template +using Hash = details::Sha3; + +} // namespace scl + +#endif // SCL_PRIMITIVES_HASH_H diff --git a/include/scl/primitives/iuf_hash.h b/include/scl/primitives/iuf_hash.h new file mode 100644 index 0000000..ee57821 --- /dev/null +++ b/include/scl/primitives/iuf_hash.h @@ -0,0 +1,88 @@ +/** + * @file iuf_hash.h + * + * SCL --- Secure Computation Library + * Copyright (C) 2022 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_PRIMITIVES_IUF_HASH_H +#define SCL_PRIMITIVES_IUF_HASH_H + +#include +#include +#include +#include +#include + +namespace scl { +namespace details { + +/** + * @brief IUF interface for hash functions. + * + *

HashInterface defines an IUF (Initialize-Update-Finalize) style interface + * for a hash function.

+ * + * @tparam H interfaceementation. + */ +template +struct IUFHash { + /** + * @brief Update the hash function with a set of bytes. + * + * @param bytes a pointer to a number of bytes. + * @param n the number of bytes. + * @return the updated Hash object. + */ + IUFHash& Update(const unsigned char* bytes, std::size_t n) { + static_cast(this)->Hash(bytes, n); + return *this; + }; + + /** + * @brief Update the hash function with the content from a byte vector. + * + * @param data a vector of bytes. + * @return the updated Hash object. + */ + IUFHash& Update(const std::vector& data) { + return Update(data.data(), data.size()); + }; + + /** + * @brief Update the hash function with the content from a byte STL array. + * + * @param data the array + * @return the updated Hash object. + */ + template + IUFHash& Update(const std::array& data) { + return Update(data.data(), N); + } + + /** + * @brief Finalize and return the digest. + */ + auto Finalize() { + auto digest = static_cast(this)->Write(); + return digest; + }; +}; + +} // namespace details +} // namespace scl + +#endif // SCL_PRIMITIVES_IUF_HASH_H diff --git a/include/scl/prg.h b/include/scl/primitives/prg.h similarity index 51% rename from include/scl/prg.h rename to include/scl/primitives/prg.h index 5699aa8..6fc94b5 100644 --- a/include/scl/prg.h +++ b/include/scl/primitives/prg.h @@ -18,15 +18,17 @@ * along with this program. If not, see . */ -#ifndef SCL_PRG_H -#define SCL_PRG_H - -#include +#ifndef SCL_PRIMITIVES_PRG_H +#define SCL_PRIMITIVES_PRG_H +#include #include #include #include +#include +#include + /** * @brief 64 bit nonce which is prepended to the counter in the PRG. */ @@ -61,103 +63,108 @@ namespace scl { * as a macro. It defaults to 0x0123456789ABCDEF. */ class PRG { - public: - /** - * @brief The size of an output block. - */ - static constexpr std::size_t BlockSize() { return sizeof(BlockType); }; + private: + using BlockType = __m128i; + static constexpr std::size_t kBlockSize = sizeof(BlockType); + public: /** - * @brief The size of the seed. + * @brief Size of the seed. */ - static constexpr std::size_t SeedSize() { return BlockSize(); }; + static constexpr std::size_t SeedSize() { + return kBlockSize; + }; /** - * @brief Construct a new PRG object with seed of 0. + * @brief Create a new PRG with seed 0. */ - PRG(); + static PRG Create() { + PRG prg(nullptr); + prg.Init(); + return prg; + }; /** - * @brief Construct a new PRG with a given seed. - * + * @brief Create a new PRG with a provided seed. * @param seed the seed. - * - * @pre This constructor reads seed_size() bytes from - * seed so the latter must point to that much allocated space. */ - PRG(const unsigned char *seed); + static PRG Create(const unsigned char* seed) { + PRG prg(seed); + prg.Init(); + return prg; + }; /** - * @brief Reset the PRG to its initial state. + * @brief Reset the PRG. + * + * This method allows resetting a PRG object to its initial state. */ void Reset(); /** - * @brief Generate random data and store it in a supplied location. - * - * @param dest the destination of the generated random bytes. - * @param nbytes how many bytes of random data to generate. - * - * @pre dest must point to nbytes of allocated - * space. - * - * @throws std::runtime_error if allocation of roughly nbytes - * bytes of memory fails. + * @brief Generate random data and store it in a supplied buffer. + * @param buffer the buffer + * @param n how many bytes of random data to generate */ - void Next(unsigned char *dest, std::size_t nbytes); + void Next(unsigned char* buffer, std::size_t n); /** - * @brief Generate random data. - * @param dest the destination of the random data + * @brief Generate random data and store it in a supplied buffer. + * @param buffer the buffer with space pre-allocated + * + * How many bytes of random data to generate is decided based on the output of + * buffer.size(). */ - void Next(std::vector &dest) { - Next(dest.data(), dest.size()); + void Next(std::vector& buffer) { + Next(buffer.data(), buffer.size()); }; /** - * @brief Generate random data. - * @param dest the destination of the random data - * @param nbytes how many bytes to generate. + * @brief Generate random data and store in in a supplied buffer. + * @param buffer the buffer + * @param n how many random bytes to generate + * @throws std::invalid_argument if \p n is greater than + * buffer.size(). + * + * The capacity of \p buffer is not affected in any way by this method and it + * requires that it has room for at least \p n elements. */ - void Next(std::vector &dest, std::size_t nbytes) { - if (dest.size() < nbytes) { + void Next(std::vector& buffer, std::size_t n) { + if (buffer.size() < n) { throw std::invalid_argument("requested more randomness than dest.size()"); } - Next(dest.data(), nbytes); + Next(buffer.data(), n); }; /** - * @brief Generate and return some random data. - * @param nbytes the number of random bytes to generate + * @brief Generate and return random data. + * @param n the number of random bytes to generate * @return the random bytes. */ - std::vector Next(std::size_t nbytes) { - auto buffer = std::make_unique(nbytes); - Next(buffer.get(), nbytes); - return std::vector(buffer.get(), buffer.get() + nbytes); + std::vector Next(std::size_t n) { + auto buffer = std::make_unique(n); + Next(buffer.get(), n); + return std::vector(buffer.get(), buffer.get() + n); }; /** - * @brief The seed of the PRG. + * @brief The seed. */ - const unsigned char *Seed() const { return mSeed; }; - - /** - * @brief The current counter of the PRG. - */ - long Counter() const { return mCounter; }; + std::array Seed() const { + return mSeed; + }; private: - using BlockType = __m128i; - - void Update(void); - void Init(void); + PRG(const unsigned char* seed); - unsigned char mSeed[sizeof(BlockType)] = {0}; + std::array mSeed = {0}; long mCounter = PRG_INITIAL_COUNTER; BlockType mState[11]; + + void Update(); + void Init(); }; } // namespace scl -#endif // SCL_PRG_H +#endif // SCL_PRIMITIVES_PRG_H diff --git a/include/scl/primitives/sha256.h b/include/scl/primitives/sha256.h new file mode 100644 index 0000000..04d87eb --- /dev/null +++ b/include/scl/primitives/sha256.h @@ -0,0 +1,80 @@ +/** + * @file sha256.h + * + * SCL --- Secure Computation Library + * Copyright (C) 2022 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_PRIMITIVES_SHA256_H +#define SCL_PRIMITIVES_SHA256_H + +#include +#include +#include +#include + +#include "scl/primitives/digest.h" +#include "scl/primitives/iuf_hash.h" + +namespace scl { +namespace details { + +/** + * @brief SHA256 hash function. + */ +class Sha256 final : public details::IUFHash { + public: + /** + * @brief The type of a SHA256 digest. + */ + using DigestType = typename details::Digest<256>::Type; + + /** + * @brief Update the hash function with a set of bytes. + * + * @param bytes a pointer to a number of bytes. + * @param nbytes the number of bytes. + * @return the updated Hash object. + */ + void Hash(const unsigned char* bytes, std::size_t nbytes); + + /** + * @brief Finalize and return the digest. + */ + DigestType Write(); + + private: + std::array mChunk; + std::uint32_t mChunkPos = 0; + std::size_t mTotalLen = 0; + std::array mState = {0x6a09e667, + 0xbb67ae85, + 0x3c6ef372, + 0xa54ff53a, + 0x510e527f, + 0x9b05688c, + 0x1f83d9ab, + 0x5be0cd19}; + + void Transform(); + void Pad(); + DigestType WriteDigest(); +}; + +} // namespace details +} // namespace scl + +#endif // SCL_PRIMITIVES_SHA256_H diff --git a/include/scl/hash.h b/include/scl/primitives/sha3.h similarity index 56% rename from include/scl/hash.h rename to include/scl/primitives/sha3.h index 9b25dee..5d7d7f9 100644 --- a/include/scl/hash.h +++ b/include/scl/primitives/sha3.h @@ -1,5 +1,5 @@ /** - * @file hash.h + * @file sha3.h * * SCL --- Secure Computation Library * Copyright (C) 2022 Anders Dalskov @@ -18,75 +18,47 @@ * along with this program. If not, see . */ -#ifndef SCL_HASH_H -#define SCL_HASH_H +#ifndef SCL_PRIMITIVES_SHA3_H +#define SCL_PRIMITIVES_SHA3_H #include +#include #include -#include -#include #include +#include "scl/primitives/digest.h" +#include "scl/primitives/iuf_hash.h" + namespace scl { +namespace details { /** - * @brief A hash function. - * - *

Hash defines an IUF (Initialize-Update-Finalize) style interface for a - * hash function. The current implementation is based on SHA3 and supports - * digest sizes of either 256, 384 or 512 bits.

- * - * @code - * // define a hash function object with 256-bit output. - * using Hash = scl::Hash<256>; - * - * unsigned char data[] = {'d', 'a', 't', 'a'}; - * Hash hash; - * hash.Update(data, 4); - * auto digest = hash.Finalize(); - * @endcode - * + * @brief SHA3 hash function. * @tparam DigestSize the output size in bits. Must be either 256, 384 or 512 */ template -class Hash { +class Sha3 final : public details::IUFHash> { static_assert(DigestSize == 256 || DigestSize == 384 || DigestSize == 512, - "B must be one of 256, 384 or 512"); + "Invalid SHA3 digest size. Must be 256, 384 or 512"); public: /** - * @brief The type of the final digest. - */ - using DigestType = std::array; - - /** - * @brief Initialize the hash function. + * @brief The type of a SHA3 digest. */ - Hash(){}; + using DigestType = typename details::Digest::Type; /** * @brief Update the hash function with a set of bytes. - * - * @param[in] bytes a pointer to a number of bytes. - * @param[in] nbytes the number of bytes. - * @return the updated Hash object. - */ - Hash &Update(const unsigned char *bytes, std::size_t nbytes); - - /** - * @brief Update the hash function with the content from a byte vector. - * - * @param bytes a vector of bytes. + * @param bytes a pointer to a number of bytes. + * @param nbytes the number of bytes. * @return the updated Hash object. */ - Hash &Update(const std::vector &bytes) { - return Update(bytes.data(), bytes.size()); - }; + void Hash(const unsigned char* bytes, std::size_t nbytes); /** * @brief Finalize and return the digest. */ - DigestType Finalize(); + DigestType Write(); private: static const std::size_t kStateSize = 25; @@ -100,40 +72,22 @@ class Hash { unsigned int mWordIndex = 0; }; -static const uint64_t keccakf_rndc[24] = { - 0x0000000000000001ULL, 0x0000000000008082ULL, 0x800000000000808aULL, - 0x8000000080008000ULL, 0x000000000000808bULL, 0x0000000080000001ULL, - 0x8000000080008081ULL, 0x8000000000008009ULL, 0x000000000000008aULL, - 0x0000000000000088ULL, 0x0000000080008009ULL, 0x000000008000000aULL, - 0x000000008000808bULL, 0x800000000000008bULL, 0x8000000000008089ULL, - 0x8000000000008003ULL, 0x8000000000008002ULL, 0x8000000000000080ULL, - 0x000000000000800aULL, 0x800000008000000aULL, 0x8000000080008081ULL, - 0x8000000000008080ULL, 0x0000000080000001ULL, 0x8000000080008008ULL}; - -static const unsigned int keccakf_rotc[24] = {1, 3, 6, 10, 15, 21, 28, 36, - 45, 55, 2, 14, 27, 41, 56, 8, - 25, 43, 62, 18, 39, 61, 20, 44}; - -static const unsigned int keccakf_piln[24] = {10, 7, 11, 17, 18, 3, 5, 16, - 8, 21, 24, 4, 15, 23, 19, 13, - 12, 2, 20, 14, 22, 9, 6, 1}; - /** * @brief Keccak function. * @param state the current state */ void Keccakf(uint64_t state[25]); -template -Hash &Hash::Update(const unsigned char *bytes, std::size_t nbytes) { +template +void Sha3::Hash(const unsigned char* bytes, std::size_t nbytes) { unsigned int old_tail = (8 - mByteIndex) & 7; - const unsigned char *p = bytes; + const unsigned char* p = bytes; if (nbytes < old_tail) { while (nbytes-- > 0) { mSaved |= (uint64_t)(*(p++)) << ((mByteIndex++) * 8); } - return *this; + return; } if (old_tail != 0) { @@ -174,12 +128,10 @@ Hash &Hash::Update(const unsigned char *bytes, std::size_t nbytes) { while (tail-- > 0) { mSaved |= (uint64_t)(*(p++)) << ((mByteIndex++) * 8); } - - return *this; } -template -auto Hash::Finalize() -> DigestType { +template +auto Sha3::Write() -> Sha3::DigestType { uint64_t t = (uint64_t)(((uint64_t)(0x02 | (1 << 2))) << ((mByteIndex)*8)); mState[mWordIndex] ^= mSaved ^ t; mState[kCuttoff - 1] ^= 0x8000000000000000ULL; @@ -207,21 +159,7 @@ auto Hash::Finalize() -> DigestType { return digest; } -/** - * @brief Convert a digest to a string. - * @param digest the digest - * @return a hex representation of the digest. - */ -template -std::string DigestToString(const D &digest) { - std::stringstream ss; - ss << std::setw(2) << std::setfill('0') << std::hex; - for (const auto &c : digest) { - ss << (int)c; - } - return ss.str(); -} - +} // namespace details } // namespace scl -#endif // SCL_HASH_H +#endif // SCL_PRIMITIVES_SHA3_H diff --git a/include/scl/ss/additive.h b/include/scl/ss/additive.h index 9d935af..0c598fb 100644 --- a/include/scl/ss/additive.h +++ b/include/scl/ss/additive.h @@ -24,7 +24,7 @@ #include #include "scl/math/vec.h" -#include "scl/prg.h" +#include "scl/primitives/prg.h" namespace scl { diff --git a/include/scl/ss/feldman.h b/include/scl/ss/feldman.h new file mode 100644 index 0000000..a3e731f --- /dev/null +++ b/include/scl/ss/feldman.h @@ -0,0 +1,159 @@ +/** + * @file feldman.h + * + * SCL --- Secure Computation Library + * Copyright (C) 2022 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_SS_FELDMAN_H +#define SCL_SS_FELDMAN_H + +#include +#include +#include +#include + +#include "scl/math/vec.h" +#include "scl/primitives/prg.h" +#include "scl/ss/shamir.h" + +namespace scl { + +/** + * @brief Class for working with Feldman secret-shares. + * + *

This class allows creation, verification and reconstruction of + * secret-shares from Feldman's verifiable secret-sharing scheme over a suitable + * elliptic curve.

+ * + *

The factory is instantiated with a type \p V which should be \ref EC + * for some curve definition. In particular, V::Order is assumed to + * be an \ref FF type.

+ */ +template +class FeldmanSSFactory { + /** + * @brief The type of secret-shares. + */ + using T = typename V::Order; + + public: + /** + * @brief Create a new FeldmanSSFactory object. + * @param threshold the privacy threshold + * @param prg a prg to use for generating randomness + */ + static FeldmanSSFactory Create(std::size_t threshold, PRG& prg) { + return FeldmanSSFactory( + threshold, + ShamirSSFactory::Create(threshold, prg, SecurityLevel::PASSIVE)); + }; + + /** + * @brief The result of calling Share + */ + struct ShareBundle { + /** + * @brief The secret shares. + */ + Vec shares; + /** + * @brief Commitments of the shares plus the secret. + */ + Vec commitments; + }; + + /** + * @brief Secret share a value. + * @param secret the secret to share + * @param number_of_shares the number of shares to generate + * @return shares and commitments of shares. + */ + ShareBundle Share(const T& secret, std::size_t number_of_shares) const { + auto shares = mBase.Share(secret, number_of_shares); + return {shares, ComputeCommitments(shares)}; + }; + + /** + * @brief Check that a share is consistent with a set of commitments. + */ + bool Verify(const T& share, const Vec& commitments, int party_index) const; + + /** + * @brief Check that a secret is consistent with a set of commitments. + */ + bool Verify(const T& secret, const Vec& commitments) const { + return Verify(secret, commitments, -1); + }; + + /** + * @brief Recover the value corresponding to a particular index + * @param shares the shares to recover + * @param index the index. Defaults to 0 + */ + T Recover(const Vec& shares, int index = 0) const { + return mBase.Recover(shares, index); + }; + + /** + * @brief Recover a share of a particular party + * @param shares the shares + * @param party_index the index of a the party whose share to recover + */ + T RecoverShare(const Vec& shares, int party_index = 0) const { + return mBase.RecoverShare(shares, party_index); + }; + + private: + FeldmanSSFactory(std::size_t threshold, ShamirSSFactory base) + : mThreshold(threshold), mBase(base){}; + + std::size_t mThreshold; + // mutable because calling Recover on this might trigger computation of new + // lagrange coefficients. + mutable ShamirSSFactory mBase; + + Vec ComputeCommitments(const Vec& shares) const; +}; + +template +bool FeldmanSSFactory::Verify(const T& share, + const Vec& commitments, + int party_index) const { + if (commitments.Size() < mThreshold + 1) { + throw std::invalid_argument("insufficient commitments for verification"); + } + // coefficients are indexed one-off. + auto& coeff = mBase.GetLagrangeCoefficients(party_index + 1); + auto v = details::UncheckedInnerProd( + coeff.begin(), coeff.end(), commitments.begin()); + return v == V::Generator() * share; +} + +template +Vec FeldmanSSFactory::ComputeCommitments(const Vec& shares) const { + std::vector c; + c.reserve(mThreshold + 1); + auto gen = V::Generator(); + for (std::size_t i = 0; i < mThreshold + 1; ++i) { + c.emplace_back(shares[i] * gen); + } + return Vec{c}; +} + +} // namespace scl + +#endif // SCL_SS_FELDMAN_H diff --git a/include/scl/ss/poly.h b/include/scl/ss/poly.h index 04340a9..679eb60 100644 --- a/include/scl/ss/poly.h +++ b/include/scl/ss/poly.h @@ -70,12 +70,16 @@ class Polynomial { /** * @brief Access coefficients, with the constant term at position 0. */ - T& operator[](std::size_t idx) { return mCoefficients[idx]; }; + T& operator[](std::size_t idx) { + return mCoefficients[idx]; + }; /** * @brief Access coefficients, with the constant term at position 0. */ - T operator[](std::size_t idx) const { return mCoefficients[idx]; }; + T operator[](std::size_t idx) const { + return mCoefficients[idx]; + }; /** * @brief Add two polynomials. @@ -101,22 +105,30 @@ class Polynomial { /** * @brief Returns true if this is the 0 polynomial. */ - bool IsZero() const { return Degree() == 0 && ConstantTerm() == T(0); }; + bool IsZero() const { + return Degree() == 0 && ConstantTerm() == T(0); + }; /** * @brief Get the constant term of this polynomial. */ - T ConstantTerm() const { return operator[](0); }; + T ConstantTerm() const { + return operator[](0); + }; /** * @brief Get the leading term of this polynomial. */ - T LeadingTerm() const { return operator[](Degree()); }; + T LeadingTerm() const { + return operator[](Degree()); + }; /** * @brief Degree of this polynomial. */ - std::size_t Degree() const { return mCoefficients.Size() - 1; }; + std::size_t Degree() const { + return mCoefficients.Size() - 1; + }; /** * @brief Get a string representation of this polynomial. @@ -135,7 +147,9 @@ class Polynomial { * @brief Get a string representation of this polynomial. * @note Equivalent to ToString("f", "x"). */ - std::string ToString() const { return ToString("f", "x"); }; + std::string ToString() const { + return ToString("f", "x"); + }; /** * @brief Write a string representation of this polynomial to a stream. diff --git a/include/scl/ss/shamir.h b/include/scl/ss/shamir.h index b38a4c2..dc77378 100644 --- a/include/scl/ss/shamir.h +++ b/include/scl/ss/shamir.h @@ -30,7 +30,7 @@ #include "scl/math/la.h" #include "scl/math/vec.h" -#include "scl/prg.h" +#include "scl/primitives/prg.h" #include "scl/ss/poly.h" namespace scl { @@ -55,7 +55,8 @@ namespace details { * @return a pair of polynomials. */ template -auto ReconstructShamirRobust(const Vec& shares, const Vec& alphas, +auto ReconstructShamirRobust(const Vec& shares, + const Vec& alphas, std::size_t t) { std::size_t n = 3 * t + 1; if (n > shares.Size()) { @@ -123,115 +124,148 @@ T ReconstructShamirRobust(const Vec& shares, std::size_t t) { ReconstructShamirRobust(shares, Vec::Range(1, shares.Size() + 1), t); return std::get<0>(p).Evaluate(T(0)); } +} // namespace details /** - * @brief Defines some common security levels. + * @brief Defines some commonly used security levels for Shamir secret-sharing. * - * These security levels are used to derive some common defaults. For example, - * SecurityLevel::DETECT is used to indicate that ShamirSSFactory, when - * instantiated with a threshold of \f$t\f$, should generate \f$2t + 1\f$ shares - * if no other value is specified. + * The SecurityLevel enum describes different threat models that Shamir + * secret-sharing is typically used with. Each value should be interpreted + * relative to an implicit threshold \f$t\f$. */ enum class SecurityLevel { /** - * @brief \f$t + 1\f$. + * @brief Passive security. + * + * This level describes a passive adversary threat model. When this level is + * used, only \f$t+1\f$ shares are created, and secrets are recovered with + * standard Lagrange interpolation without any form of checks. */ PASSIVE, /** - * @brief \f$2t + 1\f$. Enough shares to detect errors. + * @brief Active security with error detection. + * + * This level describes an active adversary threat model. When this level is + * used, \f$n = 2t + 1\f$ shares are created, and secrets are recovered by + * using the first \f$t + 1\f$ shares to recover the remaining \f$n - t\f$ + * shares, which are checked for equality. If this check fails, then + * reconstruction could not proceed. Otherwise the secret is recovered as in + * the passive case (in which case the secret is guaranteed to be the right + * one). */ DETECT, /** - * @brief \f$3t + 1\f$. Enough shares to correct errors. + * @brief Active security with error correction. + * + * This level describes an active adversary threat model. When this level is + * used, \f$3t+1\f$ shares are generated. Recovery proceeds via. the + * Berelkamp-Welch error correction algorithm and guarantees that the right + * secret is recovered, provided the input shares contain at most \f$t\f$ + * errors. */ CORRECT }; /** - * @brief Class for reconstructing secrets from Shamir secret-shares. + * @brief Class for working with Shamir secret-shares. * - *

The main purpose of this class is to cache the lagrange coefficients used - * when performing interpolation. These coefficients are computed when they are - * first needed. The only exception is the coefficients needed to reconstruct - * the point at 0, which is usually where the secret is placed.

+ *

This class can be used to create Shamir secret-shares of provided secrets, + * and to recover secrets from Shamir secret-shares. On creation, this class + * takes a threshold \f$t\f$ and a security level. The security level is used to + * determine (1) how many shares to create when secret-sharing a secret, and (2) + * how to recover a secret from a set of secret-shares. For example, if + * SecurityLevel::DETECT is passed as the security level and \f$t = 3\f$, then + * the factory object will create \f$2t + 1 = 7\f$ shares when secret-sharing + * something, and reconstruction uses an algorithm which detects errors in the + * provided shares (i.e., reconstruction in this case, if it succeeds, + * guarantees that the recovered secret is the right one).

* - *

A Reconstructor contains essentially two methods for reconstructing: - *

    - *
  • Reconstruct reconstructs a particular evaluation of the implied - * polynomial.
  • - *
  • ReconstructShare reconstructs a particular share of a party
  • - *
- * Both methods do effectively the same, and the only difference is how the - * index passed to them is interpreted. For the first, the index is interpreted - * as-is. I.e., if we pass an index of 3, then it will recover the point - * \f$f(3)\f$, where \f$f\f$ is the polynomial implied by the shares we provide. - * For the latter method, passing an index of \f$i\f$ implies that we wish to - * interpolate the point \f$f(i + 1)\f$. This is because we assume parties are - * indexed from 0, but their evaluation indices are indexed from 1 (because the - * 0'th index is reserved for the secret).

+ * @tparam T the share type. + * @see SecurityLevel */ template -class Reconstructor { +class ShamirSSFactory { public: /** - * @brief Create a new Reconstructor instance. - * - * This creator method creates a new instance and pre-computes the - * coefficients needed to recover a secret from a set of shamir shares. The - * SecurityLevel given is used to infer the default reconstruction method: - *
    - *
  • SecurityLevel::PASSIVE: use passive reconstruction. I.e., reconstruct - * the value we ask for, using the minimal amount of \f$t + 1\f$ shares - * needed.
  • - * - *
  • SecurityLevel::DETECT: use the first \f$t+1\f$ shares to reconstruct - * the remaning \f$t\f$ shares. If all of these check out, then reconstruct - * the value we seek.
  • - * - *
  • SecurityLevel::CORRECT: use an error correcting - * algorithm to recover the polynomial \f$f\f$ from \f$3t+1\f$ shares, - * assuming that at most \f$t\f$ shares are faulty.
  • - *
- * + * @brief Create a new object for working with Shamir secret-shares. * @param threshold the threshold \f$t\f$ + * @param prg the prg * @param default_security_level the default security level */ - static Reconstructor Create(std::size_t threshold, - SecurityLevel default_security_level); + static ShamirSSFactory Create(std::size_t threshold, + PRG& prg, + SecurityLevel default_security_level); /** - * @brief Interpolate a set of shares according to a given security level + * @brief Create a Shamir secret-sharing of a secret. + * @param secret the secret + * @param number_of_shares the number of shares to generate. If empty, then + * the number of shares is determined by GetDefaultNumberOfShares + * @return a vector of shares. + */ + Vec Share(const T& secret, + std::optional number_of_shares = {}); + + /** + * @brief Recover a secret from a vector of secret-shares. + * + * This method technically recovers a particular coefficient in the original + * polynomial used to generate the input shares. By default, the coefficient + * to recover is the constant term, which is where the secret lies. However, + * this method can also be used to recover the share of another party + * (although, \ref RecoverShare has a clearer syntax). + * * @param shares the shares - * @param security_level the security level + * @param security_level the security level to use * @param index the index to interpolate. Defaults to 0 + * @return the recovered secret. + * @throws std::invalid_argument if not enough shares is given for the + * provided security level. + * @throws std::logic_error if reconstruction fails (depends on the security + * level). */ - T Reconstruct(const Vec& shares, SecurityLevel security_level, - int index = 0) const; + T Recover(const Vec& shares, + SecurityLevel security_level, + int index = 0) const; /** - * @brief Interpolate a set of shares with the default security level + * @brief Recover a secret from a vector of secret-shares. + * + * This works like \ref Recover with the security level set to the one + * passed during creation of the factory. + * * @param shares the shares * @param index the index to interpolate. Defaults to 0 */ - T Reconstruct(const Vec& shares, int index = 0) const { - return Reconstruct(shares, mSecurityLevel, index); + T Recover(const Vec& shares, int index = 0) const { + return Recover(shares, mSecurityLevel, index); }; /** - * @brief Reconstruct the share of a party. + * @brief Recover the share of a particular party. + * + * Because parties' indices are counted from 1 when creating the + * secret-shares, this method is simply a shorthand for + * Reconstruct(shares, level, party_index + 1). + * + * @param shares the shares + * @param level the security level to use + * @param party_index the index of a party + * @return the recovered secret. */ - T ReconstructShare(const Vec& shares, SecurityLevel level, - int party_index = 0) const { - return Reconstruct(shares, level, party_index + 1); + T RecoverShare(const Vec& shares, + SecurityLevel level, + int party_index = 0) const { + return Recover(shares, level, party_index + 1); }; /** - * @brief Reconstruct the share of a party. + * @brief Recover the share of a party. */ - T ReconstructShare(const Vec& shares, int party_index = 0) const { - return Reconstruct(shares, party_index + 1); + T RecoverShare(const Vec& shares, int party_index = 0) const { + return RecoverShare(shares, mSecurityLevel, party_index); }; /** @@ -247,28 +281,48 @@ class Reconstructor { return mLagrangeCoeff[index]; }; + /** + * @brief Get the number of shares to create based on set security level + * + * The returned number is determined based on the \ref scl::SecurityLevel + * provided during construction. Given a threshold \f$t\f$, then the number of + * shares is computed based on the following rules: + */ + std::size_t GetDefaultNumberOfShares() const { + switch (mSecurityLevel) { + case SecurityLevel::PASSIVE: + return mThreshold + 1; + case SecurityLevel::DETECT: + return 2 * mThreshold + 1; + default: // SecurityLevel::ROBUST: + return 3 * mThreshold + 1; + } + }; + private: - Reconstructor(std::size_t threshold, SecurityLevel security_level) - : mThreshold(threshold), mSecurityLevel(security_level){}; + ShamirSSFactory(std::size_t threshold, PRG& prg, SecurityLevel security_level) + : mThreshold(threshold), mPrg(prg), mSecurityLevel(security_level){}; void ComputeLagrangeCoefficients(int index) const; std::size_t mThreshold; + PRG mPrg; SecurityLevel mSecurityLevel; mutable std::unordered_map> mLagrangeCoeff; }; template -Reconstructor Reconstructor::Create( - std::size_t threshold, SecurityLevel default_security_level) { - Reconstructor intr(threshold, default_security_level); - intr.ComputeLagrangeCoefficients(0); - return intr; +ShamirSSFactory ShamirSSFactory::Create( + std::size_t threshold, PRG& prg, SecurityLevel default_security_level) { + ShamirSSFactory ssf(threshold, prg, default_security_level); + ssf.ComputeLagrangeCoefficients(0); + return ssf; } template -T Reconstructor::Reconstruct(const Vec& shares, - SecurityLevel security_level, int index) const { +T ShamirSSFactory::Recover(const Vec& shares, + SecurityLevel security_level, + int index) const { if (security_level == SecurityLevel::CORRECT) { // TODO(anders): Currently only supports index = 0 return details::ReconstructShamirRobust(shares, mThreshold); @@ -302,7 +356,7 @@ T Reconstructor::Reconstruct(const Vec& shares, } template -void Reconstructor::ComputeLagrangeCoefficients(int index) const { +void ShamirSSFactory::ComputeLagrangeCoefficients(int index) const { Vec coeff(mThreshold + 1); const auto x = T(index); for (std::size_t j = 0; j <= mThreshold; ++j) { @@ -319,73 +373,31 @@ void Reconstructor::ComputeLagrangeCoefficients(int index) const { mLagrangeCoeff[index] = coeff; } +namespace details { + /** - * @brief A factory object for creating Shamir secret shares. - * @tparam T the finite field to use. + * @brief Create a polynomial for Shamir SS. + * @param secret the secret + * @param prg a PRG to use for generating the coefficients + * @param degree the degree of the polynomial + * @return a degree \p degree polynomial. */ template -class ShamirSSFactory { - public: - /** - * @brief Create a new Shamir secret share factory - * @param t the privacy threshold to use - * @param prg the PRG to use when generating new shares - * @param security_level the SecurityLevel to use. Default to - * SecurityLevel::DETECT - */ - ShamirSSFactory(std::size_t t, PRG& prg, - SecurityLevel security_level = SecurityLevel::DETECT) - : mThreshold(t), mPrg(prg), mDefaultSecurityLevel(security_level){}; - - /** - * @brief Create a new set of Shamir secret shares of a secret - * @param secret the secret - * @param number_of_shares the number of shares to generate. If empty, then - * the number of shares is determined by GetDefaultNumberOfShares - * @return a vector of shares. - */ - Vec Share(const T& secret, - std::optional number_of_shares = {}); - - /** - * @brief Get the number of shares to create based on set security level - * - * The returned number is determined based on the \ref SecurityLevel - * provided during construction. Given a threshold \f$t\f$, then the number of - * shares is computed based on the following rules: - */ - std::size_t GetDefaultNumberOfShares() const { - switch (mDefaultSecurityLevel) { - case SecurityLevel::PASSIVE: - return mThreshold + 1; - case SecurityLevel::DETECT: - return 2 * mThreshold + 1; - default: // SecurityLevel::ROBUST: - return 3 * mThreshold + 1; - } - }; - - /** - * @brief Get an scl::Interpolator suitable for this factory. - */ - Reconstructor GetInterpolator() const { - return Reconstructor::Create(mThreshold, mDefaultSecurityLevel); - }; +Polynomial CreateSharePolynomial(const T& secret, + PRG& prg, + std::size_t degree) { + auto coeff = Vec::PartialRandom( + degree + 1, [](std::size_t i) { return i > 0; }, prg); + coeff[0] = secret; + return Polynomial::Create(coeff); +} - private: - std::size_t mThreshold; - PRG mPrg; - SecurityLevel mDefaultSecurityLevel; -}; +} // namespace details template Vec ShamirSSFactory::Share(const T& secret, std::optional number_of_shares) { - auto coeff = Vec::PartialRandom( - mThreshold + 1, [](std::size_t i) { return i > 0; }, mPrg); - coeff[0] = secret; - auto p = details::Polynomial::Create(coeff); - + auto p = details::CreateSharePolynomial(secret, mPrg, mThreshold); auto n = number_of_shares.value_or(GetDefaultNumberOfShares()); std::vector shares; shares.reserve(n); @@ -395,7 +407,6 @@ Vec ShamirSSFactory::Share(const T& secret, return Vec(shares); } -} // namespace details } // namespace scl #endif // SCL_SS_SHAMIR_H diff --git a/include/scl/math/str.h b/include/scl/util/str.h similarity index 84% rename from include/scl/math/str.h rename to include/scl/util/str.h index f337729..2e9a73e 100644 --- a/include/scl/math/str.h +++ b/include/scl/util/str.h @@ -18,8 +18,8 @@ * along with this program. If not, see . */ -#ifndef SCL_MATH_STR_H -#define SCL_MATH_STR_H +#ifndef SCL_UTIL_STR_H +#define SCL_UTIL_STR_H #include #include @@ -77,6 +77,22 @@ std::string ToHexString(const T& v) { return ss.str(); } +/** + * @brief Convert a list of bytes to a string. + * @param begin the start of an iterator + * @param end the end of an iterator + * @return a hex representation of the digest. + */ +template +std::string ToHexString(It begin, It end) { + std::stringstream ss; + ss << std::setfill('0') << std::hex; + while (begin != end) { + ss << std::setw(2) << static_cast(*begin++); + } + return ss.str(); +} + /** * @brief ToHexString specialization for __uint128_t. */ @@ -86,4 +102,4 @@ std::string ToHexString(const __uint128_t& v); } // namespace details } // namespace scl -#endif // SCL_MATH_STR_H +#endif // SCL_UTIL_STR_H diff --git a/scripts/check_formatting.sh b/scripts/check_formatting.sh index 0a9aed7..0647985 100755 --- a/scripts/check_formatting.sh +++ b/scripts/check_formatting.sh @@ -1,3 +1,3 @@ #!/usr/bin/bash -find . -type f \( -iname "*.h" -o -iname "*.cc" \) -exec clang-format -n --style=Google {} \; +find . -type f \( -iname "*.h" -o -iname "*.cc" \) -exec clang-format -n {} \; diff --git a/src/scl/math/mersenne127.cc b/src/scl/math/mersenne127.cc index d024a5a..5cacb14 100644 --- a/src/scl/math/mersenne127.cc +++ b/src/scl/math/mersenne127.cc @@ -25,7 +25,7 @@ #include "./ops_small_fp.h" #include "scl/math/ff_ops.h" -#include "scl/math/str.h" +#include "scl/util/str.h" using u64 = std::uint64_t; using u128 = __uint128_t; @@ -111,6 +111,12 @@ void scl::details::FieldFromBytes(u128& dest, dest = dest % p; } +template <> +void scl::details::FieldToBytes(unsigned char* dest, + const u128& src) { + std::memcpy(dest, &src, sizeof(u128)); +} + template <> std::string scl::details::FieldToString(const u128& in) { return ToHexString(in); diff --git a/src/scl/math/mersenne61.cc b/src/scl/math/mersenne61.cc index 2866028..f916e0f 100644 --- a/src/scl/math/mersenne61.cc +++ b/src/scl/math/mersenne61.cc @@ -26,7 +26,7 @@ #include "./ops_small_fp.h" #include "scl/math/ff_ops.h" -#include "scl/math/str.h" +#include "scl/util/str.h" using u64 = std::uint64_t; using u128 = __uint128_t; @@ -84,6 +84,12 @@ void scl::details::FieldFromBytes(u64& dest, dest = dest % p; } +template <> +void scl::details::FieldToBytes(unsigned char* dest, + const u64& src) { + std::memcpy(dest, &src, sizeof(u64)); +} + template <> std::string scl::details::FieldToString(const u64& in) { return ToHexString(in); diff --git a/src/scl/math/number.cc b/src/scl/math/number.cc index 0134768..148b110 100644 --- a/src/scl/math/number.cc +++ b/src/scl/math/number.cc @@ -25,7 +25,9 @@ #include #include -scl::Number::Number() { mpz_init(mValue); } +scl::Number::Number() { + mpz_init(mValue); +} scl::Number::Number(const Number& number) : Number() { mpz_set(mValue, number.mValue); @@ -35,7 +37,9 @@ scl::Number::Number(Number&& number) noexcept : Number() { mpz_set(mValue, number.mValue); } -scl::Number::~Number() { mpz_clear(mValue); } +scl::Number::~Number() { + mpz_clear(mValue); +} scl::Number scl::Number::Random(std::size_t bits, PRG& prg) { auto len = (bits - 1) / 8 + 2; @@ -59,7 +63,9 @@ scl::Number scl::Number::FromString(const std::string& str) { return num; } // LCOV_EXCL_LINE -scl::Number::Number(int value) : Number() { mpz_set_si(mValue, value); } +scl::Number::Number(int value) : Number() { + mpz_set_si(mValue, value); +} scl::Number scl::Number::operator+(const Number& number) const { scl::Number sum; @@ -139,7 +145,9 @@ int scl::Number::Compare(const Number& number) const { return mpz_cmp(mValue, number.mValue); } -std::size_t scl::Number::BitSize() const { return mpz_sizeinbase(mValue, 2); } +std::size_t scl::Number::BitSize() const { + return mpz_sizeinbase(mValue, 2); +} bool scl::Number::TestBit(std::size_t index) const { return mpz_tstbit(mValue, index); diff --git a/src/scl/math/ops_gmp_ff.cc b/src/scl/math/ops_gmp_ff.cc index 8cc3d18..7ef0aba 100644 --- a/src/scl/math/ops_gmp_ff.cc +++ b/src/scl/math/ops_gmp_ff.cc @@ -20,16 +20,7 @@ #include "./ops_gmp_ff.h" -void scl::details::ReadLimb(mp_limb_t &lmb, const unsigned char *bytes, - std::size_t bits_per_limbs) { - std::size_t c = 0; - lmb = 0; - for (std::size_t i = 0; i < bits_per_limbs; i += 8) { - lmb |= static_cast(bytes[c++]) << i; - } -} - -std::size_t scl::details::FindFirstNonZero(const std::string &s) { +std::size_t scl::details::FindFirstNonZero(const std::string& s) { int n = 0; for (const auto c : s) { if (c != '0') { diff --git a/src/scl/math/ops_gmp_ff.h b/src/scl/math/ops_gmp_ff.h index 704d14a..d712041 100644 --- a/src/scl/math/ops_gmp_ff.h +++ b/src/scl/math/ops_gmp_ff.h @@ -21,14 +21,15 @@ #ifndef SCL_MATH_OPS_GMP_FF_H #define SCL_MATH_OPS_GMP_FF_H -#include - #include #include +#include #include #include -#include "scl/math/str.h" +#include + +#include "scl/util/str.h" namespace scl { namespace details { @@ -138,7 +139,9 @@ void ModSub(mp_limb_t* out, const mp_limb_t* op, const mp_limb_t* mod) { */ template void ModNeg(mp_limb_t* out, const mp_limb_t* mod) { - mpn_sub_n(out, mod, out, N); + mp_limb_t t[N] = {0}; + ModSub(t, out, mod); + SCL_COPY(out, t, N); } /** @@ -150,7 +153,9 @@ void ModNeg(mp_limb_t* out, const mp_limb_t* mod) { * @see Redc */ template -void ModMul(mp_limb_t* out, const mp_limb_t* op, const mp_limb_t* mod, +void ModMul(mp_limb_t* out, + const mp_limb_t* op, + const mp_limb_t* mod, const mp_limb_t* np) { mp_limb_t res[2 * N]; mpn_mul_n(res, out, op, N); @@ -166,7 +171,9 @@ void ModMul(mp_limb_t* out, const mp_limb_t* op, const mp_limb_t* mod, * @param np a constant used for montgomery reduction */ template -void ModSqr(mp_limb_t* out, const mp_limb_t* op, const mp_limb_t* mod, +void ModSqr(mp_limb_t* out, + const mp_limb_t* op, + const mp_limb_t* mod, const mp_limb_t* np) { mp_limb_t res[2 * N]; mpn_sqr(res, op, N); @@ -196,8 +203,11 @@ inline bool TestBit(const mp_limb_t* v, std::size_t pos) { * @param np a constant used for montgomery reduction */ template -void ModExp(mp_limb_t* out, const mp_limb_t* x, const mp_limb_t* e, - const mp_limb_t* mod, const mp_limb_t* np) { +void ModExp(mp_limb_t* out, + const mp_limb_t* x, + const mp_limb_t* e, + const mp_limb_t* mod, + const mp_limb_t* np) { auto n = mpn_sizeinbase(e, N, 2); for (std::size_t i = n; i-- > 0;) { ModSqr(out, out, mod, np); @@ -208,8 +218,11 @@ void ModExp(mp_limb_t* out, const mp_limb_t* x, const mp_limb_t* e, } template -void ModInvFermat(mp_limb_t* out, const mp_limb_t* op, const mp_limb_t* mod, - const mp_limb_t* mod_minus_2, const mp_limb_t* np) { +void ModInvFermat(mp_limb_t* out, + const mp_limb_t* op, + const mp_limb_t* mod, + const mp_limb_t* mod_minus_2, + const mp_limb_t* np) { if (mpn_zero_p(op, N)) { throw std::invalid_argument("0 not invertible modulo prime"); } @@ -226,16 +239,34 @@ int CompareValues(const mp_limb_t* lhs, const mp_limb_t* rhs) { return mpn_cmp(lhs, rhs, N); } -void ReadLimb(mp_limb_t& lmb, const unsigned char* bytes, - std::size_t bits_per_limb); +template +void ValueFromBytes(mp_limb_t* out, + const unsigned char* src, + const mp_limb_t* mod) { + for (int i = N - 1; i >= 0; --i) { + for (int j = BYTES_PER_LIMB - 1; j >= 0; --j) { + out[i] |= static_cast(*src++) << (j * 8); + } + } + + ToMonty(out, mod); +} -/** - * @brief Read a value from a byte array. - */ template -void ValueFromBytes(mp_limb_t* out, const unsigned char* src) { - for (std::size_t i = 0; i < N; ++i) { - ReadLimb(out[i], src + i * BYTES_PER_LIMB, BITS_PER_LIMB); +void ValueToBytes(unsigned char* dest, + const mp_limb_t* src, + const mp_limb_t* mod, + const mp_limb_t* np) { + mp_limb_t padded[2 * N] = {0}; + SCL_COPY(padded, src, N); + Redc(padded, mod, np); + + std::size_t c = 0; + for (int i = N - 1; i >= 0; --i) { + const auto v = padded[i]; + for (int j = BYTES_PER_LIMB - 1; j >= 0; --j) { + dest[c++] = v >> (j * 8); + } } } @@ -245,7 +276,8 @@ std::size_t FindFirstNonZero(const std::string& s); * @brief Print a value. */ template -std::string ToString(const mp_limb_t* val, const mp_limb_t* mod, +std::string ToString(const mp_limb_t* val, + const mp_limb_t* mod, const mp_limb_t* np) { mp_limb_t padded[2 * N] = {0}; SCL_COPY(padded, val, N); diff --git a/src/scl/math/secp256k1_curve.cc b/src/scl/math/secp256k1_curve.cc index afa5b9f..522ed28 100644 --- a/src/scl/math/secp256k1_curve.cc +++ b/src/scl/math/secp256k1_curve.cc @@ -19,6 +19,7 @@ */ #include +#include #include #include "./secp256k1_extras.h" @@ -57,7 +58,8 @@ bool Valid(const Field& x, const Field& y) { } // namespace template <> -void scl::details::CurveSetAffine(Point& out, const Field& x, +void scl::details::CurveSetAffine(Point& out, + const Field& x, const Field& y) { if (Valid(x, y)) { out = {x, y, Field(1)}; @@ -226,11 +228,16 @@ void scl::details::CurveScalarMultiply(Point& out, } } -#define COMPRESSED_FLAG 0x04 +// Flag indicating that the point was serialized as an (X, Y) pair +#define FULL_POINT_FLAG 0x04 +// Flag indicating that the serialized point was the point at infinity. If the +// point was also serialized as a FULL_POINT, then we write the pair (0, 0). #define POINT_AT_INFINITY_FLAG 0x02 +// Flag indicating which of Y or -Y to select in case we serialize the point in +// compressed form. #define SELECT_SMALLER_FLAG 0x01 -#define IS_COMPRESSED(flags) ((flags)&COMPRESSED_FLAG) +#define IS_FULL_POINT(flags) ((flags)&FULL_POINT_FLAG) #define IS_POINT_AT_INFINITY(flags) ((flags)&POINT_AT_INFINITY_FLAG) #define SELECT_SMALLER(flags) ((flags)&SELECT_SMALLER_FLAG) @@ -258,7 +265,11 @@ void scl::details::CurveFromBytes(Point& out, const unsigned char* src) { // send the point-at-infinity. CurveSetPointAtInfinity(out); } else { - if (IS_COMPRESSED(flags)) { + if (IS_FULL_POINT(flags)) { + out[0] = Field::Read(src + 1); + out[1] = Field::Read(src + 1 + Field::ByteSize()); + out[2] = Field::One(); + } else { Field x = Field::Read(src + 1); out[0] = x; @@ -274,27 +285,24 @@ void scl::details::CurveFromBytes(Point& out, const unsigned char* src) { } else { out[1] = select_smaller == 0 ? y : yn; } - } else { - out[0] = Field::Read(src + 1); - out[1] = Field::Read(src + 1 + Field::ByteSize()); - out[2] = Field::One(); } } } -#define MARK_COMPRESSED(buf) (*(buf) |= COMPRESSED_FLAG) +#define MARK_FULL_POINT(buf) (*(buf) |= FULL_POINT_FLAG) #define MARK_POINT_AT_INFINITY(buf) (*(buf) |= POINT_AT_INFINITY_FLAG) #define MARK_SELECT_SMALLER(buf) (*(buf) |= SELECT_SMALLER_FLAG) template <> -void scl::details::CurveToBytes(unsigned char* dest, const Point& in, +void scl::details::CurveToBytes(unsigned char* dest, + const Point& in, bool compress) { // Make sure flag byte is zeroed. *dest = 0; - // if point is compressed, mark it as such. - if (compress) { - MARK_COMPRESSED(dest); + // if point is un-compressed, mark it as such. + if (!compress) { + MARK_FULL_POINT(dest); } if (CurveIsPointAtInfinity(in)) { diff --git a/src/scl/math/secp256k1_field.cc b/src/scl/math/secp256k1_field.cc index c78ddcd..5d3323c 100644 --- a/src/scl/math/secp256k1_field.cc +++ b/src/scl/math/secp256k1_field.cc @@ -101,12 +101,12 @@ bool scl::details::FieldEqual(const Elem& in1, const Elem& in2) { template <> void scl::details::FieldFromBytes(Elem& dest, const unsigned char* src) { - ValueFromBytes(PTR(dest), src); + ValueFromBytes(PTR(dest), src, kPrime); } template <> -std::string scl::details::FieldToString(const Elem& in) { - return ToString(PTR(in), kPrime, kMontyN); +void scl::details::FieldToBytes(unsigned char* dest, const Elem& src) { + ValueToBytes(dest, PTR(src), kPrime, kMontyN); } template <> @@ -115,6 +115,11 @@ void scl::details::FieldFromString(Elem& out, const std::string& src) { FromString(PTR(out), kPrime, src); } +template <> +std::string scl::details::FieldToString(const Elem& in) { + return ToString(PTR(in), kPrime, kMontyN); +} + bool scl::SCL_FF_Extras::IsSmaller( const scl::FF& lhs, const scl::FF& rhs) { diff --git a/src/scl/math/secp256k1_order.cc b/src/scl/math/secp256k1_order.cc index 8a78a93..c5abbac 100644 --- a/src/scl/math/secp256k1_order.cc +++ b/src/scl/math/secp256k1_order.cc @@ -18,11 +18,11 @@ * along with this program. If not, see . */ -#include - #include #include +#include + #include "./ops_gmp_ff.h" #include "./secp256k1_extras.h" #include "scl/math/curves/secp256k1.h" @@ -101,7 +101,12 @@ bool scl::details::FieldEqual(const Elem& in1, const Elem& in2) { template <> void scl::details::FieldFromBytes(Elem& dest, const unsigned char* src) { - ValueFromBytes(PTR(dest), src); + ValueFromBytes(PTR(dest), src, kPrime); +} + +template <> +void scl::details::FieldToBytes(unsigned char* dest, const Elem& src) { + ValueToBytes(dest, PTR(src), kPrime, kMontyN); } template <> diff --git a/src/scl/net/config.cc b/src/scl/net/config.cc index 356752e..36e8364 100644 --- a/src/scl/net/config.cc +++ b/src/scl/net/config.cc @@ -73,7 +73,8 @@ scl::NetworkConfig scl::NetworkConfig::Load(int id, return NetworkConfig(id, info); } -scl::NetworkConfig scl::NetworkConfig::Localhost(int id, int size, +scl::NetworkConfig scl::NetworkConfig::Localhost(int id, + int size, int port_base) { ValidateIdAndSize(id, size); diff --git a/src/scl/net/mem_channel.cc b/src/scl/net/mem_channel.cc index 72f186c..e2caebf 100644 --- a/src/scl/net/mem_channel.cc +++ b/src/scl/net/mem_channel.cc @@ -57,7 +57,8 @@ std::size_t scl::InMemoryChannel::Recv(unsigned char* dst, std::size_t n) { const auto old_size = mOverflow.size(); mOverflow.reserve(old_size + leftovers); mOverflow.insert(mOverflow.begin() + DIFF_T(old_size), - data.begin() + DIFF_T(to_copy), data.end()); + data.begin() + DIFF_T(to_copy), + data.end()); } } diff --git a/src/scl/net/tcp_utils.cc b/src/scl/net/tcp_utils.cc index bec9b62..50a1f55 100644 --- a/src/scl/net/tcp_utils.cc +++ b/src/scl/net/tcp_utils.cc @@ -20,10 +20,6 @@ #include "scl/net/tcp_utils.h" -#include -#include -#include - #include #include #include @@ -32,6 +28,10 @@ #include #include +#include +#include +#include + int scl::details::CreateServerSocket(int port, int backlog) { int err; int ssock = ::socket(AF_INET, SOCK_STREAM, 0); @@ -117,14 +117,18 @@ int scl::details::ConnectAsClient(const std::string& hostname, int port) { return sock; } -int scl::details::CloseSocket(int socket) { return ::close(socket); } +int scl::details::CloseSocket(int socket) { + return ::close(socket); +} -ssize_t scl::details::ReadFromSocket(int socket, unsigned char* dst, +ssize_t scl::details::ReadFromSocket(int socket, + unsigned char* dst, std::size_t n) { return ::read(socket, dst, n); } -ssize_t scl::details::WriteToSocket(int socket, const unsigned char* src, +ssize_t scl::details::WriteToSocket(int socket, + const unsigned char* src, std::size_t n) { return ::write(socket, src, n); } diff --git a/src/scl/prg.cc b/src/scl/primitives/prg.cc similarity index 51% rename from src/scl/prg.cc rename to src/scl/primitives/prg.cc index d08150f..2e4bb19 100644 --- a/src/scl/prg.cc +++ b/src/scl/primitives/prg.cc @@ -18,17 +18,19 @@ * along with this program. If not, see . */ -#include "scl/prg.h" +#include "scl/primitives/prg.h" #include #include -/* https://github.com/sebastien-riou/aes-brute-force */ +#include -using byte_t = unsigned char; -using block_t = __m128i; +/** + * PRG implementation based on AES-CTR with code from + * https://github.com/sebastien-riou/aes-brute-force + */ -using std::size_t; +#define BLOCK_SIZE sizeof(__m128i) #define DO_ENC_BLOCK(m, k) \ do { \ @@ -45,10 +47,12 @@ using std::size_t; (m) = _mm_aesenclast_si128(m, (k)[10]); \ } while (0) -#define AES_128_key_exp(k, rcon) \ - aes_128_key_expansion(k, _mm_aeskeygenassist_si128(k, rcon)) +#define AES_128_KEY_EXP(k, rcon) \ + Aes128KeyExpansion(k, _mm_aeskeygenassist_si128(k, rcon)) -inline static block_t aes_128_key_expansion(block_t key, block_t keygened) { +namespace { + +auto Aes128KeyExpansion(__m128i key, __m128i keygened) { keygened = _mm_shuffle_epi32(keygened, _MM_SHUFFLE(3, 3, 3, 3)); key = _mm_xor_si128(key, _mm_slli_si128(key, 4)); key = _mm_xor_si128(key, _mm_slli_si128(key, 4)); @@ -56,74 +60,72 @@ inline static block_t aes_128_key_expansion(block_t key, block_t keygened) { return _mm_xor_si128(key, keygened); } -inline static void aes128_load_key(byte_t* enc_key, block_t* key_schedule) { - key_schedule[0] = _mm_loadu_si128((const block_t*)enc_key); - key_schedule[1] = AES_128_key_exp(key_schedule[0], 0x01); - key_schedule[2] = AES_128_key_exp(key_schedule[1], 0x02); - key_schedule[3] = AES_128_key_exp(key_schedule[2], 0x04); - key_schedule[4] = AES_128_key_exp(key_schedule[3], 0x08); - key_schedule[5] = AES_128_key_exp(key_schedule[4], 0x10); - key_schedule[6] = AES_128_key_exp(key_schedule[5], 0x20); - key_schedule[7] = AES_128_key_exp(key_schedule[6], 0x40); - key_schedule[8] = AES_128_key_exp(key_schedule[7], 0x80); - key_schedule[9] = AES_128_key_exp(key_schedule[8], 0x1B); - key_schedule[10] = AES_128_key_exp(key_schedule[9], 0x36); +void Aes128LoadKey(const unsigned char* enc_key, __m128i* key_schedule) { + const auto* k = reinterpret_cast(enc_key); + key_schedule[0] = _mm_loadu_si128(k); + key_schedule[1] = AES_128_KEY_EXP(key_schedule[0], 0x01); + key_schedule[2] = AES_128_KEY_EXP(key_schedule[1], 0x02); + key_schedule[3] = AES_128_KEY_EXP(key_schedule[2], 0x04); + key_schedule[4] = AES_128_KEY_EXP(key_schedule[3], 0x08); + key_schedule[5] = AES_128_KEY_EXP(key_schedule[4], 0x10); + key_schedule[6] = AES_128_KEY_EXP(key_schedule[5], 0x20); + key_schedule[7] = AES_128_KEY_EXP(key_schedule[6], 0x40); + key_schedule[8] = AES_128_KEY_EXP(key_schedule[7], 0x80); + key_schedule[9] = AES_128_KEY_EXP(key_schedule[8], 0x1B); + key_schedule[10] = AES_128_KEY_EXP(key_schedule[9], 0x36); } -inline static void aes128_enc(block_t* key_schedule, byte_t* pt, byte_t* ct) { - block_t m = _mm_loadu_si128((block_t*)pt); +void Aes128Enc(const __m128i* key_schedule, __m128i m, unsigned char* ct) { DO_ENC_BLOCK(m, key_schedule); - _mm_storeu_si128((block_t*)ct, m); + _mm_storeu_si128(reinterpret_cast<__m128i*>(ct), m); +} + +auto create_mask(long counter) { + return _mm_set_epi64x(PRG_NONCE, counter); } -scl::PRG::PRG() { Init(); } +} // namespace scl::PRG::PRG(const unsigned char* seed) { - memcpy(mSeed, seed, SeedSize()); + if (seed != nullptr) { + std::copy(seed, seed + SeedSize(), mSeed.begin()); + } Init(); } -void scl::PRG::Update() { mCounter += 1; } +void scl::PRG::Update() { + mCounter += 1; +} -void scl::PRG::Init() { aes128_load_key(mSeed, mState); } +void scl::PRG::Init() { + Aes128LoadKey(mSeed.data(), mState); +} void scl::PRG::Reset() { Init(); mCounter = PRG_INITIAL_COUNTER; } -static inline auto create_mask(const long counter) { - return _mm_set_epi64x(PRG_NONCE, counter); -} - -void scl::PRG::Next(byte_t* dest, size_t nbytes) { - if (nbytes == 0) { +void scl::PRG::Next(unsigned char* buffer, size_t n) { + if (n == 0) { return; } - size_t nblocks = nbytes / BlockSize(); + auto nblocks = n / BLOCK_SIZE; - if ((nbytes % BlockSize()) != 0) { + if ((n % BLOCK_SIZE) != 0) { nblocks++; } - block_t mask = create_mask(mCounter); - byte_t* out = (byte_t*)malloc(nblocks * BlockSize()); - byte_t* p = out; - - // LCOV_EXCL_START - if (out == nullptr) { - throw std::runtime_error("Could not allocate memory for PRG."); - } - // LCOV_EXCL_STOP - + auto mask = create_mask(mCounter); + auto out = std::make_unique(nblocks * BLOCK_SIZE); + auto* p = out.get(); for (size_t i = 0; i < nblocks; i++) { - aes128_enc(mState, (byte_t*)(&mask), p); + Aes128Enc(mState, mask, p); Update(); mask = create_mask(mCounter); - p += BlockSize(); + p += BLOCK_SIZE; } - memcpy(dest, out, nbytes); - free(out); + std::copy(out.get(), out.get() + n, buffer); } diff --git a/src/scl/primitives/sha256.cc b/src/scl/primitives/sha256.cc new file mode 100644 index 0000000..3f3c57a --- /dev/null +++ b/src/scl/primitives/sha256.cc @@ -0,0 +1,171 @@ +/** + * @file sha256.cc + * + * SCL --- Secure Computation Library + * Copyright (C) 2022 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#include "scl/primitives/sha256.h" + +#include +#include + +/** + * SHA-256 implementation based on https://github.com/System-Glitch/SHA256. + */ + +namespace { + +auto RotR(uint32_t x, unsigned n) { + return (x >> n) | (x << (32 - n)); +} + +auto Sig0(uint32_t x) { + return RotR(x, 7) ^ RotR(x, 18) ^ (x >> 3); +} + +auto Sig1(uint32_t x) { + return RotR(x, 17) ^ RotR(x, 19) ^ (x >> 10); +} + +auto Split(std::array& chunk) { + std::array split; + for (std::size_t i = 0, j = 0; i < 16; ++i, j += 4) { + split[i] = (chunk[j] << 24) // + | (chunk[j + 1] << 16) // + | (chunk[j + 2] << 8) // + | (chunk[j + 3]); + } + + for (std::size_t i = 16; i < 64; ++i) { + split[i] = Sig1(split[i - 2]) + Sig0(split[i - 15]); + split[i] += split[i - 7] + split[i - 16]; + } + + return split; +} + +auto Majority(uint32_t x, uint32_t y, uint32_t z) { + return (x & (y | z)) | (y & z); +} + +auto Choose(uint32_t x, uint32_t y, uint32_t z) { + return (x & y) ^ (~x & z); +} + +} // namespace + +void scl::details::Sha256::Transform() { + // round constants. + static constexpr std::array k = { + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, + 0x923f82a4, 0xab1c5ed5, 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, + 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786, + 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, + 0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, + 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b, + 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, + 0x5b9cca4f, 0x682e6ff3, 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, + 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2}; + + const auto m = Split(mChunk); + auto s = mState; + + for (std::size_t i = 0; i < 64; ++i) { + const auto maj = Majority(s[0], s[1], s[2]); + const auto chs = Choose(s[4], s[5], s[6]); + + const auto xor_a = RotR(s[0], 2) ^ RotR(s[0], 13) ^ RotR(s[0], 22); + const auto xor_e = RotR(s[4], 6) ^ RotR(s[4], 11) ^ RotR(s[4], 25); + + const auto sum = m[i] + k[i] + s[7] + chs + xor_e; + + const auto new_a = xor_a + maj + sum; + const auto new_e = s[3] + sum; + + s[7] = s[6]; + s[6] = s[5]; + s[5] = s[4]; + s[4] = new_e; + s[3] = s[2]; + s[2] = s[1]; + s[1] = s[0]; + s[0] = new_a; + } + + for (std::size_t i = 0; i < 8; ++i) { + mState[i] += s[i]; + } +} + +void scl::details::Sha256::Pad() { + auto i = mChunkPos; + const auto end = mChunkPos < 56U ? 56U : 64U; + + mChunk[i++] = 0x80; + while (i < end) { + mChunk[i++] = 0; + } + + if (mChunkPos >= 56) { + Transform(); + std::fill(mChunk.begin(), mChunk.begin() + 56, 0); + } + + mTotalLen += static_cast(mChunkPos) * 8; + + mChunk[63] = mTotalLen; + mChunk[62] = mTotalLen >> 8; + mChunk[61] = mTotalLen >> 16; + mChunk[60] = mTotalLen >> 24; + mChunk[59] = mTotalLen >> 32; + mChunk[58] = mTotalLen >> 40; + mChunk[57] = mTotalLen >> 48; + mChunk[56] = mTotalLen >> 56; + + Transform(); +} + +scl::details::Sha256::DigestType scl::details::Sha256::WriteDigest() { + scl::details::Sha256::DigestType digest; + + for (std::size_t i = 0; i < 4; ++i) { + for (std::size_t j = 0; j < 8; ++j) { + digest[i + (j * 4)] = (mState[j] >> (24 - i * 8)) & 0xFF; + } + } + + return digest; +} + +void scl::details::Sha256::Hash(const unsigned char* bytes, + std::size_t nbytes) { + for (std::size_t i = 0; i < nbytes; ++i) { + mChunk[mChunkPos++] = bytes[i]; + if (mChunkPos == 64) { + Transform(); + mTotalLen += 512; + mChunkPos = 0; + } + } +} + +scl::details::Sha256::DigestType scl::details::Sha256::Write() { + Pad(); + return WriteDigest(); +} diff --git a/src/scl/hash.cc b/src/scl/primitives/sha3.cc similarity index 55% rename from src/scl/hash.cc rename to src/scl/primitives/sha3.cc index 6c34d9b..d6bbd4b 100644 --- a/src/scl/hash.cc +++ b/src/scl/primitives/sha3.cc @@ -1,5 +1,5 @@ /** - * @file hash.cc + * @file sha3.cc * * SCL --- Secure Computation Library * Copyright (C) 2022 Anders Dalskov @@ -18,13 +18,35 @@ * along with this program. If not, see . */ -#include "scl/hash.h" +#include "scl/primitives/sha3.h" -static inline uint64_t rotl64(uint64_t x, uint64_t y) { +namespace { + +const uint64_t keccakf_rndc[24] = { + 0x0000000000000001ULL, 0x0000000000008082ULL, 0x800000000000808aULL, + 0x8000000080008000ULL, 0x000000000000808bULL, 0x0000000080000001ULL, + 0x8000000080008081ULL, 0x8000000000008009ULL, 0x000000000000008aULL, + 0x0000000000000088ULL, 0x0000000080008009ULL, 0x000000008000000aULL, + 0x000000008000808bULL, 0x800000000000008bULL, 0x8000000000008089ULL, + 0x8000000000008003ULL, 0x8000000000008002ULL, 0x8000000000000080ULL, + 0x000000000000800aULL, 0x800000008000000aULL, 0x8000000080008081ULL, + 0x8000000000008080ULL, 0x0000000080000001ULL, 0x8000000080008008ULL}; + +const unsigned int keccakf_rotc[24] = {1, 3, 6, 10, 15, 21, 28, 36, + 45, 55, 2, 14, 27, 41, 56, 8, + 25, 43, 62, 18, 39, 61, 20, 44}; + +const unsigned int keccakf_piln[24] = {10, 7, 11, 17, 18, 3, 5, 16, + 8, 21, 24, 4, 15, 23, 19, 13, + 12, 2, 20, 14, 22, 9, 6, 1}; + +uint64_t RotLeft64(uint64_t x, uint64_t y) { return (x << y) | (x >> ((sizeof(uint64_t) * 8) - y)); } -void scl::Keccakf(uint64_t state[25]) { +} // namespace + +void scl::details::Keccakf(uint64_t state[25]) { uint64_t t; uint64_t bc[5]; @@ -35,7 +57,7 @@ void scl::Keccakf(uint64_t state[25]) { } for (std::size_t i = 0; i < 5; ++i) { - t = bc[(i + 4) % 5] ^ rotl64(bc[(i + 1) % 5], 1); + t = bc[(i + 4) % 5] ^ RotLeft64(bc[(i + 1) % 5], 1); for (std::size_t j = 0; j < 25; j += 5) { state[j + i] ^= t; } @@ -45,7 +67,7 @@ void scl::Keccakf(uint64_t state[25]) { for (std::size_t i = 0; i < 24; ++i) { const uint64_t v = keccakf_piln[i]; bc[0] = state[v]; - state[v] = rotl64(t, keccakf_rotc[i]); + state[v] = RotLeft64(t, keccakf_rotc[i]); t = bc[0]; } diff --git a/src/scl/math/str.cc b/src/scl/util/str.cc similarity index 97% rename from src/scl/math/str.cc rename to src/scl/util/str.cc index 32bdb74..cf7a5e3 100644 --- a/src/scl/math/str.cc +++ b/src/scl/util/str.cc @@ -18,7 +18,7 @@ * along with this program. If not, see . */ -#include "scl/math/str.h" +#include "scl/util/str.h" #include diff --git a/test/scl/gf7.cc b/test/scl/gf7.cc index 7148f3d..9cd1d7d 100644 --- a/test/scl/gf7.cc +++ b/test/scl/gf7.cc @@ -22,7 +22,7 @@ #include "scl/math/ff_ops.h" -using GF7 = scl::details::GF7; +using GF7 = scl_tests::GaloisField7; template <> void scl::details::FieldConvertIn(unsigned char& out, int v) { @@ -95,6 +95,12 @@ void scl::details::FieldFromBytes(unsigned char& dest, dest = dest % 7; } +template <> +void scl::details::FieldToBytes(unsigned char* dest, + const unsigned char& src) { + *dest = src; +} + template <> std::string scl::details::FieldToString(const unsigned char& in) { std::stringstream ss; diff --git a/test/scl/gf7.h b/test/scl/gf7.h index ec9e9c9..a4e2a1b 100644 --- a/test/scl/gf7.h +++ b/test/scl/gf7.h @@ -24,17 +24,15 @@ #include #include -namespace scl { -namespace details { +namespace scl_tests { -struct GF7 { +struct GaloisField7 { using ValueType = unsigned char; constexpr static const char* kName = "GF(7)"; constexpr static const std::size_t kByteSize = 1; constexpr static const std::size_t kBitSize = 8; }; -} // namespace details -} // namespace scl +} // namespace scl_tests #endif /* TEST_SCL_GF7_H */ diff --git a/test/scl/math/fields.h b/test/scl/math/fields.h new file mode 100644 index 0000000..e0277e9 --- /dev/null +++ b/test/scl/math/fields.h @@ -0,0 +1,43 @@ +/** + * @file fields.h + * + * SCL --- Secure Computation Library + * Copyright (C) 2022 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#include "../gf7.h" +#include "scl/math/curves/secp256k1.h" +#include "scl/math/fp.h" + +namespace scl_tests { + +using Mersenne61 = scl::Fp<61>; +using Mersenne127 = scl::Fp<127>; +using GF7 = scl::FF; + +#ifdef SCL_ENABLE_EC_TESTS +using Secp256k1_Field = scl::FF; +using Secp256k1_Order = scl::FF; +#endif + +#ifdef SCL_ENABLE_EC_TESTS +#define FIELD_DEFS \ + Mersenne61, Mersenne127, GF7, Secp256k1_Field, Secp256k1_Order +#else +#define FIELD_DEFS Mersenne61, Mersenne127, GF7 +#endif + +} // namespace scl_tests diff --git a/test/scl/math/test_ff.cc b/test/scl/math/test_ff.cc index 2f22384..f684cc0 100644 --- a/test/scl/math/test_ff.cc +++ b/test/scl/math/test_ff.cc @@ -20,19 +20,12 @@ #include -#include "../gf7.h" -#include "scl/math/curves/secp256k1.h" -#include "scl/math/fp.h" -#include "scl/prg.h" +#include "fields.h" +#include "scl/primitives/prg.h" -using Mersenne61 = scl::Fp<61>; -using Mersenne127 = scl::Fp<127>; -using GF7 = scl::FF; +using namespace scl_tests; -#ifdef SCL_ENABLE_EC_TESTS -using Secp256k1_Field = scl::FF; -using Secp256k1_Order = scl::FF; -#endif +namespace { template T RandomNonZero(scl::PRG& prg) { @@ -60,16 +53,12 @@ GF7 RandomNonZero(scl::PRG& prg) { return a; } -#define REPEAT for (std::size_t i = 0; i < 50; ++i) +} // namespace -#ifdef SCL_ENABLE_EC_TESTS -#define ARG_LIST Mersenne61, Mersenne127, GF7, Secp256k1_Field, Secp256k1_Order -#else -#define ARG_LIST Mersenne61, Mersenne127, GF7 -#endif +#define REPEAT for (std::size_t i = 0; i < 50; ++i) -TEMPLATE_TEST_CASE("FF", "[math]", ARG_LIST) { - scl::PRG prg; +TEMPLATE_TEST_CASE("FF", "[math]", FIELD_DEFS) { + auto prg = scl::PRG::Create(); auto zero = TestType(); SECTION("random") { @@ -106,6 +95,8 @@ TEMPLATE_TEST_CASE("FF", "[math]", ARG_LIST) { REQUIRE(a == a_negated); REQUIRE(a - zero == a); } + + REQUIRE(TestType(0) == -TestType(0)); } SECTION("subtraction") { @@ -136,7 +127,8 @@ TEMPLATE_TEST_CASE("FF", "[math]", ARG_LIST) { SECTION("inverses") { REQUIRE_THROWS_MATCHES( - zero.Inverse(), std::logic_error, + zero.Inverse(), + std::logic_error, Catch::Matchers::Message("0 not invertible modulo prime")); REPEAT { diff --git a/test/scl/math/test_la.cc b/test/scl/math/test_la.cc index 8234d09..f485c5b 100644 --- a/test/scl/math/test_la.cc +++ b/test/scl/math/test_la.cc @@ -25,7 +25,7 @@ #include "scl/math/la.h" #include "scl/math/mat.h" -using F = scl::FF; +using F = scl::FF; using Mat = scl::Mat; using Vec = scl::Vec; @@ -37,10 +37,8 @@ TEST_CASE("LinearAlgebra", "[math]") { // [1 0 1] // [0 1 0] // [0 0 0] - Mat A = Mat::FromVector(3, 3, - {one, zero, one, // - zero, one, zero, // - zero, zero, zero}); + Mat A = Mat::FromVector( + 3, 3, {one, zero, one, zero, one, zero, zero, zero, zero}); REQUIRE(scl::details::GetPivotInColumn(A, 2) == -1); REQUIRE(scl::details::GetPivotInColumn(A, 1) == 1); REQUIRE(scl::details::GetPivotInColumn(A, 0) == 0); @@ -52,27 +50,27 @@ TEST_CASE("LinearAlgebra", "[math]") { } SECTION("FindFirstNonZeroRow") { - Mat A = Mat::FromVector(3, 3, - {one, zero, one, // - zero, one, zero, // - zero, zero, zero}); + Mat A = Mat::FromVector( + 3, 3, {one, zero, one, zero, one, zero, zero, zero, zero}); REQUIRE(scl::details::FindFirstNonZeroRow(A) == 1); A(2, 1) = one; REQUIRE(scl::details::FindFirstNonZeroRow(A) == 2); } SECTION("ExtractSolution") { - Mat A = Mat::FromVector(3, 4, - {one, zero, zero, F(3), // - zero, one, zero, F(5), // - zero, zero, one, F(2)}); + Mat A = Mat::FromVector( + 3, + 4, + {one, zero, zero, F(3), zero, one, zero, F(5), zero, zero, one, F(2)}); auto x = scl::details::ExtractSolution(A); REQUIRE(x.Equals(Vec{F(3), F(5), F(2)})); + // clang-format off Mat B = Mat::FromVector(3, 4, - {F(1), F(3), F(1), F(2), // - F(0), F(0), F(1), F(4), // + {F(1), F(3), F(1), F(2), + F(0), F(0), F(1), F(4), F(0), F(0), F(0), F(0)}); + // clang-format on auto y = scl::details::ExtractSolution(B); REQUIRE(y.Equals(Vec{F(4), F(4), F(0)})); @@ -84,7 +82,7 @@ TEST_CASE("LinearAlgebra", "[math]") { SECTION("RandomSolve") { auto n = 10; - scl::PRG prg; + auto prg = scl::PRG::Create(); Mat A = Mat::Random(n, n, prg); Vec b = Vec::Random(n, prg); @@ -99,7 +97,8 @@ TEST_CASE("LinearAlgebra", "[math]") { Mat A(2, 2); Vec b(3); REQUIRE_THROWS_MATCHES( - scl::details::SolveLinearSystem(x, A, b), std::invalid_argument, + scl::details::SolveLinearSystem(x, A, b), + std::invalid_argument, Catch::Matchers::Message("malformed system of equations")); } @@ -113,7 +112,7 @@ TEST_CASE("LinearAlgebra", "[math]") { SECTION("Inverse") { std::size_t n = 10; - scl::PRG prg; + auto prg = scl::PRG::Create(); Mat A = Mat::Random(n, n, prg); Mat I = Mat::Identity(n); diff --git a/test/scl/math/test_mat.cc b/test/scl/math/test_mat.cc index 4410098..c2cb475 100644 --- a/test/scl/math/test_mat.cc +++ b/test/scl/math/test_mat.cc @@ -26,7 +26,9 @@ using F = scl::Fp<61>; using Mat = scl::Mat; -inline void Populate(Mat& m, const int* values) { +namespace { + +void Populate(Mat& m, const int* values) { for (std::size_t i = 0; i < m.Rows(); i++) { for (std::size_t j = 0; j < m.Cols(); j++) { m(i, j) = F(values[i * m.Cols() + j]); @@ -34,6 +36,8 @@ inline void Populate(Mat& m, const int* values) { } } +} // namespace + TEST_CASE("Matrix", "[math]") { Mat m0(2, 2); int v0[] = {1, 2, 5, 6}; @@ -79,9 +83,11 @@ TEST_CASE("Matrix", "[math]") { // matrices are 0 initialized, so the above matrices are equal REQUIRE(a.Equals(b)); - REQUIRE_THROWS_MATCHES(Mat(0, 1), std::invalid_argument, + REQUIRE_THROWS_MATCHES(Mat(0, 1), + std::invalid_argument, Catch::Matchers::Message("n or m cannot be 0")); - REQUIRE_THROWS_MATCHES(Mat(1, 0), std::invalid_argument, + REQUIRE_THROWS_MATCHES(Mat(1, 0), + std::invalid_argument, Catch::Matchers::Message("n or m cannot be 0")); } @@ -119,7 +125,8 @@ TEST_CASE("Matrix", "[math]") { SECTION("Incompatible") { Mat m2(3, 2); - REQUIRE_THROWS_MATCHES(m2.Add(m0), std::invalid_argument, + REQUIRE_THROWS_MATCHES(m2.Add(m0), + std::invalid_argument, Catch::Matchers::Message("incompatible matrices")); } @@ -147,7 +154,8 @@ TEST_CASE("Matrix", "[math]") { REQUIRE(m5.Equals(m4)); REQUIRE_THROWS_MATCHES( - m3.Multiply(m0), std::invalid_argument, + m3.Multiply(m0), + std::invalid_argument, Catch::Matchers::Message("invalid matrix dimensions for multiply")); } @@ -181,7 +189,7 @@ TEST_CASE("Matrix", "[math]") { } SECTION("Random") { - scl::PRG prg; + auto prg = scl::PRG::Create(); Mat mr = Mat::Random(4, 5, prg); REQUIRE(mr.Rows() == 4); REQUIRE(mr.Cols() == 5); @@ -194,7 +202,7 @@ TEST_CASE("Matrix", "[math]") { REQUIRE(not_zero); // check stability - scl::PRG prg1; + auto prg1 = scl::PRG::Create(); Mat mr1 = Mat::Random(4, 5, prg1); REQUIRE(mr1.Equals(mr)); } @@ -212,7 +220,7 @@ TEST_CASE("Matrix", "[math]") { } SECTION("resize") { - scl::PRG prg; + auto prg = scl::PRG::Create(); Mat m = Mat::Random(2, 4, prg); auto copy = m; copy.Resize(1, 8); @@ -225,12 +233,13 @@ TEST_CASE("Matrix", "[math]") { } } - REQUIRE_THROWS_MATCHES(m.Resize(42, 4), std::invalid_argument, + REQUIRE_THROWS_MATCHES(m.Resize(42, 4), + std::invalid_argument, Catch::Matchers::Message("cannot resize matrix")); } SECTION("IsSquare") { - scl::PRG prg; + auto prg = scl::PRG::Create(); Mat sq = Mat::Random(2, 2, prg); REQUIRE(sq.IsSquare()); Mat nsq = Mat::Random(4, 2, prg); @@ -279,7 +288,8 @@ TEST_CASE("Matrix", "[math]") { REQUIRE(m1(2, 2) == F(64)); xs.emplace_back(F(55)); - REQUIRE_THROWS_MATCHES(Mat::Vandermonde(3, 3, xs), std::invalid_argument, + REQUIRE_THROWS_MATCHES(Mat::Vandermonde(3, 3, xs), + std::invalid_argument, Catch::Matchers::Message("|xs| != number of rows")); } } diff --git a/test/scl/math/test_mersenne127.cc b/test/scl/math/test_mersenne127.cc index 5adc8d3..dd47263 100644 --- a/test/scl/math/test_mersenne127.cc +++ b/test/scl/math/test_mersenne127.cc @@ -19,6 +19,7 @@ */ #include +#include #include "scl/math/fp.h" @@ -31,7 +32,9 @@ TEST_CASE("Mersenne127", "[math]") { Field x(0x7b); Field big = Field::FromString("58797a14d0653d22a05c11c60e1aacf4"); - SECTION("Name") { REQUIRE(std::string(Field::Name()) == "Mersenne127"); } + SECTION("Name") { + REQUIRE(std::string(Field::Name()) == "Mersenne127"); + } SECTION("ToString") { REQUIRE(zero.ToString() == "0"); diff --git a/test/scl/math/test_mersenne61.cc b/test/scl/math/test_mersenne61.cc index 6f84d52..c827e09 100644 --- a/test/scl/math/test_mersenne61.cc +++ b/test/scl/math/test_mersenne61.cc @@ -19,6 +19,7 @@ */ #include +#include #include "scl/math/fp.h" @@ -30,7 +31,9 @@ TEST_CASE("Mersenne61", "[math]") { Field x(0x7b); Field big(0x41621e); - SECTION("Name") { REQUIRE(std::string(Field::Name()) == "Mersenne61"); } + SECTION("Name") { + REQUIRE(std::string(Field::Name()) == "Mersenne61"); + } SECTION("ToString") { REQUIRE(zero.ToString() == "0"); @@ -53,10 +56,12 @@ TEST_CASE("Mersenne61", "[math]") { } SECTION("FromString") { - REQUIRE_THROWS_MATCHES(Field::FromString("012"), std::invalid_argument, + REQUIRE_THROWS_MATCHES(Field::FromString("012"), + std::invalid_argument, Catch::Matchers::Message("odd-length hex string")); REQUIRE_THROWS_MATCHES( - Field::FromString("1g"), std::invalid_argument, + Field::FromString("1g"), + std::invalid_argument, Catch::Matchers::Message("encountered invalid hex character")); auto y = Field::FromString("7b"); REQUIRE(x == y); diff --git a/test/scl/math/test_number.cc b/test/scl/math/test_number.cc index 700ec2d..085f507 100644 --- a/test/scl/math/test_number.cc +++ b/test/scl/math/test_number.cc @@ -22,13 +22,13 @@ #include #include "scl/math/number.h" -#include "scl/prg.h" +#include "scl/primitives/prg.h" #define REPEAT_I(I) for (std::size_t i = 0; i < (I); ++i) #define REPEAT REPEAT_I(50) TEST_CASE("Number", "[math]") { - scl::PRG prg; + auto prg = scl::PRG::Create(); SECTION("Construction") { scl::Number n0(27); REQUIRE(n0.ToString() == "Number{1b}"); @@ -129,7 +129,9 @@ TEST_CASE("Number", "[math]") { } } - SECTION("Negation") { REQUIRE(-scl::Number(1234) == scl::Number(-1234)); } + SECTION("Negation") { + REQUIRE(-scl::Number(1234) == scl::Number(-1234)); + } SECTION("Multiplication") { scl::Number a(444); diff --git a/test/scl/math/test_secp256k1.cc b/test/scl/math/test_secp256k1.cc index d9053b5..9432085 100644 --- a/test/scl/math/test_secp256k1.cc +++ b/test/scl/math/test_secp256k1.cc @@ -26,13 +26,15 @@ #include "scl/math/ec_ops.h" #include "scl/math/fp.h" #include "scl/math/number.h" -#include "scl/prg.h" +#include "scl/primitives/prg.h" using Curve = scl::EC; using Field = Curve::Field; TEST_CASE("secp256k1_field", "[math]") { - SECTION("name") { REQUIRE(std::string(Field::Name()) == "secp256k1_field"); } + SECTION("name") { + REQUIRE(std::string(Field::Name()) == "secp256k1_field"); + } SECTION("Strings") { REQUIRE(Field(0).ToString() == "0"); @@ -76,7 +78,8 @@ TEST_CASE("secp256k1_field", "[math]") { REQUIRE(as_affine[1] == y); REQUIRE_THROWS_MATCHES( - Curve::FromAffine(Field(0), Field(0)), std::invalid_argument, + Curve::FromAffine(Field(0), Field(0)), + std::invalid_argument, Catch::Matchers::Message("provided (x, y) not on curve")); } @@ -126,7 +129,8 @@ TEST_CASE("secp256k1_field", "[math]") { SECTION("inversion") { REQUIRE_THROWS_MATCHES( - Field(0).Inverse(), std::invalid_argument, + Field(0).Inverse(), + std::invalid_argument, Catch::Matchers::Message("0 not invertible modulo prime")); Field one(1); REQUIRE(one * one.Inverse() == one); @@ -148,13 +152,18 @@ TEST_CASE("secp256k1_field", "[math]") { using Scalar = scl::FF; -static Curve RandomPoint(scl::PRG& prg) { +namespace { + +Curve RandomPoint(scl::PRG& prg) { auto r = scl::Number::Random(100, prg); return Curve::Generator() * r; } +} // namespace TEST_CASE("secp256k1", "[math]") { - SECTION("name") { REQUIRE(std::string(Curve::Name()) == "secp256k1"); } + SECTION("name") { + REQUIRE(std::string(Curve::Name()) == "secp256k1"); + } auto ord = scl::Number::FromString( "fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141"); @@ -186,7 +195,7 @@ TEST_CASE("secp256k1", "[math]") { REQUIRE(poi.ToString() == "EC{POINT_AT_INFINITY}"); } - scl::PRG prg; + auto prg = scl::PRG::Create(); SECTION("addition") { auto a = RandomPoint(prg); @@ -263,23 +272,23 @@ TEST_CASE("secp256k1", "[math]") { auto a = RandomPoint(prg); auto buffer = std::make_unique(Curve::ByteSize(false)); a.Write(buffer.get(), false); - REQUIRE(buffer[0] == 0); + REQUIRE(buffer[0] == 0x04); auto c = Curve::Read(buffer.get()); REQUIRE(a == c); a.Write(buffer.get(), true); auto d = Curve::Read(buffer.get()); - REQUIRE(buffer[0] == 0x05); + REQUIRE(buffer[0] == 0x01); REQUIRE(a == d); Curve poi; poi.Write(buffer.get(), false); - REQUIRE(buffer[0] == 0x02); + REQUIRE(buffer[0] == 0x06); auto e = Curve::Read(buffer.get()); REQUIRE(e.PointAtInfinity()); poi.Write(buffer.get(), true); - REQUIRE(buffer[0] == (0x02 | 0x04)); + REQUIRE(buffer[0] == 0x02); auto f = Curve::Read(buffer.get()); REQUIRE(f.PointAtInfinity()); @@ -290,7 +299,7 @@ TEST_CASE("secp256k1", "[math]") { "a2eee86009ff") // ); g.Write(buffer.get()); - REQUIRE(buffer[0] == (0x04 | 0x01)); + REQUIRE(buffer[0] == 0x01); auto h = Curve::Read(buffer.get()); REQUIRE(h == g); diff --git a/test/scl/math/test_vec.cc b/test/scl/math/test_vec.cc index a126e39..d3d1049 100644 --- a/test/scl/math/test_vec.cc +++ b/test/scl/math/test_vec.cc @@ -112,7 +112,8 @@ TEST_CASE("Vector", "[math]") { SECTION("Incompatible") { auto v2 = Vec{F(2), F(3)}; REQUIRE(!v2.Equals(v1)); - REQUIRE_THROWS_MATCHES(v2.Add(v1), std::invalid_argument, + REQUIRE_THROWS_MATCHES(v2.Add(v1), + std::invalid_argument, Catch::Matchers::Message("Vec sizes mismatch")); } @@ -122,7 +123,7 @@ TEST_CASE("Vector", "[math]") { } SECTION("Random") { - scl::PRG prg; + auto prg = scl::PRG::Create(); auto r = Vec::Random(3, prg); auto zero = F(); REQUIRE(r.Size() == 3); diff --git a/test/scl/math/test_z2k.cc b/test/scl/math/test_z2k.cc index 954d810..07bf345 100644 --- a/test/scl/math/test_z2k.cc +++ b/test/scl/math/test_z2k.cc @@ -22,14 +22,14 @@ #include #include "scl/math/z2k.h" -#include "scl/prg.h" +#include "scl/primitives/prg.h" // use sizes that ensure masking is needed using Z2k1 = scl::Z2k<62>; using Z2k2 = scl::Z2k<123>; TEMPLATE_TEST_CASE("Z2k", "[math]", Z2k1, Z2k2) { - scl::PRG prg; + auto prg = scl::PRG::Create(); auto zero = TestType(); REQUIRE(std::string("Z2k") == TestType::Name()); @@ -89,7 +89,8 @@ TEMPLATE_TEST_CASE("Z2k", "[math]", Z2k1, Z2k2) { auto a_inverse = a.Inverse(); REQUIRE(a * a_inverse == TestType(1)); REQUIRE_THROWS_MATCHES( - zero.Inverse(), std::logic_error, + zero.Inverse(), + std::logic_error, Catch::Matchers::Message("value not invertible modulo 2^K")); } diff --git a/test/scl/net/test_config.cc b/test/scl/net/test_config.cc index c79a45c..5070d2a 100644 --- a/test/scl/net/test_config.cc +++ b/test/scl/net/test_config.cc @@ -53,13 +53,15 @@ TEST_CASE("Config", "[network]") { const auto* invalid_entry = SCL_TEST_DATA_DIR "invalid_entry.txt"; REQUIRE_THROWS_MATCHES( - scl::NetworkConfig::Load(0, invalid_entry), std::invalid_argument, + scl::NetworkConfig::Load(0, invalid_entry), + std::invalid_argument, Catch::Matchers::Message("invalid entry in config file")); const auto* invalid_non_existing_file = ""; REQUIRE_THROWS_MATCHES( scl::NetworkConfig::Load(0, invalid_non_existing_file), - std::invalid_argument, Catch::Matchers::Message("could not open file")); + std::invalid_argument, + Catch::Matchers::Message("could not open file")); } SECTION("All local") { diff --git a/test/scl/net/test_discover.cc b/test/scl/net/test_discover.cc index 7d7d09c..4e039e6 100644 --- a/test/scl/net/test_discover.cc +++ b/test/scl/net/test_discover.cc @@ -30,7 +30,9 @@ namespace { -bool VerifyParty(scl::Party& party, int id, const std::string& hostname, +bool VerifyParty(scl::Party& party, + int id, + const std::string& hostname, int port) { return party.id == id && party.hostname == hostname && party.port == port; } @@ -105,7 +107,8 @@ TEST_CASE("Discovery Server", "[network]") { scl::DiscoveryServer::Ctx ctx{me, fake.my_network}; REQUIRE_THROWS_MATCHES( - prot.Run(ctx), std::logic_error, + prot.Run(ctx), + std::logic_error, Catch::Matchers::Message("received invalid party ID")); } @@ -195,11 +198,13 @@ TEST_CASE("Discovery", "[network]") { SECTION("Too many parties") { REQUIRE_THROWS_MATCHES( - Server(9999, 256), std::invalid_argument, + Server(9999, 256), + std::invalid_argument, Catch::Matchers::Message("number_of_parties exceeds max")); REQUIRE_THROWS_MATCHES( - Server(256), std::invalid_argument, + Server(256), + std::invalid_argument, Catch::Matchers::Message("number_of_parties exceeds max")); } diff --git a/test/scl/net/test_mem_channel.cc b/test/scl/net/test_mem_channel.cc index 2c01b89..7dbf19f 100644 --- a/test/scl/net/test_mem_channel.cc +++ b/test/scl/net/test_mem_channel.cc @@ -25,7 +25,7 @@ #include "scl/math.h" #include "scl/net/mem_channel.h" -#include "scl/prg.h" +#include "scl/primitives/prg.h" #include "util.h" void PrintBuf(const unsigned char* b, std::size_t n) { @@ -41,7 +41,7 @@ TEST_CASE("InMemoryChannel", "[network]") { auto chl0 = channels[0]; auto chl1 = channels[1]; - scl::PRG prg; + auto prg = scl::PRG::Create(); unsigned char data_in[200] = {0}; prg.Next(data_in, 200); diff --git a/test/scl/net/test_tcp_channel.cc b/test/scl/net/test_tcp_channel.cc index 0efa63a..9e73011 100644 --- a/test/scl/net/test_tcp_channel.cc +++ b/test/scl/net/test_tcp_channel.cc @@ -23,7 +23,7 @@ #include "scl/net/tcp_channel.h" #include "scl/net/tcp_utils.h" -#include "scl/prg.h" +#include "scl/primitives/prg.h" #include "util.h" TEST_CASE("TcpChannel", "[network]") { @@ -79,7 +79,7 @@ TEST_CASE("TcpChannel", "[network]") { clt.join(); srv.join(); - scl::PRG prg; + auto prg = scl::PRG::Create(); unsigned char send[200] = {0}; unsigned char recv[200] = {0}; prg.Next(send, 200); diff --git a/test/scl/net/test_threaded_sender.cc b/test/scl/net/test_threaded_sender.cc index c23b432..125cc2b 100644 --- a/test/scl/net/test_threaded_sender.cc +++ b/test/scl/net/test_threaded_sender.cc @@ -25,7 +25,7 @@ #include "scl/net/tcp_utils.h" #include "scl/net/threaded_sender.h" -#include "scl/prg.h" +#include "scl/primitives/prg.h" #include "util.h" TEST_CASE("ThreadedSender", "[network]") { @@ -50,7 +50,7 @@ TEST_CASE("ThreadedSender", "[network]") { clt.join(); srv.join(); - scl::PRG prg; + auto prg = scl::PRG::Create(); unsigned char send[200] = {0}; unsigned char recv[200] = {0}; prg.Next(send, 200); diff --git a/test/scl/net/util.cc b/test/scl/net/util.cc index 6b1bbb7..0bb4941 100644 --- a/test/scl/net/util.cc +++ b/test/scl/net/util.cc @@ -22,9 +22,12 @@ int test_port = SCL_DEFAULT_TEST_PORT; -int scl_tests::GetPort() { return test_port++; } +int scl_tests::GetPort() { + return test_port++; +} -bool scl_tests::BufferEquals(const unsigned char *a, const unsigned char *b, +bool scl_tests::BufferEquals(const unsigned char* a, + const unsigned char* b, int n) { while (n-- > 0 && *a++ == *b++) { ; diff --git a/test/scl/p/test_simple.cc b/test/scl/p/test_simple.cc index 663e3f4..1a0e705 100644 --- a/test/scl/p/test_simple.cc +++ b/test/scl/p/test_simple.cc @@ -87,8 +87,10 @@ class BeaverMul : public scl::ProtocolStep { FF y; }; -static inline std::vector RandomTriple() { - scl::PRG prg; +namespace { + +std::vector RandomTriple() { + auto prg = scl::PRG::Create(); auto a = FF::Random(prg); auto b = FF::Random(prg); auto c = a * b; @@ -100,8 +102,10 @@ static inline std::vector RandomTriple() { return std::vector{{as[0], bs[0], cs[0]}, {as[1], bs[1], cs[1]}}; } +} // namespace + TEST_CASE("protocol") { - scl::PRG prg; + auto prg = scl::PRG::Create(); auto xs = scl::CreateAdditiveShares(FF(42), 2, prg); auto ys = scl::CreateAdditiveShares(FF(11), 2, prg); auto ts = RandomTriple(); diff --git a/test/scl/test_prg.cc b/test/scl/primitives/test_prg.cc similarity index 73% rename from test/scl/test_prg.cc rename to test/scl/primitives/test_prg.cc index e156824..11fc117 100644 --- a/test/scl/test_prg.cc +++ b/test/scl/primitives/test_prg.cc @@ -18,23 +18,14 @@ * along with this program. If not, see . */ +#include #include -#include "scl/prg.h" +#include "scl/primitives/prg.h" -inline bool BufferCmp(const unsigned char* b0, const unsigned char* b1, - unsigned len) { - const auto* p0 = b0; - const auto* p1 = b1; - while (len-- > 0) { - if (*p0++ != *p1++) { - return false; - } - } - return true; -} +namespace { -inline bool BufferLooksRandom(const unsigned char* p, unsigned len) { +bool BufferLooksRandom(const unsigned char* p, unsigned len) { std::vector buckets(256); for (std::size_t i = 0; i < len; i++) { @@ -43,16 +34,17 @@ inline bool BufferLooksRandom(const unsigned char* p, unsigned len) { bool all_in_interval = true; for (std::size_t i = 0; i < 256; i++) { - auto p = 100 * ((float)buckets[i] / (float)len); + auto p = 100 * (static_cast(buckets[i]) / static_cast(len)); all_in_interval &= p >= 0.2 || p <= 6.0; } return all_in_interval; } +} // namespace + TEST_CASE("PRG", "[misc]") { - scl::PRG prg; + auto prg = scl::PRG::Create(); - REQUIRE(scl::PRG::BlockSize() == 16); REQUIRE(scl::PRG::SeedSize() == 16); SECTION("SanityCheck") { @@ -65,22 +57,19 @@ TEST_CASE("PRG", "[misc]") { SECTION("Stable") { unsigned char seed[scl::PRG::SeedSize()] = "1234567890abcde"; - scl::PRG prg0(seed); - scl::PRG prg1(seed); - - REQUIRE(prg0.Counter() == prg1.Counter()); - auto counter_before = prg1.Counter(); - REQUIRE(BufferCmp(prg0.Seed(), seed, scl::PRG::SeedSize())); + auto prg0 = scl::PRG::Create(seed); + auto prg1 = scl::PRG::Create(seed); auto rand0 = prg0.Next(100); auto rand1 = prg1.Next(100); REQUIRE(rand0 == rand1); - REQUIRE(counter_before != prg1.Counter()); prg0.Reset(); auto rand00 = prg0.Next(100); REQUIRE(rand00 == rand0); + auto rand10 = prg1.Next(100); + REQUIRE(rand00 != rand10); } SECTION("Fill") { @@ -93,7 +82,8 @@ TEST_CASE("PRG", "[misc]") { REQUIRE(last_is_zero); REQUIRE_THROWS_MATCHES( - prg.Next(buffer, 101), std::invalid_argument, + prg.Next(buffer, 101), + std::invalid_argument, Catch::Matchers::Message("requested more randomness than dest.size()")); prg.Next(buffer); diff --git a/test/scl/primitives/test_sha256.cc b/test/scl/primitives/test_sha256.cc new file mode 100644 index 0000000..fad4069 --- /dev/null +++ b/test/scl/primitives/test_sha256.cc @@ -0,0 +1,89 @@ +/** + * @file test_sha256.cc + * + * SCL --- Secure Computation Library + * Copyright (C) 2022 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#include +#include + +#include "scl/math/curves/secp256k1.h" +#include "scl/math/ec.h" +#include "scl/primitives/digest.h" +#include "scl/primitives/sha256.h" + +const static std::array SHA256_empty = { + 0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, 0x9a, 0xfb, 0xf4, + 0xc8, 0x99, 0x6f, 0xb9, 0x24, 0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, + 0x93, 0x4c, 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55}; + +const static std::array SHA256_abc = { + 0xba, 0x78, 0x16, 0xbf, 0x8f, 0x01, 0xcf, 0xea, 0x41, 0x41, 0x40, + 0xde, 0x5d, 0xae, 0x22, 0x23, 0xb0, 0x03, 0x61, 0xa3, 0x96, 0x17, + 0x7a, 0x9c, 0xb4, 0x10, 0xff, 0x61, 0xf2, 0x00, 0x15, 0xad}; + +using Sha256 = scl::details::Sha256; + +TEST_CASE("Sha256") { + SECTION("SHA256 abc") { + Sha256 hash; + hash.Update({'a', 'b', 'c'}); + auto digest = hash.Finalize(); + REQUIRE(digest.size() == 32); + REQUIRE(digest == SHA256_abc); + } + + SECTION("SHA256 empty") { + Sha256 hash; + auto digest = hash.Finalize(); + REQUIRE(digest.size() == 32); + REQUIRE(digest == SHA256_empty); + } + + SECTION("SHA256 chunked") { + Sha256 hash; + hash.Update({'a', 'b'}); + hash.Update({'c'}); + auto digest = hash.Finalize(); + REQUIRE(digest.size() == 32); + REQUIRE(digest == SHA256_abc); + } + + SECTION("Hash of curve point") { + // Reference test showing that serialization + hashing is the same as + // bouncycastle in Java. + + using Curve = scl::EC; + auto pk = Curve::Generator() * scl::Number::FromString("a"); + + const auto n = Curve::ByteSize(false); + unsigned char buf[n] = {0}; + pk.Write(buf, false); + + Sha256 hash; + hash.Update(buf, n); + + auto d = hash.Finalize(); + + std::array target = { + 0xde, 0xc1, 0x6a, 0xc2, 0x78, 0x99, 0xeb, 0xdf, 0x76, 0x0e, 0xaf, + 0x0a, 0x9f, 0x30, 0x95, 0xd1, 0x6a, 0x55, 0xea, 0x59, 0xef, 0x2a, + 0xe1, 0x8e, 0x9d, 0x22, 0x33, 0xd6, 0xbe, 0x82, 0x58, 0x38}; + + REQUIRE(d == target); + } +} diff --git a/test/scl/test_hash.cc b/test/scl/primitives/test_sha3.cc similarity index 80% rename from test/scl/test_hash.cc rename to test/scl/primitives/test_sha3.cc index 30eb57f..0eb1710 100644 --- a/test/scl/test_hash.cc +++ b/test/scl/primitives/test_sha3.cc @@ -1,5 +1,5 @@ /** - * @file test_hash.cc + * @file test_sha3.cc * * SCL --- Secure Computation Library * Copyright (C) 2022 Anders Dalskov @@ -20,30 +20,31 @@ #include -#include "scl/hash.h" +#include "scl/primitives/digest.h" +#include "scl/primitives/hash.h" -const static std::array SHA3_256_empty = { +static const scl::details::Digest<256>::Type SHA3_256_empty = { 0xa7, 0xff, 0xc6, 0xf8, 0xbf, 0x1e, 0xd7, 0x66, 0x51, 0xc1, 0x47, 0x56, 0xa0, 0x61, 0xd6, 0x62, 0xf5, 0x80, 0xff, 0x4d, 0xe4, 0x3b, 0x49, 0xfa, 0x82, 0xd8, 0x0a, 0x4b, 0x80, 0xf8, 0x43, 0x4a}; -static const std::array SHA3_256_abc = { +static const scl::details::Digest<256>::Type SHA3_256_abc = { 0x3a, 0x98, 0x5d, 0xa7, 0x4f, 0xe2, 0x25, 0xb2, 0x04, 0x5c, 0x17, 0x2d, 0x6b, 0xd3, 0x90, 0xbd, 0x85, 0x5f, 0x08, 0x6e, 0x3e, 0x9d, 0x52, 0x5b, 0x46, 0xbf, 0xe2, 0x45, 0x11, 0x43, 0x15, 0x32}; -static const std::array SHA3_256_0xa3_200_times = { +static const scl::details::Digest<256>::Type SHA3_256_0xa3_200_times = { 0x79, 0xf3, 0x8a, 0xde, 0xc5, 0xc2, 0x03, 0x07, 0xa9, 0x8e, 0xf7, 0x6e, 0x83, 0x24, 0xaf, 0xbf, 0xd4, 0x6c, 0xfd, 0x81, 0xb2, 0x2e, 0x39, 0x73, 0xc6, 0x5f, 0xa1, 0xbd, 0x9d, 0xe3, 0x17, 0x87}; -static const std::array SHA3_384_0xa3_200_times = { +static const scl::details::Digest<384>::Type SHA3_384_0xa3_200_times = { 0x18, 0x81, 0xde, 0x2c, 0xa7, 0xe4, 0x1e, 0xf9, 0x5d, 0xc4, 0x73, 0x2b, 0x8f, 0x5f, 0x00, 0x2b, 0x18, 0x9c, 0xc1, 0xe4, 0x2b, 0x74, 0x16, 0x8e, 0xd1, 0x73, 0x26, 0x49, 0xce, 0x1d, 0xbc, 0xdd, 0x76, 0x19, 0x7a, 0x31, 0xfd, 0x55, 0xee, 0x98, 0x9f, 0x2d, 0x70, 0x50, 0xdd, 0x47, 0x3e, 0x8f}; -std::array SHA3_512_0xa3_200_times = { +static const std::array SHA3_512_0xa3_200_times = { 0xe7, 0x6d, 0xfa, 0xd2, 0x20, 0x84, 0xa8, 0xb1, 0x46, 0x7f, 0xcf, 0x2f, 0xfa, 0x58, 0x36, 0x1b, 0xec, 0x76, 0x28, 0xed, 0xf5, 0xf3, 0xfd, 0xc0, 0xe4, 0x80, 0x5d, 0xc4, 0x8c, 0xae, 0xec, 0xa8, 0x1b, @@ -51,7 +52,7 @@ std::array SHA3_512_0xa3_200_times = { 0x9a, 0x2d, 0xf4, 0x6b, 0xe5, 0x89, 0xc5, 0x1c, 0xa1, 0xa4, 0xa8, 0x41, 0x6d, 0xf6, 0x54, 0x5a, 0x1c, 0xe8, 0xba, 0x00}; -TEST_CASE("Hash", "[misc]") { +TEST_CASE("Sha3", "[misc]") { SECTION("SHA3-256 empty") { scl::Hash<256> hash; auto digest = hash.Finalize(); @@ -60,7 +61,8 @@ TEST_CASE("Hash", "[misc]") { SECTION("SHA3-256 abc") { scl::Hash<256> hash; - auto digest = hash.Update((const unsigned char *)"abc", 3).Finalize(); + unsigned char abc[] = "abc"; + auto digest = hash.Update(abc, 3).Finalize(); REQUIRE(digest == SHA3_256_abc); } @@ -91,6 +93,7 @@ TEST_CASE("Hash", "[misc]") { scl::Hash<384> hash0; auto digest = hash0.Update(buf, 200).Finalize(); + REQUIRE(digest.size() == 48); REQUIRE(digest == SHA3_384_0xa3_200_times); scl::Hash<384> hash1; @@ -109,6 +112,7 @@ TEST_CASE("Hash", "[misc]") { scl::Hash<512> hash0; auto digest = hash0.Update(buf, 200).Finalize(); + REQUIRE(digest.size() == 64); REQUIRE(digest == SHA3_512_0xa3_200_times); scl::Hash<512> hash1; @@ -124,9 +128,17 @@ TEST_CASE("Hash", "[misc]") { auto ref = hash_ref.Update(ref_buf, 12).Finalize(); scl::Hash<256> hash1; - std::vector v = {'h', 'e', 'l', 'l', 'o', ',', - ' ', 'w', 'o', 'r', 'l', 'd'}; + std::vector v = { + 'h', 'e', 'l', 'l', 'o', ',', ' ', 'w', 'o', 'r', 'l', 'd'}; auto from_vec = hash1.Update(v).Finalize(); REQUIRE(ref == from_vec); } + + SECTION("Hash array") { + unsigned char abc[] = "abc"; + std::array abc_arr = {'a', 'b', 'c'}; + auto ref = scl::Hash<256>{}.Update(abc, 3).Finalize(); + auto act = scl::Hash<256>{}.Update(abc_arr).Finalize(); + REQUIRE(ref == act); + } } diff --git a/test/scl/ss/test_additive.cc b/test/scl/ss/test_additive.cc index cd905a9..fda0c0e 100644 --- a/test/scl/ss/test_additive.cc +++ b/test/scl/ss/test_additive.cc @@ -21,12 +21,12 @@ #include #include "scl/math.h" -#include "scl/prg.h" +#include "scl/primitives/prg.h" #include "scl/ss/additive.h" TEST_CASE("AdditiveSS", "[ss]") { using FF = scl::Fp<61>; - scl::PRG prg; + auto prg = scl::PRG::Create(); auto secret = FF(12345); @@ -41,6 +41,7 @@ TEST_CASE("AdditiveSS", "[ss]") { REQUIRE(scl::ReconstructAdditive(sum) == secret + x); REQUIRE_THROWS_MATCHES( - scl::CreateAdditiveShares(secret, 0, prg), std::invalid_argument, + scl::CreateAdditiveShares(secret, 0, prg), + std::invalid_argument, Catch::Matchers::Message("cannot create shares for 0 people")); } diff --git a/test/scl/ss/test_feldman.cc b/test/scl/ss/test_feldman.cc new file mode 100644 index 0000000..47027ef --- /dev/null +++ b/test/scl/ss/test_feldman.cc @@ -0,0 +1,45 @@ +/** + * @file test_feldman.cc + * + * SCL --- Secure Computation Library + * Copyright (C) 2022 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#include + +#include "scl/math/curves/secp256k1.h" +#include "scl/math/ec.h" +#include "scl/primitives/prg.h" +#include "scl/ss/feldman.h" + +TEST_CASE("Feldman", "[ss]") { + using EC = scl::EC; + using FF = EC::Order; + + auto prg = scl::PRG::Create(); + std::size_t t = 4; + + SECTION("Share") { + auto factory = scl::FeldmanSSFactory::Create(t, prg); + auto secret = FF(123); + auto sb = factory.Share(secret, 24); + REQUIRE(sb.shares.Size() == 24); + REQUIRE(sb.commitments.Size() == t + 1); + REQUIRE(factory.Verify(sb.shares[22], sb.commitments, 22)); + REQUIRE(factory.Verify(secret, sb.commitments)); + REQUIRE(factory.Recover(sb.shares) == secret); + } +} diff --git a/test/scl/ss/test_poly.cc b/test/scl/ss/test_poly.cc index e9296a1..c80440b 100644 --- a/test/scl/ss/test_poly.cc +++ b/test/scl/ss/test_poly.cc @@ -19,16 +19,21 @@ */ #include +#include +#include "../math/fields.h" #include "scl/math.h" -#include "scl/prg.h" +#include "scl/math/curves/secp256k1.h" +#include "scl/primitives/prg.h" #include "scl/ss/poly.h" -using FF = scl::Fp<61>; -using Poly = scl::details::Polynomial; -using Vec = scl::Vec; +using namespace scl_tests; + +TEMPLATE_TEST_CASE("Polynomials", "[ss][math]", FIELD_DEFS) { + using FF = TestType; + using Poly = scl::details::Polynomial; + using Vec = scl::Vec; -TEST_CASE("Polynomials", "[ss][math]") { SECTION("DefaultConstruct") { Poly p; REQUIRE(p.Degree() == 0); @@ -149,10 +154,15 @@ TEST_CASE("Polynomials", "[ss][math]") { for (std::size_t i = 0; i < x.Degree(); ++i) { REQUIRE(x[i] == q[i]); } + + Poly z; + REQUIRE_THROWS_MATCHES(p.Divide(z), + std::invalid_argument, + Catch::Matchers::Message("division by 0")); } SECTION("DivideRandom") { - scl::PRG prg; + auto prg = scl::PRG::Create(); auto c0 = Vec::Random(10, prg); auto c1 = Vec::Random(9, prg); auto a = Poly::Create(c0); diff --git a/test/scl/ss/test_shamir.cc b/test/scl/ss/test_shamir.cc index 9a44801..07b45d4 100644 --- a/test/scl/ss/test_shamir.cc +++ b/test/scl/ss/test_shamir.cc @@ -23,7 +23,7 @@ #include "../gf7.h" #include "scl/math.h" -#include "scl/prg.h" +#include "scl/primitives/prg.h" #include "scl/ss/shamir.h" TEST_CASE("Shamir", "[ss]") { @@ -32,48 +32,46 @@ TEST_CASE("Shamir", "[ss]") { const std::size_t t = 2; - SECTION("Reconstruct") { - scl::PRG prg; - scl::details::ShamirSSFactory factory( - t, prg, scl::details::SecurityLevel::PASSIVE); - auto intr = factory.GetInterpolator(); - + SECTION("Recover") { + auto prg = scl::PRG::Create(); + auto factory = + scl::ShamirSSFactory::Create(t, prg, scl::SecurityLevel::PASSIVE); auto secret = FF(123); auto shares = factory.Share(secret); - auto s = intr.Reconstruct(shares); + auto s = factory.Recover(shares); REQUIRE(s == secret); REQUIRE_THROWS_MATCHES( - intr.Reconstruct(shares.SubVector(1)), std::invalid_argument, + factory.Recover(shares.SubVector(1)), + std::invalid_argument, Catch::Matchers::Message("not enough shares to reconstruct")); } SECTION("Detection") { - scl::PRG prg; - scl::details::ShamirSSFactory factory( - t, prg, scl::details::SecurityLevel::DETECT); - auto intr = scl::details::Reconstructor::Create( - t, scl::details::SecurityLevel::DETECT); + auto prg = scl::PRG::Create(); + auto factory = + scl::ShamirSSFactory::Create(t, prg, scl::SecurityLevel::DETECT); auto secret = FF(555); auto shares = factory.Share(secret); REQUIRE(shares.Size() == 2 * t + 1); - REQUIRE(intr.Reconstruct(shares) == secret); + REQUIRE(factory.Recover(shares) == secret); REQUIRE_THROWS_MATCHES( - intr.Reconstruct(shares.SubVector(2)), std::invalid_argument, + factory.Recover(shares.SubVector(2)), + std::invalid_argument, Catch::Matchers::Message("not enough shares to reconstruct")); - auto ss = intr.ReconstructShare(shares, 2); + auto ss = factory.RecoverShare(shares, 2); REQUIRE(ss == shares[2]); - REQUIRE(intr.Reconstruct(shares, 3) == intr.ReconstructShare(shares, 2)); + REQUIRE(factory.Recover(shares, 3) == factory.RecoverShare(shares, 2)); } SECTION("Robust") { - scl::PRG prg; - scl::details::ShamirSSFactory factory( - t, prg, scl::details::SecurityLevel::CORRECT); + auto prg = scl::PRG::Create(); + auto factory = + scl::ShamirSSFactory::Create(t, prg, scl::SecurityLevel::CORRECT); // no errors auto secret = FF(123); @@ -83,9 +81,7 @@ TEST_CASE("Shamir", "[ss]") { REQUIRE(reconstructed == secret); // can also reconstruct with an interpolator - auto intr = scl::details::Reconstructor::Create( - t, scl::details::SecurityLevel::CORRECT); - REQUIRE(intr.Reconstruct(shares) == secret); + REQUIRE(factory.Recover(shares) == secret); // one error shares[0] = FF(63212); @@ -100,7 +96,8 @@ TEST_CASE("Shamir", "[ss]") { // three errors -- that's one too many shares[1] = FF(123); REQUIRE_THROWS_MATCHES( - scl::details::ReconstructShamirRobust(shares, t), std::logic_error, + scl::details::ReconstructShamirRobust(shares, t), + std::logic_error, Catch::Matchers::Message("could not correct shares")); REQUIRE_THROWS_MATCHES( @@ -119,7 +116,7 @@ TEST_CASE("Shamir", "[ss]") { TEST_CASE("BerlekampWelch", "[ss][math]") { // https://en.wikipedia.org/wiki/Berlekamp%E2%80%93Welch_algorithm#Example - using FF = scl::FF; + using FF = scl::FF; using Vec = scl::Vec; Vec bs = {FF(1), FF(5), FF(3), FF(6), FF(3), FF(2), FF(2)};