From d50e6c07314adf9a1a5d3f6aaf65f1ca0d3c513f Mon Sep 17 00:00:00 2001 From: Anders Dalskov Date: Tue, 8 Nov 2022 00:13:54 +0100 Subject: [PATCH] Version 3.0.0 --- .github/workflows/Build.yml | 7 +- .github/workflows/Checks.yml | 55 ++- .github/workflows/Test.yml | 42 ++- CMakeLists.txt | 6 +- RELEASE.txt | 14 + examples/03_secret_sharing.cc | 22 +- include/scl/hash.h | 22 +- include/scl/math/ec.h | 8 + include/scl/math/ec_ops.h | 13 +- include/scl/math/ff_ops.h | 2 +- include/scl/math/la.h | 40 ++- include/scl/math/mat.h | 71 ++-- include/scl/math/number.h | 35 +- include/scl/math/str.h | 4 +- include/scl/math/vec.h | 117 ++++++- include/scl/math/z2k_ops.h | 5 +- include/scl/net/channel.h | 29 +- include/scl/net/config.h | 14 +- include/scl/net/discovery/client.h | 19 +- include/scl/net/discovery/server.h | 7 +- include/scl/net/mem_channel.h | 14 +- include/scl/net/network.h | 7 +- include/scl/net/tcp_channel.h | 3 +- include/scl/net/tcp_utils.h | 8 +- include/scl/net/threaded_sender.h | 4 +- include/scl/prg.h | 3 +- include/scl/ss/additive.h | 4 +- include/scl/ss/poly.h | 57 +-- include/scl/ss/shamir.h | 499 +++++++++++++++++---------- src/scl/hash.cc | 14 +- src/scl/math/mersenne127.cc | 20 +- src/scl/math/mersenne61.cc | 8 +- src/scl/math/number.cc | 16 +- src/scl/math/ops_gmp_ff.h | 31 +- src/scl/math/ops_small_fp.h | 17 +- src/scl/math/secp256k1_curve.cc | 126 +++---- src/scl/math/secp256k1_extras.h | 4 + src/scl/math/secp256k1_field.cc | 17 +- src/scl/math/secp256k1_order.cc | 29 +- src/scl/math/str.cc | 6 +- src/scl/net/config.cc | 38 +- src/scl/net/discovery/client.cc | 12 +- src/scl/net/discovery/server.cc | 28 +- src/scl/net/mem_channel.cc | 13 +- src/scl/net/network.cc | 4 +- src/scl/net/tcp_channel.cc | 25 +- src/scl/net/tcp_utils.cc | 24 +- src/scl/prg.cc | 38 +- test/scl/gf7.cc | 9 +- test/scl/gf7.h | 6 +- test/scl/math/test_ff.cc | 10 +- test/scl/math/test_mat.cc | 23 +- test/scl/math/test_secp256k1.cc | 50 ++- test/scl/math/test_vec.cc | 15 + test/scl/math/test_z2k.cc | 2 +- test/scl/net/test_config.cc | 10 +- test/scl/net/test_discover.cc | 23 +- test/scl/net/test_mem_channel.cc | 23 +- test/scl/net/test_network.cc | 4 +- test/scl/net/test_tcp_channel.cc | 17 +- test/scl/net/test_threaded_sender.cc | 17 +- test/scl/net/util.cc | 3 +- test/scl/net/util.h | 7 +- test/scl/p/test_simple.cc | 7 +- test/scl/ss/test_shamir.cc | 111 +++--- test/scl/test_hash.cc | 24 +- test/scl/test_prg.cc | 17 +- 67 files changed, 1280 insertions(+), 699 deletions(-) diff --git a/.github/workflows/Build.yml b/.github/workflows/Build.yml index b24c786..a478e9d 100644 --- a/.github/workflows/Build.yml +++ b/.github/workflows/Build.yml @@ -1,9 +1,14 @@ name: Build -on: [push] +on: + push: + branches: + - '*' + - '!master' jobs: build: + name: Release runs-on: ubuntu-latest steps: diff --git a/.github/workflows/Checks.yml b/.github/workflows/Checks.yml index 6a08abd..05adccc 100644 --- a/.github/workflows/Checks.yml +++ b/.github/workflows/Checks.yml @@ -1,30 +1,49 @@ name: Checks -on: [push] +on: + push: + branches: + - '*' + - '!master' jobs: - build: + documentation: + name: Documentation runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + + - name: Setup + run: sudo apt-get install -y doxygen + + - name: Documentation + shell: bash + run: ./scripts/build_documentation.sh + headers: + name: Header files + runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v2 - - name: Setup - run: sudo apt-get install -y doxygen clang-format-12 + - name: Copyright + run: ./scripts/check_copyright_headers.py - - name: Documentation - shell: bash - run: ./scripts/build_documentation.sh + - name: Header Guards + run: ./scripts/check_header_guards.py - - name: Copyright - run: ./scripts/check_copyright_headers.py + style: + name: Code style + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 - - name: Header Guards - run: ./scripts/check_header_guards.py + - name: Setup + run: sudo apt-get install -y clang-format-12 - - name: Style - shell: bash - run: | - find . -type f \( -iname "*.h" -o -iname "*.cc" \) -exec clang-format -n --style=Google {} \; &> checks.txt - cat checks.txt - test ! -s checks.txt + - name: Check + shell: bash + run: | + find . -type f \( -iname "*.h" -o -iname "*.cc" \) -exec clang-format -n --style=Google {} \; &> checks.txt + cat checks.txt + test ! -s checks.txt diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index 4b0332e..c52ab45 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -1,38 +1,39 @@ name: Test -on: [push] +on: + push: + branches: + - '*' + - '!master' env: BUILD_TYPE: Debug jobs: build: + name: Coverage and Linting runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Setup catch2 + - name: Setup run: | - sudo apt-get install -y lcov + sudo apt-get install -y lcov bear curl -L https://github.com/catchorg/Catch2/archive/v2.13.0.tar.gz -o c.tar.gz tar xvf c.tar.gz cd Catch2-2.13.0/ - cmake -Bbuild -H. -DBUILD_TESTING=OFF - sudo cmake --build build/ --target install + cmake -B catch -DBUILD_TESTING=OFF + cmake --build catch + sudo cmake --install catch - - name: Create build directory - run: cmake -E make_directory ${{runner.workspace}}/build - - - name: Configure CMake - shell: bash - working-directory: ${{runner.workspace}}/build - run: cmake $GITHUB_WORKSPACE -DCMAKE_BUILD_TYPE=$BUILD_TYPE + - name: CMake + run: cmake -B ${{runner.workspace}}/build -DCMAKE_BUILD_TYPE=$BUILD_TYPE . - name: Build working-directory: ${{runner.workspace}}/build shell: bash - run: cmake --build . --config $BUILD_TYPE + run: bear make -s -j4 - name: Test working-directory: ${{runner.workspace}}/build @@ -40,12 +41,17 @@ jobs: run: ctest -C $BUILD_TYPE - name: Coverage - working-directory: ${{runner.workspace}}/build shell: bash run: | - make coverage - lcov --summary coverage.info >> summary.txt + cmake --build ${{runner.workspace}}/build --target coverage + lcov --summary ${{runner.workspace}}/build/coverage.info >> ${{runner.workspace}}/summary.txt + ./scripts/check_coverage.py ${{runner.workspace}}/summary.txt - - name: Check + - name: Lint shell: bash - run: ./scripts/check_coverage.py ${{runner.workspace}}/build/summary.txt + run: | + find include/ src/ test/ -type f \( -iname "*.h" -o -iname "*.cc" \) \ + -exec clang-tidy -p ${{runner.workspace}}/build/compile_commands.json --quiet {} \; 1>> lint.txt 2>/dev/null + cat lint.txt + test ! -s lint.txt + diff --git a/CMakeLists.txt b/CMakeLists.txt index 810c4f6..25be133 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,7 +16,7 @@ cmake_minimum_required( VERSION 3.14 ) -project( scl VERSION 2.1.0 DESCRIPTION "Secure Computation Library" ) +project( scl VERSION 3.0.0 DESCRIPTION "Secure Computation Library" ) if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) @@ -120,7 +120,7 @@ if(CMAKE_BUILD_TYPE MATCHES "Debug") add_compile_definitions(SCL_ENABLE_EC_TESTS) endif() - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -fsanitize=address") find_package(Catch2 REQUIRED) include(CTest) include(Catch) @@ -150,3 +150,5 @@ if(CMAKE_BUILD_TYPE MATCHES "Debug") EXCLUDE "/usr/include/*" "test/*" "/usr/lib/*" "/usr/local/*") endif() + +message(STATUS "CXX_FLAGS=" ${CMAKE_CXX_FLAGS}) diff --git a/RELEASE.txt b/RELEASE.txt index afc0177..f56e41d 100644 --- a/RELEASE.txt +++ b/RELEASE.txt @@ -1,3 +1,17 @@ +3.0: More features, build changes +- Add method for returning a point as a pair of affine coordinates +- Add method to check if a channel has data available +- Allow sending and receiving STL vectors without specifying the size +- Extend Vec with a SubVector, operator== and operator!= methods +- Begin Shamir code refactor and move all of it into details namespace +- bugs: + - fix scalar multiplication for secp256k1_order + - fix compilation error on g++12 +- build: + - build tests with -fsanitize=address + - disable actions for master branch + - add clang-tidy action + 2.1: More Finite Fields - Provide a FF implementation for computations modulo the order of Secp256k1 - Extend EC with support for scalar multiplications with scalars from a finite diff --git a/examples/03_secret_sharing.cc b/examples/03_secret_sharing.cc index 55a6d55..c5f9dcc 100644 --- a/examples/03_secret_sharing.cc +++ b/examples/03_secret_sharing.cc @@ -46,22 +46,28 @@ int main() { * correction. Lets see error detection at work first */ + scl::details::ShamirSSFactory factory( + 1, prg, scl::details::SecurityLevel::CORRECT); /* We create 4 shamir shares with a threshold of 1. */ - auto shamir_shares = scl::CreateShamirShares(secret, 4, 1, prg); + auto shamir_shares = factory.Share(secret); std::cout << shamir_shares << "\n"; /* Of course, these can be reconstructed. The second parameter is the * threshold. This performs reconstruction with error detection. */ - auto shamir_reconstructed = scl::ReconstructShamir(shamir_shares, 1); + auto recon = factory.GetInterpolator(); + auto shamir_reconstructed = + recon.Reconstruct(shamir_shares, scl::details::SecurityLevel::DETECT); std::cout << shamir_reconstructed << "\n"; /* If we introduce an error, then reconstruction fails */ shamir_shares[2] = Fp(123); try { - std::cout << scl::ReconstructShamir(shamir_shares, 1) << "\n"; + std::cout << recon.Reconstruct(shamir_shares, + scl::details::SecurityLevel::DETECT) + << "\n"; } catch (std::logic_error& e) { std::cout << e.what() << "\n"; } @@ -69,7 +75,7 @@ int main() { /* On the other hand, we can use the robust reconstruction since the threshold * is low enough. I.e., because 4 >= 3*1 + 1. */ - auto r = scl::ReconstructShamirRobust(shamir_shares, 1); + auto r = recon.Reconstruct(shamir_shares); std::cout << r << "\n"; /* With a bit of extra work, we can even learn which share had the error. @@ -79,7 +85,7 @@ int main() { * default these are just the field elements 1 through 4. */ Vec alphas = {Fp(1), Fp(2), Fp(3), Fp(4)}; - auto pe = scl::ReconstructShamirRobust(shamir_shares, alphas, 1); + auto pe = scl::details::ReconstructShamirRobust(shamir_shares, alphas, 1); /* pe is a pair of polynomials. The first is the original polynomial used for * generating the shares and the second is a polynomial whose roots tell which @@ -87,18 +93,18 @@ int main() { * * The secret is embedded in the constant term. */ - std::cout << pe[0].Evaluate(Fp(0)) << "\n"; + std::cout << std::get<0>(pe).Evaluate(Fp(0)) << "\n"; /* This will be 0, indicating that the share corresponding to party 3 had an * error. */ - std::cout << pe[1].Evaluate(Fp(3)) << "\n"; + std::cout << std::get<1>(pe).Evaluate(Fp(3)) << "\n"; /* Lastly, if there's too many errors, then correction is not possible */ shamir_shares[1] = Fp(22); try { - scl::ReconstructShamirRobust(shamir_shares, 1); + recon.Reconstruct(shamir_shares); } catch (std::logic_error& e) { std::cout << e.what() << "\n"; } diff --git a/include/scl/hash.h b/include/scl/hash.h index fa1416f..9b25dee 100644 --- a/include/scl/hash.h +++ b/include/scl/hash.h @@ -130,13 +130,17 @@ Hash &Hash::Update(const unsigned char *bytes, std::size_t nbytes) { const unsigned char *p = bytes; if (nbytes < old_tail) { - while (nbytes--) mSaved |= (uint64_t)(*(p++)) << ((mByteIndex++) * 8); + while (nbytes-- > 0) { + mSaved |= (uint64_t)(*(p++)) << ((mByteIndex++) * 8); + } return *this; } - if (old_tail) { + if (old_tail != 0) { nbytes -= old_tail; - while (old_tail--) mSaved |= (uint64_t)(*(p++)) << ((mByteIndex++) * 8); + while (old_tail-- != 0) { + mSaved |= (uint64_t)(*(p++)) << ((mByteIndex++) * 8); + } mState[mWordIndex] ^= mSaved; mByteIndex = 0; @@ -167,7 +171,9 @@ Hash &Hash::Update(const unsigned char *bytes, std::size_t nbytes) { p += sizeof(uint64_t); } - while (tail--) mSaved |= (uint64_t)(*(p++)) << ((mByteIndex++) * 8); + while (tail-- > 0) { + mSaved |= (uint64_t)(*(p++)) << ((mByteIndex++) * 8); + } return *this; } @@ -194,7 +200,9 @@ auto Hash::Finalize() -> DigestType { // truncate DigestType digest = {0}; - for (std::size_t i = 0; i < digest.size(); ++i) digest[i] = mStateBytes[i]; + for (std::size_t i = 0; i < digest.size(); ++i) { + digest[i] = mStateBytes[i]; + } return digest; } @@ -208,7 +216,9 @@ template std::string DigestToString(const D &digest) { std::stringstream ss; ss << std::setw(2) << std::setfill('0') << std::hex; - for (const auto &c : digest) ss << (int)c; + for (const auto &c : digest) { + ss << (int)c; + } return ss.str(); } diff --git a/include/scl/math/ec.h b/include/scl/math/ec.h index c33c196..caed8d4 100644 --- a/include/scl/math/ec.h +++ b/include/scl/math/ec.h @@ -271,6 +271,14 @@ class EC { return details::CurveIsPointAtInfinity(mValue); }; + /** + * @brief Return this point as a pair of affine coordinates. + * @return this point as a pair of affine coordinates. + */ + std::array ToAffine() const { + return details::CurveToAffine(mValue); + }; + /** * @brief Output this point as a string. */ diff --git a/include/scl/math/ec_ops.h b/include/scl/math/ec_ops.h index 1d1f433..8a601f4 100644 --- a/include/scl/math/ec_ops.h +++ b/include/scl/math/ec_ops.h @@ -59,6 +59,15 @@ template void CurveSetAffine(typename C::ValueType& out, const FF& x, const FF& y); +/** + * @brief Convert a point to a pair of affine coordinates. + * @param point the point to convert. + * @return a set of affine coordinates. + */ +template +std::array, 2> CurveToAffine( + const typename C::ValueType& point); + /** * @brief Add two elliptic curve points in-place. * @param out the first point and output @@ -135,11 +144,11 @@ void CurveToBytes(unsigned char* dest, const typename C::ValueType& in, /** * @brief Convert an elliptic curve point to a string - * @param in the point + * @param point the point * @return an STL string representation of \p in. */ template -std::string CurveToString(const typename C::ValueType& in); +std::string CurveToString(const typename C::ValueType& point); } // namespace details } // namespace scl diff --git a/include/scl/math/ff_ops.h b/include/scl/math/ff_ops.h index 7e64750..8c293fd 100644 --- a/include/scl/math/ff_ops.h +++ b/include/scl/math/ff_ops.h @@ -37,7 +37,7 @@ namespace details { * @param value the integer to convert */ template -void FieldConvertIn(typename F::ValueType& out, const int value); +void FieldConvertIn(typename F::ValueType& out, int value); /** * @brief Add two field elements in-place. diff --git a/include/scl/math/la.h b/include/scl/math/la.h index 4651b72..21ad7a7 100644 --- a/include/scl/math/la.h +++ b/include/scl/math/la.h @@ -53,7 +53,9 @@ void SwapRows(Mat& A, std::size_t k, std::size_t h) { */ template void MultiplyRow(Mat& A, std::size_t row, const T& m) { - for (std::size_t j = 0; j < A.Cols(); ++j) A(row, j) *= m; + for (std::size_t j = 0; j < A.Cols(); ++j) { + A(row, j) *= m; + } } /** @@ -65,7 +67,9 @@ void MultiplyRow(Mat& A, std::size_t row, const T& m) { */ template void AddRows(Mat& A, std::size_t dst, std::size_t op, const T& m) { - for (std::size_t j = 0; j < A.Cols(); ++j) A(dst, j) += A(op, j) * m; + for (std::size_t j = 0; j < A.Cols(); ++j) { + A(dst, j) += A(op, j) * m; + } } /** @@ -83,7 +87,9 @@ void RowReduceInPlace(Mat& A) { while (r < n && c < m) { // find pivot in current column auto pivot = r; - while (pivot < n && A(pivot, c) == zero) pivot++; + while (pivot < n && A(pivot, c) == zero) { + pivot++; + } if (pivot == n) { // this column was all 0, so go to next one @@ -97,10 +103,14 @@ void RowReduceInPlace(Mat& A) { // finally, for each row that is not r, subtract a multiple of row r. for (std::size_t k = 0; k < n; ++k) { - if (k == r) continue; + if (k == r) { + continue; + } // skip row if leading coefficient of that row is 0. auto t = A(k, c); - if (t != zero) AddRows(A, k, r, -t); + if (t != zero) { + AddRows(A, k, r, -t); + } } r++; c++; @@ -122,7 +132,9 @@ int GetPivotInColumn(const Mat& A, int col) { while (i-- > 0) { if (A(i, col) != zero) { for (int k = 0; k < col - 1; ++k) { - if (A(i, k) != zero) return -1; + if (A(i, k) != zero) { + return -1; + } } return i; } @@ -152,7 +164,9 @@ std::size_t FindFirstNonZeroRow(const Mat& A) { break; } } - if (non_zero) break; + if (non_zero) { + break; + } } return nzr; } @@ -185,7 +199,9 @@ Vec ExtractSolution(const Mat& A) { x[c] = T{1}; } else { T sum; - for (std::size_t j = p + 1; j < n; ++j) sum += A(i, j) * x[j]; + for (std::size_t j = p + 1; j < n; ++j) { + sum += A(i, j) * x[j]; + } x[c] = A(i, m - 1) - sum; i--; } @@ -218,9 +234,13 @@ bool HasSolution(const Mat& A, bool unique_only) { // the last column (the augmentation). I.e., when row(A', i) == 0, but // row(A, i) != 0. if (unique_only) { - if (all_zero) return false; + if (all_zero) { + return false; + } } else { - if (all_zero && A(i, m - 1) != zero) return false; + if (all_zero && A(i, m - 1) != zero) { + return false; + } } } return true; diff --git a/include/scl/math/mat.h b/include/scl/math/mat.h index a505426..c027f7d 100644 --- a/include/scl/math/mat.h +++ b/include/scl/math/mat.h @@ -85,7 +85,9 @@ class Mat { static Mat Vandermonde(std::size_t n, std::size_t m) { std::vector xs; xs.reserve(n); - for (std::size_t i = 0; i < n; ++i) xs.emplace_back(T(i + 1)); + for (std::size_t i = 0; i < n; ++i) { + xs.emplace_back(T(i + 1)); + } return Mat::Vandermonde(n, m, xs); } @@ -119,7 +121,9 @@ class Mat { */ static Mat FromVector(std::size_t n, std::size_t m, const std::vector& vec) { - if (vec.size() != n * m) throw std::invalid_argument("invalid dimensions"); + if (vec.size() != n * m) { + throw std::invalid_argument("invalid dimensions"); + } return Mat(n, m, vec); }; @@ -128,7 +132,9 @@ class Mat { */ static Mat Identity(std::size_t n) { Mat I(n); - for (std::size_t i = 0; i < n; ++i) I(i, i) = T(1); + for (std::size_t i = 0; i < n; ++i) { + I(i, i) = T(1); + } return I; } // LCOV_EXCL_LINE @@ -143,7 +149,9 @@ class Mat { * @param m the number of columns */ explicit Mat(std::size_t n, std::size_t m) { - if (!(n && m)) throw std::invalid_argument("n or m cannot be 0"); + if (n == 0 || m == 0) { + throw std::invalid_argument("n or m cannot be 0"); + } std::vector v(n * m); mRows = n; mCols = m; @@ -296,7 +304,9 @@ class Mat { * @return this scaled by \p scalar. */ Mat& ScalarMultiplyInPlace(const T& scalar) { - for (auto& v : mValues) v *= scalar; + for (auto& v : mValues) { + v *= scalar; + } return *this; }; @@ -317,8 +327,9 @@ class Mat { * @param cols the new column count */ Mat& Resize(std::size_t rows, std::size_t cols) { - if (rows * cols != Rows() * Cols()) + if (rows * cols != Rows() * Cols()) { throw std::invalid_argument("cannot resize matrix"); + } mRows = rows; mCols = cols; return *this; @@ -345,11 +356,14 @@ class Mat { * @brief Test if this matrix is equal to another. */ bool Equals(const Mat& other) const { - if (Rows() != other.Rows() || Cols() != other.Cols()) return false; + if (Rows() != other.Rows() || Cols() != other.Cols()) { + return false; + } bool equal = true; - for (std::size_t i = 0; i < mValues.size(); i++) + for (std::size_t i = 0; i < mValues.size(); i++) { equal &= mValues[i] == other.mValues[i]; + } return equal; }; @@ -378,8 +392,9 @@ class Mat { : mRows(r), mCols(c), mValues(v){}; void EnsureCompatible(const Mat& other) { - if (mRows != other.mRows || mCols != other.mCols) + if (mRows != other.mRows || mCols != other.mCols) { throw std::invalid_argument("incompatible matrices"); + } }; std::size_t mRows; @@ -391,7 +406,7 @@ class Mat { template Mat Mat::Read(std::size_t n, std::size_t m, const unsigned char* src) { - auto ptr = src; + const auto* ptr = src; auto total = n * m; // write all elements now that we know we'll not exceed the maximum read size. @@ -406,10 +421,9 @@ Mat Mat::Read(std::size_t n, std::size_t m, const unsigned char* src) { template void Mat::Write(unsigned char* dest) const { - auto ptr = dest; for (const auto& v : mValues) { - v.Write(ptr); - ptr += T::ByteSize(); + v.Write(dest); + dest += T::ByteSize(); } } @@ -421,10 +435,11 @@ Mat Mat::Random(std::size_t n, std::size_t m, PRG& prg) { std::size_t buffer_size = nelements * T::ByteSize(); auto buffer = std::make_unique(buffer_size); - auto ptr = buffer.get(); + auto* ptr = buffer.get(); prg.Next(buffer.get(), buffer_size); - for (std::size_t i = 0; i < nelements; i++) + for (std::size_t i = 0; i < nelements; i++) { elements.emplace_back(T::Read(ptr + i * T::ByteSize())); + } return Mat(n, m, elements); } @@ -482,8 +497,9 @@ Mat Mat::HyperInvertible(std::size_t n, std::size_t m) { template Mat Mat::Multiply(const Mat& other) const { - if (Cols() != other.Rows()) + if (Cols() != other.Rows()) { throw std::invalid_argument("invalid matrix dimensions for multiply"); + } const auto n = Rows(); const auto p = Cols(); const auto m = other.Cols(); @@ -512,15 +528,18 @@ Mat Mat::Transpose() const { template bool Mat::IsIdentity() const { - if (!IsSquare()) return false; + if (!IsSquare()) { + return false; + } bool is_ident = true; for (std::size_t i = 0; i < Rows(); ++i) { for (std::size_t j = 0; j < Cols(); ++j) { - if (i == j) + if (i == j) { is_ident &= operator()(i, j) == T{1}; - else + } else { is_ident &= operator()(i, j) == T{0}; + } } } return is_ident; @@ -536,13 +555,15 @@ std::string Mat::ToString() const { const auto n = Rows(); const auto m = Cols(); - if (!(n && m)) return "[ EMPTY_MATRIX ]"; + if (!(n && m)) { + return "[ EMPTY_MATRIX ]"; + } // convert all elements to strings and find the widest string in each column // since that will be used to properly align the final output. std::vector elements; elements.reserve(n * m); - std::vector fills; + std::vector fills; fills.reserve(m); for (std::size_t j = 0; j < m; j++) { auto first = operator()(0, j).ToString(); @@ -551,7 +572,9 @@ std::string Mat::ToString() const { for (std::size_t i = 1; i < n; i++) { auto next = operator()(i, j).ToString(); auto next_fill = next.size(); - if (next_fill > fill) fill = next_fill; + if (next_fill > fill) { + fill = next_fill; + } elements.push_back(next); } fills.push_back(fill + 1); @@ -566,7 +589,9 @@ std::string Mat::ToString() const { << " "; } ss << "]"; - if (i < n - 1) ss << "\n"; + if (i < n - 1) { + ss << "\n"; + } } return ss.str(); } diff --git a/include/scl/math/number.h b/include/scl/math/number.h index 45cc0da..4b536ed 100644 --- a/include/scl/math/number.h +++ b/include/scl/math/number.h @@ -76,7 +76,7 @@ class Number { * @brief Move constructor for a Number. * @param number the Number that is moved */ - Number(Number&& number); + Number(Number&& number) noexcept; /** * @brief Copy assignment from a Number. @@ -94,7 +94,7 @@ class Number { * @param number the Number that is moved * @return this */ - Number& operator=(Number&& number) { + Number& operator=(Number&& number) noexcept { swap(*this, number); return *this; }; @@ -257,42 +257,47 @@ class Number { */ int Compare(const Number& number) const; -#define SCL_CMP_FUN_NN(op) \ - friend bool operator op(const Number& lhs, const Number& rhs) { \ - return lhs.Compare(rhs) op 0; \ - } - /** * @brief Equality of two numbers. */ - SCL_CMP_FUN_NN(==); + friend bool operator==(const Number& lhs, const Number& rhs) { + return lhs.Compare(rhs) == 0; + }; /** * @brief In-equality of two numbers. */ - SCL_CMP_FUN_NN(!=); + friend bool operator!=(const Number& lhs, const Number& rhs) { + return lhs.Compare(rhs) != 0; + }; /** * @brief Strictly less-than of two numbers. */ - SCL_CMP_FUN_NN(<); + friend bool operator<(const Number& lhs, const Number& rhs) { + return lhs.Compare(rhs) < 0; + }; /** * @brief Less-than-or-equal of two numbers. */ - SCL_CMP_FUN_NN(<=); + friend bool operator<=(const Number& lhs, const Number& rhs) { + return lhs.Compare(rhs) <= 0; + }; /** * @brief Strictly greater-than of two numbers. */ - SCL_CMP_FUN_NN(>); + friend bool operator>(const Number& lhs, const Number& rhs) { + return lhs.Compare(rhs) > 0; + }; /** * @brief Greater-than-or-equal of two numbers. */ - SCL_CMP_FUN_NN(>=); - -#undef SCL_CMP_FUN_NN + friend bool operator>=(const Number& lhs, const Number& rhs) { + return lhs.Compare(rhs) >= 0; + }; /** * @brief Get the size of this Number in bits. diff --git a/include/scl/math/str.h b/include/scl/math/str.h index 4d7f25e..f337729 100644 --- a/include/scl/math/str.h +++ b/include/scl/math/str.h @@ -50,7 +50,9 @@ namespace details { template T FromHexString(const std::string& s) { auto n = s.size(); - if (n % 2) throw std::invalid_argument("odd-length hex string"); + if (n % 2) { + throw std::invalid_argument("odd-length hex string"); + } T t = 0; for (std::size_t i = 0; i < n; i += 2) { char c0 = s[i]; diff --git a/include/scl/math/vec.h b/include/scl/math/vec.h index 6b33fcd..ab5e89a 100644 --- a/include/scl/math/vec.h +++ b/include/scl/math/vec.h @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -32,6 +33,25 @@ namespace scl { +namespace details { + +/** + * @brief Computes an unchecked inner product between two iterators. + * @param xb start of the first iterator + * @param xe end of the first iterator + * @param yb start of the second iterator + */ +template +T UncheckedInnerProd(It xb, It xe, It yb) { + T v; + while (xb != xe) { + v += *xb++ * *yb++; + } + return v; +} + +} // namespace details + /** * @brief Vector. * @@ -98,6 +118,14 @@ class Vec { */ static Vec Random(std::size_t n, PRG& prg); + /** + * @brief Create a vector with values in a range. + * @param start the start value, inclusive + * @param end the end value, exclusive + * @return a vector with values [start, start + 1, ..., end - 1]. + */ + static Vec Range(std::size_t start, std::size_t end); + /** * @brief Default constructor that creates an empty Vec. */ @@ -212,10 +240,7 @@ class Vec { */ T Dot(const Vec& other) const { EnsureCompatible(other); - T result; - for (std::size_t i = 0; i < Size(); i++) - result += mValues[i] * other.mValues[i]; - return result; + return details::UncheckedInnerProd(begin(), end(), other.begin()); }; /** @@ -224,7 +249,9 @@ class Vec { */ T Sum() const { T sum; - for (const auto& v : mValues) sum += v; + for (const auto& v : mValues) { + sum += v; + } return sum; }; @@ -236,7 +263,9 @@ class Vec { Vec ScalarMultiply(const T& scalar) const { std::vector r; r.reserve(Size()); - for (const auto& v : mValues) r.emplace_back(scalar * v); + for (const auto& v : mValues) { + r.emplace_back(scalar * v); + } return Vec(r); }; @@ -246,7 +275,9 @@ class Vec { * @return a scaled version of this vector. */ Vec& ScalarMultiplyInPlace(const T& scalar) { - for (auto& v : mValues) v *= scalar; + for (auto& v : mValues) { + v *= scalar; + } return *this; }; @@ -257,6 +288,20 @@ class Vec { */ bool Equals(const Vec& other) const; + /** + * @brief Operator == overload for Vec. + */ + friend bool operator==(const Vec& left, const Vec& right) { + return left.Equals(right); + }; + + /** + * @brief Operator != overload for Vec. + */ + friend bool operator!=(const Vec& left, const Vec& right) { + return !(left == right); + }; + /** * @brief Convert this vector into a 1-by-N row matrix. */ @@ -277,6 +322,29 @@ class Vec { */ const std::vector& ToStlVector() const { return mValues; }; + /** + * @brief Extract a sub-vector + * @param start the start index, inclusive + * @param end the end index, exclusive + * @return a sub-vector. + */ + Vec SubVector(std::size_t start, std::size_t end) { + if (start > end) { + throw std::logic_error("invalid range"); + } + return Vec(begin() + start, begin() + end); + } + + /** + * @brief Extract a sub-vector. + * + * This method is equivalent to Vec#SubVector(0, end). + * + * @param end the end index, exclusive + * @return a sub-vector. + */ + Vec SubVector(std::size_t end) { return SubVector(0, end); }; + /** * @brief Return a string representation of this vector. */ @@ -362,8 +430,9 @@ class Vec { private: void EnsureCompatible(const Vec& other) const { - if (Size() != other.Size()) + if (Size() != other.Size()) { throw std::invalid_argument("Vec sizes mismatch"); + } }; std::vector mValues; @@ -396,6 +465,23 @@ Vec Vec::PartialRandom(std::size_t n, Pred predicate, PRG& prg) { return Vec(v); } +template +Vec Vec::Range(std::size_t start, std::size_t end) { + if (start > end) { + throw std::invalid_argument("invalid range"); + } + if (start == end) { + return Vec{}; + } + + std::vector v; + v.reserve(end - start); + for (std::size_t i = start; i < end; ++i) { + v.emplace_back(T{(int)i}); + } + return Vec(v); +} + template Vec Vec::Random(std::size_t n, PRG& prg) { return Vec::PartialRandom( @@ -459,24 +545,27 @@ bool Vec::Equals(const Vec& other) const { template std::string Vec::ToString() const { + std::string str; if (Size()) { std::stringstream ss; ss << "["; std::size_t i = 0; - for (; i < Size() - 1; i++) ss << mValues[i] << ", "; + for (; i < Size() - 1; i++) { + ss << mValues[i] << ", "; + } ss << mValues[i] << "]"; - return ss.str(); + str = ss.str(); } else { - return "[ EMPTY_VECTOR ]"; + str = "[ EMPTY_VECTOR ]"; } + return str; } template void Vec::Write(unsigned char* dest) const { - unsigned char* p = dest; for (const auto& v : mValues) { - v.Write(p); - p += T::ByteSize(); + v.Write(dest); + dest += T::ByteSize(); } } diff --git a/include/scl/math/z2k_ops.h b/include/scl/math/z2k_ops.h index 5d7e618..c767bad 100644 --- a/include/scl/math/z2k_ops.h +++ b/include/scl/math/z2k_ops.h @@ -82,8 +82,9 @@ unsigned LsbZ2k(T& v) { */ template = true> void InvertZ2k(T& v) { - if (!LsbZ2k(v)) + if (!LsbZ2k(v)) { throw std::invalid_argument("value not invertible modulo 2^K"); + } std::size_t bits = 5; T z = ((v * 3) ^ 2); @@ -95,7 +96,7 @@ void InvertZ2k(T& v) { v = z; } -#define SCL_MASK(T, K) ((static_cast(1) << K) - 1) +#define SCL_MASK(T, K) ((static_cast(1) << (K)) - 1) /** * @brief Compute equality modulo a power of 2. diff --git a/include/scl/net/channel.h b/include/scl/net/channel.h index 75c9d4d..1a4becb 100644 --- a/include/scl/net/channel.h +++ b/include/scl/net/channel.h @@ -51,8 +51,8 @@ namespace scl { #define MAX_MAT_READ_SIZE 1 << 25 #endif -#define SCL_CC(x) reinterpret_cast(x) -#define SCL_C(x) reinterpret_cast(x) +#define SCL_CC(x) reinterpret_cast((x)) +#define SCL_C(x) reinterpret_cast((x)) /** * @brief Abstract channel for communicating between two peers. @@ -85,7 +85,14 @@ class Channel { * @param n how much data to receive * @return how many bytes were received. */ - virtual int Recv(unsigned char* dst, std::size_t n) = 0; + virtual std::size_t Recv(unsigned char* dst, std::size_t n) = 0; + + /** + * @brief Check if there is something to receive on this channel. + * @return true if this channel has data and false otherwise. + * @note the default implementation always returns true. + */ + virtual bool HasData() { return true; }; /** * @brief Send a trivially copyable item. @@ -108,6 +115,7 @@ class Channel { template , bool> = true> void Send(const std::vector& src) { + Send(src.size()); Send(SCL_CC(src.data()), sizeof(T) * src.size()); } @@ -168,15 +176,16 @@ class Channel { /** * @brief Receive a vector of trivially copyable items. - * - * dst.size() determines how many bytes to receive. - * * @param dst where to store the received items + * @note any existing content in \p dst is overwritten. */ template , bool> = true> void Recv(std::vector& dst) { - Recv(SCL_C(dst.data()), sizeof(T) * dst.size()); + std::size_t size; + Recv(size); + dst.resize(size); + Recv(SCL_C(dst.data()), sizeof(T) * size); } /** @@ -188,8 +197,9 @@ class Channel { template void Recv(Vec& vec) { auto vec_size = RecvSize(); - if (vec_size > MAX_VEC_READ_SIZE) + if (vec_size > MAX_VEC_READ_SIZE) { throw std::logic_error("received vector exceeds size limit"); + } auto n = vec_size * T::ByteSize(); auto buf = std::make_unique(n); Recv(SCL_C(buf.get()), n); @@ -206,8 +216,9 @@ class Channel { void Recv(Mat& mat) { auto rows = RecvSize(); auto cols = RecvSize(); - if (rows * cols > MAX_MAT_READ_SIZE) + if (rows * cols > MAX_MAT_READ_SIZE) { throw std::logic_error("received matrix exceeds size limit"); + } auto n = rows * cols * T::ByteSize(); auto buf = std::make_unique(n); Recv(SCL_C(buf.get()), n); diff --git a/include/scl/net/config.h b/include/scl/net/config.h index fd36d54..8831779 100644 --- a/include/scl/net/config.h +++ b/include/scl/net/config.h @@ -42,7 +42,7 @@ struct Party { /** * @brief The id of this party. */ - unsigned id; + int id; /** * @brief The hostname. @@ -68,7 +68,7 @@ class NetworkConfig { * @param id the identity of this party * @param filename the filename */ - static NetworkConfig Load(unsigned id, std::string filename); + static NetworkConfig Load(int id, const std::string& filename); /** * @brief Create a network config where all parties are running locally. @@ -82,14 +82,14 @@ class NetworkConfig { * @param size the size of the network * @param port_base the base port */ - static NetworkConfig Localhost(unsigned id, std::size_t size, int port_base); + static NetworkConfig Localhost(int id, int size, int port_base); /** * @brief Create a network config where all parties are running locally. * @param id the identity of this party * @param size the size of the network */ - static NetworkConfig Localhost(unsigned id, std::size_t size) { + static NetworkConfig Localhost(int id, int size) { return NetworkConfig::Localhost(id, size, DEFAULT_PORT_OFFSET); }; @@ -98,7 +98,7 @@ class NetworkConfig { * @param id the id of the local party * @param parties a list of parties */ - NetworkConfig(unsigned id, std::vector parties) + NetworkConfig(int id, const std::vector& parties) : mId(id), mParties(parties) { Validate(); }; @@ -112,7 +112,7 @@ class NetworkConfig { /** * @brief Gets the identity of this party. */ - unsigned Id() const { return mId; }; + int Id() const { return mId; }; /** * @brief Gets the size of the network. @@ -137,7 +137,7 @@ class NetworkConfig { private: void Validate(); - unsigned mId; + int mId; std::vector mParties; }; diff --git a/include/scl/net/discovery/client.h b/include/scl/net/discovery/client.h index ed81866..50e5bb1 100644 --- a/include/scl/net/discovery/client.h +++ b/include/scl/net/discovery/client.h @@ -41,14 +41,14 @@ class DiscoveryClient { * @param discovery_port the port of the discovery server * @param discovery_hostname the hostname of the discovery server */ - DiscoveryClient(std::string discovery_hostname, int discovery_port) + DiscoveryClient(const std::string& discovery_hostname, int discovery_port) : mHostname(discovery_hostname), mPort(discovery_port){}; /** * @brief Create a new client in a discovery protocol. * @param discovery_hostname the hostname of the discovery server */ - DiscoveryClient(std::string discovery_hostname) + DiscoveryClient(const std::string& discovery_hostname) : DiscoveryClient(discovery_hostname, DEFAULT_DISCOVERY_PORT){}; /** @@ -57,7 +57,7 @@ class DiscoveryClient { * @param port the port of this party * @return A network configuration. */ - NetworkConfig Run(unsigned id, int port); + NetworkConfig Run(int id, int port) const; class SendIdAndPort; class ReceiveNetworkConfig; @@ -77,15 +77,16 @@ class DiscoveryClient::SendIdAndPort /** * @brief Constructor. */ - SendIdAndPort(unsigned id, int port) : mId(id), mPort(port){}; + SendIdAndPort(int id, int port) : mId(id), mPort(port){}; /** * @brief Run this protocol step. */ - DiscoveryClient::ReceiveNetworkConfig Run(std::shared_ptr ctx); + DiscoveryClient::ReceiveNetworkConfig Run( + const std::shared_ptr& ctx) const; private: - unsigned mId; + int mId; int mPort; }; @@ -99,15 +100,15 @@ class DiscoveryClient::ReceiveNetworkConfig /** * @brief Constructor. */ - ReceiveNetworkConfig(unsigned id) : mId(id){}; + ReceiveNetworkConfig(int id) : mId(id){}; /** * @brief Finalize the discovery protocol. */ - NetworkConfig Finalize(std::shared_ptr ctx); + NetworkConfig Finalize(const std::shared_ptr& ctx) const; private: - unsigned mId; + int mId; }; } // namespace scl diff --git a/include/scl/net/discovery/server.h b/include/scl/net/discovery/server.h index 371c6d9..b02e22c 100644 --- a/include/scl/net/discovery/server.h +++ b/include/scl/net/discovery/server.h @@ -53,8 +53,9 @@ class DiscoveryServer { */ DiscoveryServer(int discovery_port, std::size_t number_of_parties) : mPort(discovery_port), mNumberOfParties(number_of_parties) { - if (number_of_parties > MAX_DISCOVER_PARTIES) + if (number_of_parties > MAX_DISCOVER_PARTIES) { throw std::invalid_argument("number_of_parties exceeds max"); + } }; /** @@ -69,7 +70,7 @@ class DiscoveryServer { * @param me ID, port and hostname information for this party * @return A network configuration. */ - NetworkConfig Run(const Party& me); + NetworkConfig Run(const Party& me) const; class CollectIdsAndPorts; class SendNetworkConfig; @@ -107,7 +108,7 @@ class DiscoveryServer::CollectIdsAndPorts /** * @brief Constructor. */ - CollectIdsAndPorts(std::vector hostnames) + CollectIdsAndPorts(const std::vector& hostnames) : mHostnames(hostnames){}; /** diff --git a/include/scl/net/mem_channel.h b/include/scl/net/mem_channel.h index a054b08..5e552b0 100644 --- a/include/scl/net/mem_channel.h +++ b/include/scl/net/mem_channel.h @@ -46,8 +46,9 @@ class InMemoryChannel final : public Channel { static std::array, 2> CreatePaired() { auto buf0 = std::make_shared(); auto buf1 = std::make_shared(); - return {std::make_shared(buf0, buf1), - std::make_shared(buf1, buf0)}; + auto chl0 = std::make_shared(buf0, buf1); + auto chl1 = std::make_shared(buf1, buf0); + return {chl0, chl1}; }; /** @@ -65,18 +66,21 @@ class InMemoryChannel final : public Channel { */ InMemoryChannel(std::shared_ptr in_buffer, std::shared_ptr out_buffer) - : mIn(in_buffer), mOut(out_buffer){}; + : mIn(std::move(in_buffer)), mOut(std::move(out_buffer)){}; /** * @brief Flush the incomming buffer; */ void Flush() { - while (mIn->Size()) mIn->PopFront(); + while (mIn->Size() > 0) { + mIn->PopFront(); + } mOverflow.clear(); }; void Send(const unsigned char* src, std::size_t n) override; - int Recv(unsigned char* dst, std::size_t n) override; + std::size_t Recv(unsigned char* dst, std::size_t n) override; + bool HasData() override { return mIn->Size() > 0 || !mOverflow.empty(); }; void Close() override{}; private: diff --git a/include/scl/net/network.h b/include/scl/net/network.h index 85b0ed2..b1ed6f3 100644 --- a/include/scl/net/network.h +++ b/include/scl/net/network.h @@ -21,6 +21,7 @@ #ifndef SCL_NET_NETWORK_H #define SCL_NET_NETWORK_H +#include #include #include #include @@ -75,7 +76,9 @@ class Network { * @brief Closes all channels in the network. */ void Close() { - for (auto c : mChannels) c->Close(); + for (auto& c : mChannels) { + c->Close(); + } }; private: @@ -153,7 +156,7 @@ Network Network::Create(const scl::NetworkConfig& config) { std::thread server(scl::details::SCL_AcceptConnections, std::ref(channels), config); - for (std::size_t i = 0; i < config.Id(); ++i) { + for (std::size_t i = 0; i < static_cast(config.Id()); ++i) { const auto party = config.GetParty(i); auto socket = scl::details::ConnectAsClient(party.hostname, party.port); std::shared_ptr channel = std::make_shared(socket); diff --git a/include/scl/net/tcp_channel.h b/include/scl/net/tcp_channel.h index 1bcbb85..f5af2fe 100644 --- a/include/scl/net/tcp_channel.h +++ b/include/scl/net/tcp_channel.h @@ -52,7 +52,8 @@ class TcpChannel final : public Channel { bool Alive() const { return mAlive; }; void Send(const unsigned char* src, std::size_t n) override; - int Recv(unsigned char* dst, std::size_t n) override; + std::size_t Recv(unsigned char* dst, std::size_t n) override; + bool HasData() override; void Close() override; private: diff --git a/include/scl/net/tcp_utils.h b/include/scl/net/tcp_utils.h index f10684a..fd75e70 100644 --- a/include/scl/net/tcp_utils.h +++ b/include/scl/net/tcp_utils.h @@ -64,7 +64,7 @@ AcceptedConnection AcceptConnection(int server_socket); /** * @brief Extra the hostname of an accepted connection. */ -std::string GetAddress(AcceptedConnection connection); +std::string GetAddress(const AcceptedConnection& connection); /** * @brief Connect in client mode. @@ -72,7 +72,7 @@ std::string GetAddress(AcceptedConnection connection); * @param port the port of the server * @return a socket. */ -int ConnectAsClient(std::string hostname, int port); +int ConnectAsClient(const std::string& hostname, int port); /** * @brief Close a socket. @@ -82,12 +82,12 @@ int CloseSocket(int socket); /** * @brief Read from a socket. */ -int ReadFromSocket(int socket, unsigned char* dst, std::size_t n); +ssize_t ReadFromSocket(int socket, unsigned char* dst, std::size_t n); /** * @brief Write to a socket. */ -int WriteToSocket(int socket, const unsigned char* src, std::size_t n); +ssize_t WriteToSocket(int socket, const unsigned char* src, std::size_t n); } // namespace details } // namespace scl diff --git a/include/scl/net/threaded_sender.h b/include/scl/net/threaded_sender.h index 55b8c84..f06e1b5 100644 --- a/include/scl/net/threaded_sender.h +++ b/include/scl/net/threaded_sender.h @@ -52,10 +52,12 @@ class ThreadedSenderChannel final : public Channel { mSendBuffer.PushBack({src, src + n}); }; - int Recv(unsigned char* dst, std::size_t n) override { + std::size_t Recv(unsigned char* dst, std::size_t n) override { return mChannel.Recv(dst, n); }; + bool HasData() override { return mChannel.HasData(); }; + private: TcpChannel mChannel; details::SharedDeque> mSendBuffer; diff --git a/include/scl/prg.h b/include/scl/prg.h index e633512..5699aa8 100644 --- a/include/scl/prg.h +++ b/include/scl/prg.h @@ -120,8 +120,9 @@ class PRG { * @param nbytes how many bytes to generate. */ void Next(std::vector &dest, std::size_t nbytes) { - if (dest.size() < nbytes) + if (dest.size() < nbytes) { throw std::invalid_argument("requested more randomness than dest.size()"); + } Next(dest.data(), nbytes); }; diff --git a/include/scl/ss/additive.h b/include/scl/ss/additive.h index 7976551..9d935af 100644 --- a/include/scl/ss/additive.h +++ b/include/scl/ss/additive.h @@ -38,7 +38,9 @@ namespace scl { */ template Vec CreateAdditiveShares(const T& secret, std::size_t n, PRG& prg) { - if (!n) throw std::invalid_argument("cannot create shares for 0 people"); + if (!n) { + throw std::invalid_argument("cannot create shares for 0 people"); + } Vec shares = Vec::PartialRandom( n, [](std::size_t i) { return i > 0; }, prg); diff --git a/include/scl/ss/poly.h b/include/scl/ss/poly.h index 61d2f05..04340a9 100644 --- a/include/scl/ss/poly.h +++ b/include/scl/ss/poly.h @@ -61,7 +61,9 @@ class Polynomial { auto it = mCoefficients.rbegin(); auto end = mCoefficients.rend(); auto y = *it++; - while (it != end) y = *it++ + y * x; + while (it != end) { + y = *it++ + y * x; + } return y; }; @@ -78,23 +80,23 @@ class Polynomial { /** * @brief Add two polynomials. */ - Polynomial Add(const Polynomial& p) const; + Polynomial Add(const Polynomial& q) const; /** * @brief Subtraction two polynomials. */ - Polynomial Subtract(const Polynomial& p) const; + Polynomial Subtract(const Polynomial& q) const; /** * @brief Multiply two polynomials. */ - Polynomial Multiply(const Polynomial& p) const; + Polynomial Multiply(const Polynomial& q) const; /** * @brief Divide two polynomials. * @return A pair \f$(q, r)\f$ such that \f$\mathtt{this} = p * q + r\f$. */ - std::array Divide(const Polynomial& p) const; + std::array Divide(const Polynomial& q) const; /** * @brief Returns true if this is the 0 polynomial. @@ -155,14 +157,18 @@ Polynomial Polynomial::Create(const Vec& coefficients) { auto cutoff = coefficients.Size(); T zero; for (; it != end; ++it) { - if (*it != zero) break; + if (*it != zero) { + break; + } --cutoff; } const auto c = Vec(coefficients.begin(), coefficients.begin() + cutoff); - if (!c.Size()) + + if (!c.Size()) { return Polynomial{}; - else - return Polynomial{c}; + } + + return Polynomial{c}; } /** @@ -175,7 +181,9 @@ template Vec PadCoefficients(const Polynomial& p, std::size_t n) { Vec c(n); for (std::size_t i = 0; i < n; ++i) { - if (i <= p.Degree()) c[i] = p[i]; + if (i <= p.Degree()) { + c[i] = p[i]; + } } return c; } // LCOV_EXCL_LINE @@ -226,28 +234,33 @@ Polynomial DivideLeadingTerms(const Polynomial& p, template std::array, 2> Polynomial::Divide( - const Polynomial& d) const { - if (d.IsZero()) throw std::invalid_argument("division by 0"); + const Polynomial& q) const { + if (q.IsZero()) { + throw std::invalid_argument("division by 0"); + } // https://en.wikipedia.org/wiki/Polynomial_long_division#Pseudocode - Polynomial q; + Polynomial p; Polynomial r = *this; - while (!r.IsZero() && r.Degree() >= d.Degree()) { - const auto t = DivideLeadingTerms(r, d); - q = q.Add(t); - r = r.Subtract(t.Multiply(d)); + while (!r.IsZero() && r.Degree() >= q.Degree()) { + const auto t = DivideLeadingTerms(r, q); + p = p.Add(t); + r = r.Subtract(t.Multiply(q)); } - return {q, r}; + return {p, r}; } template -std::string Polynomial::ToString(const char* pn, const char* vn) const { +std::string Polynomial::ToString(const char* polynomial_name, + const char* variable_name) const { std::stringstream ss; - ss << pn << "(" << vn << ") = " << mCoefficients[0]; + ss << polynomial_name << "(" << variable_name << ") = " << mCoefficients[0]; for (std::size_t i = 1; i < mCoefficients.Size(); i++) { - ss << " + " << mCoefficients[i] << vn; - if (i > 1) ss << "^" << i; + ss << " + " << mCoefficients[i] << variable_name; + if (i > 1) { + ss << "^" << i; + } } return ss.str(); } diff --git a/include/scl/ss/shamir.h b/include/scl/ss/shamir.h index 367c7b5..b38a4c2 100644 --- a/include/scl/ss/shamir.h +++ b/include/scl/ss/shamir.h @@ -22,8 +22,11 @@ #define SCL_SS_SHAMIR_H #include -#include +#include #include +#include +#include +#include #include "scl/math/la.h" #include "scl/math/vec.h" @@ -34,200 +37,42 @@ namespace scl { namespace details { /** - * @brief Create a random polynomial suitable for create Shamir secret-shares. - * @param secret the secret to embed in the constant term of the polynomial - * @param t the threshold that this polynomial should support - * @param prg a PRG used to generate random coefficients - * @return a Polynomial that can be used to generate degree t shares. - */ -template -details::Polynomial CreateShamirSharePolynomial(const T& secret, - std::size_t t, PRG& prg) { - if (!t) throw std::invalid_argument("threshold cannot be 0"); - auto coeff = Vec::PartialRandom( - t + 1, [](std::size_t i) { return i > 0; }, prg); - coeff[0] = secret; - return details::Polynomial::Create(coeff); -} - -/** - * @brief Interpolate a polynomial given a list of evaluations (f(x), x). - * @param ys the evaluation results - * @param xs the evaluation points - * @param k the number of points to use - * @param x the point at which to interpolate - * @param offset an offset into the \p ys and \p xs - * @return f(\p x) where f was interpolated from \p k of the provided points. - */ -template -T InterpolateAt(const Vec& ys, const Vec& xs, std::size_t k, const T& x, - std::size_t offset) { - T z; - for (std::size_t j = 0; j < k; ++j) { - T ell(1); - auto xj = xs[offset + j]; - for (std::size_t m = 0; m < k; ++m) { - if (m == j) continue; - auto xm = xs[offset + m]; - ell *= (x - xm) / (xj - xm); - } - z += ys[offset + j] * ell; - } - return z; -} -} // namespace details - -/** - * @brief Create a collection of shamir shares given a polynomial and points. - * @param sharing_polynomial the polynomial used to generate the shares - * @param alphas the evaluation points - * @tparam T a finite field - * @return A vector of shamir secret-shares. - */ -template -Vec CreateShamirShares(const details::Polynomial& sharing_polynomial, - const Vec& alphas) { - auto n = alphas.Size(); - Vec shares(n); - for (std::size_t i = 0; i < n; i++) { - shares[i] = sharing_polynomial.Evaluate(alphas[i]); - } - return shares; -} // LCOV_EXCL_LINE - -/** - * @brief Return the list of canonical evaluation points - * @param n the number of points - * @tparam T a finite field - * @return the list [1, 2, ..., n] of finite field elements. - */ -template -Vec CanonicalAlphas(std::size_t n) { - Vec alphas(n); - for (std::size_t i = 0; i < n; i++) alphas[i] = T{(int)(i + 1)}; - return alphas; -} // LCOV_EXCL_LINE - -/** - * @brief Create a shamir secret-sharing. - * @param secret the secret - * @param n the number of shares to create - * @param t the privacy threshold - * @param prg a PRG for randomness - * @return a vector of Shamir secret-shares. - */ -template -Vec CreateShamirShares(const T& secret, std::size_t n, std::size_t t, - PRG& prg) { - return CreateShamirShares( - details::CreateShamirSharePolynomial(secret, t, prg), - CanonicalAlphas(n)); -} - -/** - * @brief Reconstruct a shamir shared secret with passive security. - * @param shares the shars - * @param alphas the alphas - * @param pos the position of the secret - * @param t the threshold - * @return the reconstructed value. - * @throws std::invalid_argument if less than \f$t+1\f$ shares were given - * @throws std::invalid_argument if less than \f$t+1\f$ alphas were given - */ -template -T ReconstructShamirPassive(const Vec& shares, const Vec& alphas, - const T& pos, std::size_t t) { - if (t + 1 > shares.Size()) - throw std::invalid_argument("not enough shares to reconstruct"); - if (t + 1 > alphas.Size()) - throw std::invalid_argument("not enough alphas to reconstruct"); - return InterpolateAt(shares, alphas, t + 1, pos, 0); -} - -/** - * @brief Reconstruct a Shamir shared secret with passive security. - * - * This method makes no guarantee regarding the reconstructed value. In - * particular, this method should only be used in a passive security model. + * @brief Robust reconstruction. * - * The shares is assumed to have been generated using alphas as output by \ref - * scl::CanonicalAlphas, and the secret is assumed to have been placed on the - * constant term of the sharing polynomial. - * - * @param shares the shares - * @param t the threshold - * @return the reconstructed secret. - * @see ReconstructShamirPassive - */ -template -T ReconstructShamirPassive(const Vec& shares, std::size_t t) { - return ReconstructShamirPassive(shares, CanonicalAlphas(shares.Size()), - T(0), t); -} - -/** - * @brief Reconstruct a Shamir shared secret with error detection. + *

This function performs a robust Shamir secret-share reconstruction. Given + * at least \f$3t + 1\f$ pairs \f$(f(x),x)\f$, where each \f$f(x)\f$ is a share + * and \f$x\f$ a corresponding evaluation alpha, finds the degree \f$t\f$ + * polynomial \f$f\f$ that passes through all supplied shares.

* - * This method will attempt to reconstruct a degree \f$t\f$ shared secret from - * at least \f$2t+1\f$ provided points. In case the provided shares do not all - * lie on a degree \f$t\f$ polynomial, an exception is thrown. + *

The return value is a pair of polynomials \f$(f,e)\f$ where \f$f\f$ is the + * reconstructed polynomial and \f$e\f$ a polynomial whose roots indicate which + * (if any) of the input shares were invalid. I.e., \f$e(i)=0\f$ if + * share[i] had an error.

* - * @param shares the shares - * @param alphas the alphas - * @param pos the position of the secret - * @param t the threshold - * @return the correct reconstructed value. - * @throws std::invalid_argument if less than \f$2t+1\f$ shares were given - * @throws std::invalid_argument if less than \f$2t+1\f$ alphas were given - * @throws std::logic_error if one of the shares contained an error - */ -template -T ReconstructShamir(const Vec& shares, const Vec& alphas, const T& pos, - std::size_t t) { - if (2 * t + 1 > shares.Size()) - throw std::invalid_argument( - "not enough shares to reconstruct with error detection"); - if (2 * t + 1 > alphas.Size()) - throw std::invalid_argument( - "not enough alphas to reconstruct with error detection"); - - for (std::size_t k = t + 1; k < 2 * t + 1; ++k) { - auto s = InterpolateAt(shares, alphas, t + 1, alphas[k], 0); - if (s != shares[k]) - throw std::logic_error("error detected during reconstruction"); - } - return InterpolateAt(shares, alphas, t + 1, pos, 0); -} - -/** - * @brief Reconstruct a Shamir shared secret with error detection. - * @see ReconstructShamir - */ -template -T ReconstructShamir(const Vec& shares, std::size_t t) { - return ReconstructShamir(shares, CanonicalAlphas(shares.Size()), T(0), t); -} - -/** - * @brief Reconstruct a Shamir shared secret with error correction. + * @param shares the shares to use for reconstruction + * @param alphas the evaluation alphas + * @param t the degree of the polynomial to reconstruct + * @return a pair of polynomials. */ template -std::array, 2> ReconstructShamirRobust( - const Vec& shares, const Vec& alphas, std::size_t t) { +auto ReconstructShamirRobust(const Vec& shares, const Vec& alphas, + std::size_t t) { std::size_t n = 3 * t + 1; - if (n > shares.Size()) + if (n > shares.Size()) { throw std::invalid_argument( "not enough shares to reconstruct with error correction"); - if (n > alphas.Size()) + } + if (n > alphas.Size()) { throw std::invalid_argument( "not enough alphas to reconstruct with error correction"); + } Mat A(n); Vec b(n); Vec x(n); int e; for (std::size_t k = 0; k <= t; ++k) { - e = t - k; + e = t - k; // NOLINT for (std::size_t i = 0; i < n; ++i) { b[i] = -shares[i]; @@ -243,7 +88,9 @@ std::array, 2> ReconstructShamirRobust( } } - if (SolveLinearSystem(x, A, b)) break; + if (SolveLinearSystem(x, A, b)) { + break; + } } Vec cE{x.begin(), x.begin() + e + 1}; @@ -252,21 +99,303 @@ std::array, 2> ReconstructShamirRobust( auto E = details::Polynomial::Create(cE); auto Q = details::Polynomial::Create(Vec{x.begin() + e, x.end()}); auto qr = Q.Divide(E); - if (qr[1].IsZero()) return {qr[0], E}; - throw std::logic_error("could not correct shares"); + if (!qr[1].IsZero()) { + throw std::logic_error("could not correct shares"); + } + + return std::make_pair(qr[0], E); } /** - * @brief Reconstruct a Shamir shared secret with error correction. + * @brief Robust reconstruction. + * + * This function is a short-hand for ReconstructShamirRobust(shares, + * Vec::Range(1, shares.Size() + 1), t). + * + * @param shares the shares + * @param t the degree of the polynomial to reconstruct + * @return \f$f(0)\f$ where \f$f\f$ is the reconstructed polynomial. */ template T ReconstructShamirRobust(const Vec& shares, std::size_t t) { auto p = - ReconstructShamirRobust(shares, CanonicalAlphas(shares.Size()), t); - return p[0].Evaluate(T(0)); + ReconstructShamirRobust(shares, Vec::Range(1, shares.Size() + 1), t); + return std::get<0>(p).Evaluate(T(0)); +} + +/** + * @brief Defines some common security levels. + * + * These security levels are used to derive some common defaults. For example, + * SecurityLevel::DETECT is used to indicate that ShamirSSFactory, when + * instantiated with a threshold of \f$t\f$, should generate \f$2t + 1\f$ shares + * if no other value is specified. + */ +enum class SecurityLevel { + /** + * @brief \f$t + 1\f$. + */ + PASSIVE, + + /** + * @brief \f$2t + 1\f$. Enough shares to detect errors. + */ + DETECT, + + /** + * @brief \f$3t + 1\f$. Enough shares to correct errors. + */ + CORRECT +}; + +/** + * @brief Class for reconstructing secrets from Shamir secret-shares. + * + *

The main purpose of this class is to cache the lagrange coefficients used + * when performing interpolation. These coefficients are computed when they are + * first needed. The only exception is the coefficients needed to reconstruct + * the point at 0, which is usually where the secret is placed.

+ * + *

A Reconstructor contains essentially two methods for reconstructing: + *

    + *
  • Reconstruct reconstructs a particular evaluation of the implied + * polynomial.
  • + *
  • ReconstructShare reconstructs a particular share of a party
  • + *
+ * Both methods do effectively the same, and the only difference is how the + * index passed to them is interpreted. For the first, the index is interpreted + * as-is. I.e., if we pass an index of 3, then it will recover the point + * \f$f(3)\f$, where \f$f\f$ is the polynomial implied by the shares we provide. + * For the latter method, passing an index of \f$i\f$ implies that we wish to + * interpolate the point \f$f(i + 1)\f$. This is because we assume parties are + * indexed from 0, but their evaluation indices are indexed from 1 (because the + * 0'th index is reserved for the secret).

+ */ +template +class Reconstructor { + public: + /** + * @brief Create a new Reconstructor instance. + * + * This creator method creates a new instance and pre-computes the + * coefficients needed to recover a secret from a set of shamir shares. The + * SecurityLevel given is used to infer the default reconstruction method: + *
    + *
  • SecurityLevel::PASSIVE: use passive reconstruction. I.e., reconstruct + * the value we ask for, using the minimal amount of \f$t + 1\f$ shares + * needed.
  • + * + *
  • SecurityLevel::DETECT: use the first \f$t+1\f$ shares to reconstruct + * the remaning \f$t\f$ shares. If all of these check out, then reconstruct + * the value we seek.
  • + * + *
  • SecurityLevel::CORRECT: use an error correcting + * algorithm to recover the polynomial \f$f\f$ from \f$3t+1\f$ shares, + * assuming that at most \f$t\f$ shares are faulty.
  • + *
+ * + * @param threshold the threshold \f$t\f$ + * @param default_security_level the default security level + */ + static Reconstructor Create(std::size_t threshold, + SecurityLevel default_security_level); + + /** + * @brief Interpolate a set of shares according to a given security level + * @param shares the shares + * @param security_level the security level + * @param index the index to interpolate. Defaults to 0 + */ + T Reconstruct(const Vec& shares, SecurityLevel security_level, + int index = 0) const; + + /** + * @brief Interpolate a set of shares with the default security level + * @param shares the shares + * @param index the index to interpolate. Defaults to 0 + */ + T Reconstruct(const Vec& shares, int index = 0) const { + return Reconstruct(shares, mSecurityLevel, index); + }; + + /** + * @brief Reconstruct the share of a party. + */ + T ReconstructShare(const Vec& shares, SecurityLevel level, + int party_index = 0) const { + return Reconstruct(shares, level, party_index + 1); + }; + + /** + * @brief Reconstruct the share of a party. + */ + T ReconstructShare(const Vec& shares, int party_index = 0) const { + return Reconstruct(shares, party_index + 1); + }; + + /** + * @brief Get the lagrange coefficients for interpolating a particular index. + * + * This method modifies the internal cache in case coefficients for the + * requested index does not exist. + */ + const Vec& GetLagrangeCoefficients(int index) const { + if (mLagrangeCoeff.count(index) == 0) { + ComputeLagrangeCoefficients(index); + } + return mLagrangeCoeff[index]; + }; + + private: + Reconstructor(std::size_t threshold, SecurityLevel security_level) + : mThreshold(threshold), mSecurityLevel(security_level){}; + + void ComputeLagrangeCoefficients(int index) const; + + std::size_t mThreshold; + SecurityLevel mSecurityLevel; + mutable std::unordered_map> mLagrangeCoeff; +}; + +template +Reconstructor Reconstructor::Create( + std::size_t threshold, SecurityLevel default_security_level) { + Reconstructor intr(threshold, default_security_level); + intr.ComputeLagrangeCoefficients(0); + return intr; } +template +T Reconstructor::Reconstruct(const Vec& shares, + SecurityLevel security_level, int index) const { + if (security_level == SecurityLevel::CORRECT) { + // TODO(anders): Currently only supports index = 0 + return details::ReconstructShamirRobust(shares, mThreshold); + } + + const auto min_size = mThreshold + 1; + if (shares.Size() < min_size) { + throw std::invalid_argument("not enough shares to reconstruct"); + } + + const auto& coeff = GetLagrangeCoefficients(index); + auto x = details::UncheckedInnerProd( + shares.begin(), shares.begin() + min_size, coeff.begin()); + + if (security_level == SecurityLevel::DETECT) { + const auto n = 2 * mThreshold + 1; + if (shares.Size() < n) { + throw std::invalid_argument("not enough shares to detect errors"); + } + + for (std::size_t i = min_size; i < n; ++i) { + const auto& ccoeff = GetLagrangeCoefficients(i + 1); + const auto c = details::UncheckedInnerProd( + shares.begin(), shares.begin() + min_size, ccoeff.begin()); + if (c != shares[i]) { + throw std::logic_error("error detected during reconstruction"); + } + } + } + return x; +} + +template +void Reconstructor::ComputeLagrangeCoefficients(int index) const { + Vec coeff(mThreshold + 1); + const auto x = T(index); + for (std::size_t j = 0; j <= mThreshold; ++j) { + auto ell = T::One(); + const auto xj = T(j + 1); + for (std::size_t m = 0; m <= mThreshold; ++m) { + if (j != m) { + const auto xm = T(m + 1); + ell *= (x - xm) / (xj - xm); + } + } + coeff[j] = ell; + } + mLagrangeCoeff[index] = coeff; +} + +/** + * @brief A factory object for creating Shamir secret shares. + * @tparam T the finite field to use. + */ +template +class ShamirSSFactory { + public: + /** + * @brief Create a new Shamir secret share factory + * @param t the privacy threshold to use + * @param prg the PRG to use when generating new shares + * @param security_level the SecurityLevel to use. Default to + * SecurityLevel::DETECT + */ + ShamirSSFactory(std::size_t t, PRG& prg, + SecurityLevel security_level = SecurityLevel::DETECT) + : mThreshold(t), mPrg(prg), mDefaultSecurityLevel(security_level){}; + + /** + * @brief Create a new set of Shamir secret shares of a secret + * @param secret the secret + * @param number_of_shares the number of shares to generate. If empty, then + * the number of shares is determined by GetDefaultNumberOfShares + * @return a vector of shares. + */ + Vec Share(const T& secret, + std::optional number_of_shares = {}); + + /** + * @brief Get the number of shares to create based on set security level + * + * The returned number is determined based on the \ref SecurityLevel + * provided during construction. Given a threshold \f$t\f$, then the number of + * shares is computed based on the following rules: + */ + std::size_t GetDefaultNumberOfShares() const { + switch (mDefaultSecurityLevel) { + case SecurityLevel::PASSIVE: + return mThreshold + 1; + case SecurityLevel::DETECT: + return 2 * mThreshold + 1; + default: // SecurityLevel::ROBUST: + return 3 * mThreshold + 1; + } + }; + + /** + * @brief Get an scl::Interpolator suitable for this factory. + */ + Reconstructor GetInterpolator() const { + return Reconstructor::Create(mThreshold, mDefaultSecurityLevel); + }; + + private: + std::size_t mThreshold; + PRG mPrg; + SecurityLevel mDefaultSecurityLevel; +}; + +template +Vec ShamirSSFactory::Share(const T& secret, + std::optional number_of_shares) { + auto coeff = Vec::PartialRandom( + mThreshold + 1, [](std::size_t i) { return i > 0; }, mPrg); + coeff[0] = secret; + auto p = details::Polynomial::Create(coeff); + + auto n = number_of_shares.value_or(GetDefaultNumberOfShares()); + std::vector shares; + shares.reserve(n); + for (std::size_t i = 0; i < n; ++i) { + shares.emplace_back(p.Evaluate(T(i + 1))); + } + return Vec(shares); +} + +} // namespace details } // namespace scl #endif // SCL_SS_SHAMIR_H diff --git a/src/scl/hash.cc b/src/scl/hash.cc index 5706436..6c34d9b 100644 --- a/src/scl/hash.cc +++ b/src/scl/hash.cc @@ -29,13 +29,16 @@ void scl::Keccakf(uint64_t state[25]) { uint64_t bc[5]; for (std::size_t round = 0; round < 24; ++round) { - for (std::size_t i = 0; i < 5; ++i) + for (std::size_t i = 0; i < 5; ++i) { bc[i] = state[i] ^ state[i + 5] ^ state[i + 10] ^ state[i + 15] ^ state[i + 20]; + } for (std::size_t i = 0; i < 5; ++i) { t = bc[(i + 4) % 5] ^ rotl64(bc[(i + 1) % 5], 1); - for (std::size_t j = 0; j < 25; j += 5) state[j + i] ^= t; + for (std::size_t j = 0; j < 25; j += 5) { + state[j + i] ^= t; + } } t = state[1]; @@ -47,9 +50,12 @@ void scl::Keccakf(uint64_t state[25]) { } for (std::size_t j = 0; j < 25; j += 5) { - for (std::size_t i = 0; i < 5; ++i) bc[i] = state[j + i]; - for (std::size_t i = 0; i < 5; ++i) + for (std::size_t i = 0; i < 5; ++i) { + bc[i] = state[j + i]; + } + for (std::size_t i = 0; i < 5; ++i) { state[j + i] ^= (~bc[(i + 1) % 5]) & bc[(i + 2) % 5]; + } } state[0] ^= keccakf_rndc[round]; diff --git a/src/scl/math/mersenne127.cc b/src/scl/math/mersenne127.cc index 3b9962e..d024a5a 100644 --- a/src/scl/math/mersenne127.cc +++ b/src/scl/math/mersenne127.cc @@ -57,8 +57,10 @@ struct u256 { // https://cp-algorithms.com/algebra/montgomery_multiplication.html u256 MultiplyFull(const u128 x, const u128 y) { - u64 a = x >> 64, b = x; - u64 c = y >> 64, d = y; + u64 a = x >> 64; + u64 b = x; + u64 c = y >> 64; + u64 d = y; // (a*2^64 + b) * (c*2^64 + d) = // (a*c) * 2^128 + (a*d + b*c)*2^64 + (b*d) u128 ac = (u128)a * c; @@ -66,9 +68,9 @@ u256 MultiplyFull(const u128 x, const u128 y) { u128 bc = (u128)b * c; u128 bd = (u128)b * d; - u128 carry = (u128)(u64)ad + (u128)(u64)bc + (bd >> 64u); - u128 high = ac + (ad >> 64u) + (bc >> 64u) + (carry >> 64u); - u128 low = (ad << 64u) + (bc << 64u) + bd; + u128 carry = (u128)(u64)ad + (u128)(u64)bc + (bd >> 64U); + u128 high = ac + (ad >> 64U) + (bc >> 64U) + (carry >> 64U); + u128 low = (ad << 64U) + (bc << 64U) + bd; return {high, low}; } @@ -115,8 +117,8 @@ std::string scl::details::FieldToString(const u128& in) { } template <> -void scl::details::FieldFromString(u128& dest, - const std::string& str) { - dest = FromHexString(str); - dest = dest % p; +void scl::details::FieldFromString(u128& out, + const std::string& src) { + out = FromHexString(src); + out = out % p; } diff --git a/src/scl/math/mersenne61.cc b/src/scl/math/mersenne61.cc index 2a72759..2866028 100644 --- a/src/scl/math/mersenne61.cc +++ b/src/scl/math/mersenne61.cc @@ -90,8 +90,8 @@ std::string scl::details::FieldToString(const u64& in) { } template <> -void scl::details::FieldFromString(u64& dest, - const std::string& str) { - dest = FromHexString(str); - dest = dest % p; +void scl::details::FieldFromString(u64& out, + const std::string& src) { + out = FromHexString(src); + out = out % p; } diff --git a/src/scl/math/number.cc b/src/scl/math/number.cc index 8a46cf6..0134768 100644 --- a/src/scl/math/number.cc +++ b/src/scl/math/number.cc @@ -31,7 +31,7 @@ scl::Number::Number(const Number& number) : Number() { mpz_set(mValue, number.mValue); } -scl::Number::Number(Number&& number) : Number() { +scl::Number::Number(Number&& number) noexcept : Number() { mpz_set(mValue, number.mValue); } @@ -47,7 +47,7 @@ scl::Number scl::Number::Random(std::size_t bits, PRG& prg) { scl::Number r; mpz_import(r.mValue, len - 1, 1, 1, 0, 0, data.get() + 1); - if (data[0] & 1) { + if ((data[0] & 1) != 0) { mpz_neg(r.mValue, r.mValue); } return r; @@ -92,23 +92,23 @@ scl::Number scl::Number::operator/(const Number& number) const { } // LCOV_EXCL_LINE scl::Number scl::Number::operator<<(int shift) const { + scl::Number shifted; if (shift < 0) { - return operator>>(-shift); + shifted = operator>>(-shift); } else { - scl::Number shifted; mpz_mul_2exp(shifted.mValue, mValue, shift); - return shifted; } + return shifted; } scl::Number scl::Number::operator>>(int shift) const { + scl::Number shifted; if (shift < 0) { - return operator<<(-shift); + shifted = operator<<(-shift); } else { - scl::Number shifted; mpz_tdiv_q_2exp(shifted.mValue, mValue, shift); - return shifted; } + return shifted; } scl::Number scl::Number::operator^(const Number& number) const { diff --git a/src/scl/math/ops_gmp_ff.h b/src/scl/math/ops_gmp_ff.h index 4d947ee..704d14a 100644 --- a/src/scl/math/ops_gmp_ff.h +++ b/src/scl/math/ops_gmp_ff.h @@ -36,11 +36,11 @@ namespace details { #define BITS_PER_LIMB static_cast(mp_bits_per_limb) #define BYTES_PER_LIMB sizeof(mp_limb_t) -#define SCL_COPY(out, in, size) \ - do { \ - for (std::size_t i = 0; i < size; ++i) { \ - *((out) + i) = *((in) + i); \ - } \ +#define SCL_COPY(out, in, size) \ + do { \ + for (std::size_t i = 0; i < (size); ++i) { \ + *((out) + i) = *((in) + i); \ + } \ } while (0) /** @@ -198,8 +198,8 @@ inline bool TestBit(const mp_limb_t* v, std::size_t pos) { template void ModExp(mp_limb_t* out, const mp_limb_t* x, const mp_limb_t* e, const mp_limb_t* mod, const mp_limb_t* np) { - auto n = mpn_scan1(e, N * BITS_PER_LIMB); - for (int i = n - 1; i >= 0; --i) { + auto n = mpn_sizeinbase(e, N, 2); + for (std::size_t i = n; i-- > 0;) { ModSqr(out, out, mod, np); if (TestBit(e, i)) { ModMul(out, x, mod, np); @@ -262,25 +262,29 @@ std::string ToString(const mp_limb_t* val, const mp_limb_t* mod, } } auto s = ss.str(); + // trim leading 0s auto n = FindFirstNonZero(s); if (n > 0) { s = s.substr(n, s.length() - 1); } + if (s.length()) { return s; - } else { - return "0"; } + + return "0"; } template void FromString(mp_limb_t* out, const mp_limb_t* mod, const std::string& str) { if (str.length()) { - auto n = str.length(); - if (n > 64) { + auto n_ = str.length(); + if (n_ > 64) { throw std::invalid_argument("hex string too large to parse"); } + // to silence conversion errors. Safe to do because n_ is pretty small. + int n = static_cast(n_); std::string s = str; if (n % 2) { @@ -288,10 +292,10 @@ void FromString(mp_limb_t* out, const mp_limb_t* mod, const std::string& str) { n++; } - const auto m = 2 * BYTES_PER_LIMB; + const auto m = static_cast(2 * BYTES_PER_LIMB); int c = (n - 1) / m; auto beg = s.begin(); - for (std::size_t i = 0; i < n && c >= 0; i += m) { + for (int i = 0; i < n && c >= 0; i += m) { auto end = std::min(n, i + m); out[c--] = FromHexString(std::string(beg + i, beg + end)); } @@ -299,7 +303,6 @@ void FromString(mp_limb_t* out, const mp_limb_t* mod, const std::string& str) { } } -#undef SCL_COPY #undef BITS_PER_LIMB #undef BYTES_PER_LIMB diff --git a/src/scl/math/ops_small_fp.h b/src/scl/math/ops_small_fp.h index 3eeb39b..0c5b2f2 100644 --- a/src/scl/math/ops_small_fp.h +++ b/src/scl/math/ops_small_fp.h @@ -32,7 +32,9 @@ namespace details { template void ModAdd(T& t, const T& v, const T& m) { t = t + v; - if (t >= m) t = t - m; + if (t >= m) { + t = t - m; + } } /** @@ -40,10 +42,11 @@ void ModAdd(T& t, const T& v, const T& m) { */ template void ModSub(T& t, const T& v, const T& m) { - if (v > t) + if (v > t) { t = t + m - v; - else + } else { t = t - v; + } } /** @@ -51,7 +54,9 @@ void ModSub(T& t, const T& v, const T& m) { */ template void ModNeg(T& t, const T& m) { - if (t) t = m - t; + if (t) { + t = m - t; + } } /** @@ -62,8 +67,8 @@ void ModInv(T& t, const T& v, const T& m) { #define SCL_PARALLEL_ASSIGN(v1, v2, q) \ do { \ const auto __temp = v2; \ - v2 = v1 - q * __temp; \ - v1 = __temp; \ + (v2) = (v1) - (q)*__temp; \ + (v1) = __temp; \ } while (0) if (v == 0) { diff --git a/src/scl/math/secp256k1_curve.cc b/src/scl/math/secp256k1_curve.cc index c5d35ca..afa5b9f 100644 --- a/src/scl/math/secp256k1_curve.cc +++ b/src/scl/math/secp256k1_curve.cc @@ -18,6 +18,7 @@ * along with this program. If not, see . */ +#include #include #include "./secp256k1_extras.h" @@ -34,9 +35,9 @@ using Point = Curve::ValueType; #define POINT_AT_INFINITY Point{{Field{0}, Field{1}, Field{0}}} // clang-format on -#define _X(point) (point)[0] -#define _Y(point) (point)[1] -#define _Z(point) (point)[2] +#define GET_X(point) (point)[0] +#define GET_Y(point) (point)[1] +#define GET_Z(point) (point)[2] static const Field kCurveB(7); @@ -66,42 +67,37 @@ void scl::details::CurveSetAffine(Point& out, const Field& x, } template <> -bool scl::details::CurveEqual(const Point& in1, const Point& in2) { - const auto Z1 = _Z(in1); - const auto Z2 = _Z(in2); - // (X1, Y1, Z1) eqv (X2, Y2, Z2) <==> (X1 * Z2, Y1 * Z2) == (X2 * Z1, Y2 * Z2) - return _X(in1) * Z2 == _X(in2) * Z1 && _Y(in1) * Z2 == _Y(in2) * Z1; +std::array scl::details::CurveToAffine(const Point& point) { + const auto Z = GET_Z(point); + return {GET_X(point) / Z, GET_Y(point) / Z}; } template <> -bool scl::details::CurveIsPointAtInfinity(const Point& out) { - return CurveEqual(out, POINT_AT_INFINITY); +bool scl::details::CurveEqual(const Point& in1, const Point& in2) { + const auto Z1 = GET_Z(in1); + const auto Z2 = GET_Z(in2); + // (X1, Y1, Z1) eqv (X2, Y2, Z2) <==> (X1 * Z2, Y1 * Z2) == (X2 * Z1, Y2 * Z2) + return GET_X(in1) * Z2 == GET_X(in2) * Z1 && + GET_Y(in1) * Z2 == GET_Y(in2) * Z1; } -struct AffinePoint { - Field x; - Field y; -}; - -namespace { - -AffinePoint ToAffine(const Point& point) { - const auto Z = _Z(point); - return {_X(point) / Z, _Y(point) / Z}; +template <> +bool scl::details::CurveIsPointAtInfinity(const Point& point) { + return CurveEqual(point, POINT_AT_INFINITY); } -} // namespace - template <> std::string scl::details::CurveToString(const Point& point) { + std::string str; if (CurveIsPointAtInfinity(point)) { - return "EC{POINT_AT_INFINITY}"; + str = "EC{POINT_AT_INFINITY}"; } else { - auto ap = ToAffine(point); + auto ap = CurveToAffine(point); std::stringstream ss; - ss << "EC{" << ap.x << ", " << ap.y << "}"; - return ss.str(); + ss << "EC{" << ap[0] << ", " << ap[1] << "}"; + str = ss.str(); } + return str; } template <> @@ -119,12 +115,12 @@ void scl::details::CurveSetGenerator(Point& out) { template <> void scl::details::CurveDouble(Point& out) { if (!CurveIsPointAtInfinity(out)) { - if (_Y(out) == Field::Zero()) { + if (GET_Y(out) == Field::Zero()) { CurveSetPointAtInfinity(out); } else if (!CurveIsPointAtInfinity(out)) { - const auto X = _X(out); - const auto Y = _Y(out); - const auto Z = _Z(out); + const auto X = GET_X(out); + const auto Y = GET_Y(out); + const auto Z = GET_Z(out); const auto W = Field(3) * X * X; const auto S = Y * Z; @@ -141,16 +137,16 @@ void scl::details::CurveDouble(Point& out) { } template <> -void scl::details::CurveAdd(Point& out, const Point& op) { +void scl::details::CurveAdd(Point& out, const Point& in) { if (CurveIsPointAtInfinity(out)) { - out = op; - } else if (!CurveIsPointAtInfinity(op)) { - const auto X1 = _X(out); - const auto Y1 = _Y(out); - const auto Z1 = _Z(out); - const auto X2 = _X(op); - const auto Y2 = _Y(op); - const auto Z2 = _Z(op); + out = in; + } else if (!CurveIsPointAtInfinity(in)) { + const auto X1 = GET_X(out); + const auto Y1 = GET_Y(out); + const auto Z1 = GET_Z(out); + const auto X2 = GET_X(in); + const auto Y2 = GET_Y(in); + const auto Z2 = GET_Z(in); const auto U1 = Y2 * Z1; const auto U2 = Y1 * Z2; @@ -180,16 +176,16 @@ void scl::details::CurveAdd(Point& out, const Point& op) { template <> void scl::details::CurveNegate(Point& out) { - if (_Y(out) == Field::Zero()) { + if (GET_Y(out) == Field::Zero()) { CurveSetPointAtInfinity(out); } else { - _Y(out).Negate(); + GET_Y(out).Negate(); } } template <> -void scl::details::CurveSubtract(Point& out, const Point& op) { - Point copy(op); +void scl::details::CurveSubtract(Point& out, const Point& in) { + Point copy(in); CurveNegate(copy); CurveAdd(out, copy); } @@ -201,7 +197,8 @@ void scl::details::CurveScalarMultiply(Point& out, const auto n = scalar.BitSize(); Point res; CurveSetPointAtInfinity(res); - for (int i = n - 1; i >= 0; --i) { + // equivalent to for (int i = n - 1; i >= 0; i--) + for (auto i = n; i-- > 0;) { CurveDouble(res); if (scalar.TestBit(i)) { CurveAdd(res, out); @@ -215,12 +212,13 @@ template <> void scl::details::CurveScalarMultiply(Point& out, const FF& scalar) { if (!CurveIsPointAtInfinity(out)) { - const auto n = scl::SCL_FF_Extras::HigestSetBit(scalar); + auto x = scl::SCL_FF_Extras::FromMonty(scalar); + const auto n = scl::SCL_FF_Extras::HigestSetBit(x); Point res; CurveSetPointAtInfinity(res); - for (int i = n - 1; i >= 0; --i) { + for (auto i = n; i-- > 0;) { CurveDouble(res); - if (scl::SCL_FF_Extras::TestBit(scalar, i)) { + if (scl::SCL_FF_Extras::TestBit(x, i)) { CurveAdd(res, out); } } @@ -232,9 +230,9 @@ void scl::details::CurveScalarMultiply(Point& out, #define POINT_AT_INFINITY_FLAG 0x02 #define SELECT_SMALLER_FLAG 0x01 -#define IS_COMPRESSED(flags) (flags & COMPRESSED_FLAG) -#define IS_POINT_AT_INFINITY(flags) (flags & POINT_AT_INFINITY_FLAG) -#define SELECT_SMALLER(flags) (flags & SELECT_SMALLER_FLAG) +#define IS_COMPRESSED(flags) ((flags)&COMPRESSED_FLAG) +#define IS_POINT_AT_INFINITY(flags) ((flags)&POINT_AT_INFINITY_FLAG) +#define SELECT_SMALLER(flags) ((flags)&SELECT_SMALLER_FLAG) namespace { @@ -272,9 +270,9 @@ void scl::details::CurveFromBytes(Point& out, const unsigned char* src) { auto smaller = IsSmaller(y, yn); auto select_smaller = SELECT_SMALLER(flags); if (smaller) { - out[1] = select_smaller ? y : yn; + out[1] = select_smaller == 0 ? yn : y; } else { - out[1] = select_smaller ? yn : y; + out[1] = select_smaller == 0 ? y : yn; } } else { out[0] = Field::Read(src + 1); @@ -284,41 +282,43 @@ void scl::details::CurveFromBytes(Point& out, const unsigned char* src) { } } -#define MARK_COMPRESSED(buf) (*buf |= COMPRESSED_FLAG) -#define MARK_POINT_AT_INFINITY(buf) (*buf |= POINT_AT_INFINITY_FLAG) -#define MARK_SELECT_SMALLER(buf) (*buf |= SELECT_SMALLER_FLAG) +#define MARK_COMPRESSED(buf) (*(buf) |= COMPRESSED_FLAG) +#define MARK_POINT_AT_INFINITY(buf) (*(buf) |= POINT_AT_INFINITY_FLAG) +#define MARK_SELECT_SMALLER(buf) (*(buf) |= SELECT_SMALLER_FLAG) template <> -void scl::details::CurveToBytes(unsigned char* dest, const Point& point, +void scl::details::CurveToBytes(unsigned char* dest, const Point& in, bool compress) { // Make sure flag byte is zeroed. *dest = 0; // if point is compressed, mark it as such. - if (compress) MARK_COMPRESSED(dest); + if (compress) { + MARK_COMPRESSED(dest); + } - if (CurveIsPointAtInfinity(point)) { + if (CurveIsPointAtInfinity(in)) { MARK_POINT_AT_INFINITY(dest); // zero rest of the buffer to ensure we can always safely send the right // amount of bytes. std::memset(dest + 1, 0, compress ? 32 : 64); } else { - const auto ap = ToAffine(point); + const auto ap = CurveToAffine(in); // if compression is used, we indicate a bit indicating which of {y, -y} is // the smaller, and the only write the x coordinate. Otherwise we write both // x and y. if (compress) { // include a flag which indicates which of {y, -y} is the smaller. - const auto y = ap.y; + const auto y = ap[1]; const auto yn = y.Negated(); if (IsSmaller(y, yn)) { MARK_SELECT_SMALLER(dest); } - ap.x.Write(dest + 1); + ap[0].Write(dest + 1); } else { - ap.x.Write(dest + 1); - ap.y.Write(dest + 1 + Field::ByteSize()); + ap[0].Write(dest + 1); + ap[1].Write(dest + 1 + Field::ByteSize()); } } } diff --git a/src/scl/math/secp256k1_extras.h b/src/scl/math/secp256k1_extras.h index 40fbffa..c6d5844 100644 --- a/src/scl/math/secp256k1_extras.h +++ b/src/scl/math/secp256k1_extras.h @@ -42,6 +42,10 @@ struct SCL_FF_Extras { template <> struct SCL_FF_Extras { + // Convert an element out of montgomery representation + static scl::FF FromMonty( + const scl::FF& element); + // Get position of the highest set bit static std::size_t HigestSetBit( const scl::FF& element); diff --git a/src/scl/math/secp256k1_field.cc b/src/scl/math/secp256k1_field.cc index 20d4506..c78ddcd 100644 --- a/src/scl/math/secp256k1_field.cc +++ b/src/scl/math/secp256k1_field.cc @@ -18,6 +18,7 @@ * along with this program. If not, see . */ +#include #include #include "./ops_gmp_ff.h" @@ -94,24 +95,24 @@ void scl::details::FieldInvert(Elem& out) { } template <> -bool scl::details::FieldEqual(const Elem& first, const Elem& second) { - return CompareValues(PTR(first), PTR(second)) == 0; +bool scl::details::FieldEqual(const Elem& in1, const Elem& in2) { + return CompareValues(PTR(in1), PTR(in2)) == 0; } template <> -void scl::details::FieldFromBytes(Elem& out, const unsigned char* src) { - ValueFromBytes(PTR(out), src); +void scl::details::FieldFromBytes(Elem& dest, const unsigned char* src) { + ValueFromBytes(PTR(dest), src); } template <> -std::string scl::details::FieldToString(const Elem& element) { - return ToString(PTR(element), kPrime, kMontyN); +std::string scl::details::FieldToString(const Elem& in) { + return ToString(PTR(in), kPrime, kMontyN); } template <> -void scl::details::FieldFromString(Elem& out, const std::string& str) { +void scl::details::FieldFromString(Elem& out, const std::string& src) { out = {0}; - FromString(PTR(out), kPrime, str); + FromString(PTR(out), kPrime, src); } bool scl::SCL_FF_Extras::IsSmaller( diff --git a/src/scl/math/secp256k1_order.cc b/src/scl/math/secp256k1_order.cc index b616b0f..8a78a93 100644 --- a/src/scl/math/secp256k1_order.cc +++ b/src/scl/math/secp256k1_order.cc @@ -20,6 +20,7 @@ #include +#include #include #include "./ops_gmp_ff.h" @@ -94,24 +95,24 @@ void scl::details::FieldInvert(Elem& out) { } template <> -bool scl::details::FieldEqual(const Elem& first, const Elem& second) { - return CompareValues(PTR(first), PTR(second)) == 0; +bool scl::details::FieldEqual(const Elem& in1, const Elem& in2) { + return CompareValues(PTR(in1), PTR(in2)) == 0; } template <> -void scl::details::FieldFromBytes(Elem& out, const unsigned char* src) { - ValueFromBytes(PTR(out), src); +void scl::details::FieldFromBytes(Elem& dest, const unsigned char* src) { + ValueFromBytes(PTR(dest), src); } template <> -std::string scl::details::FieldToString(const Elem& element) { - return ToString(PTR(element), kPrime, kMontyN); +std::string scl::details::FieldToString(const Elem& in) { + return ToString(PTR(in), kPrime, kMontyN); } template <> -void scl::details::FieldFromString(Elem& out, const std::string& str) { +void scl::details::FieldFromString(Elem& out, const std::string& src) { out = {0}; - FromString(PTR(out), kPrime, str); + FromString(PTR(out), kPrime, src); } std::size_t scl::SCL_FF_Extras::HigestSetBit( @@ -127,4 +128,16 @@ bool scl::SCL_FF_Extras::TestBit(const scl::FF& element, return ((element.mValue[limb] >> limb_pos) & 1) == 1; } +scl::FF scl::SCL_FF_Extras::FromMonty( + const scl::FF& element) { + mp_limb_t padded[2 * NUM_LIMBS] = {0}; + SCL_COPY(padded, PTR(element.mValue), NUM_LIMBS); + details::Redc(padded, kPrime, kMontyN); + + scl::FF r; + SCL_COPY(PTR(r.mValue), padded, NUM_LIMBS); + + return r; +} + #undef ONE diff --git a/src/scl/math/str.cc b/src/scl/math/str.cc index 8e6c227..32bdb74 100644 --- a/src/scl/math/str.cc +++ b/src/scl/math/str.cc @@ -24,8 +24,9 @@ template <> std::string scl::details::ToHexString(const __uint128_t& v) { + std::string str; if (v == 0) { - return "0"; + str = "0"; } else { std::stringstream ss; auto top = static_cast(v >> 64); @@ -35,6 +36,7 @@ std::string scl::details::ToHexString(const __uint128_t& v) { ss << top; } ss << bot; - return ss.str(); + str = ss.str(); } + return str; } diff --git a/src/scl/net/config.cc b/src/scl/net/config.cc index bddc0bd..356752e 100644 --- a/src/scl/net/config.cc +++ b/src/scl/net/config.cc @@ -26,15 +26,22 @@ #include #include -static inline void ValidateIdAndSize(unsigned id, std::size_t n) { - if (!n) { +namespace { + +void ValidateIdAndSize(unsigned id, std::size_t n) { + if (n == 0) { throw std::invalid_argument("n cannot be zero"); - } else if (n <= id) { + } + + if (n <= id) { throw std::invalid_argument("invalid id"); } } -scl::NetworkConfig scl::NetworkConfig::Load(unsigned id, std::string filename) { +} // namespace + +scl::NetworkConfig scl::NetworkConfig::Load(int id, + const std::string& filename) { std::ifstream file(filename); if (!file.is_open()) { @@ -45,14 +52,17 @@ scl::NetworkConfig scl::NetworkConfig::Load(unsigned id, std::string filename) { std::vector info; while (std::getline(file, line)) { - auto a = line.find(','); - auto b = line.rfind(','); + auto a_ = line.find(','); + auto b_ = line.rfind(','); - if (a == std::string::npos || a == b) { + if (a_ == std::string::npos || a_ == b_) { throw std::invalid_argument("invalid entry in config file"); } - auto id = (unsigned)std::stoul(std::string(line.begin(), line.begin() + a)); + auto a = static_cast(a_); + auto b = static_cast(b_); + + auto id = std::stoi(std::string(line.begin(), line.begin() + a)); auto hostname = std::string(line.begin() + a + 1, line.begin() + b); auto port = std::stoi(std::string(line.begin() + b + 1, line.end())); info.emplace_back(Party{id, hostname, port}); @@ -63,14 +73,14 @@ scl::NetworkConfig scl::NetworkConfig::Load(unsigned id, std::string filename) { return NetworkConfig(id, info); } -scl::NetworkConfig scl::NetworkConfig::Localhost(unsigned id, std::size_t n, +scl::NetworkConfig scl::NetworkConfig::Localhost(int id, int size, int port_base) { - ValidateIdAndSize(id, n); + ValidateIdAndSize(id, size); std::vector info; - for (std::size_t i = 0; i < n; ++i) { + for (int i = 0; i < size; ++i) { int port = port_base + i; - info.emplace_back(Party{(unsigned)i, "127.0.0.1", port}); + info.emplace_back(Party{i, "127.0.0.1", port}); } return NetworkConfig(id, info); @@ -93,13 +103,13 @@ std::string scl::NetworkConfig::ToString() const { void scl::NetworkConfig::Validate() { auto n = NetworkSize(); - if (Id() >= n) { + if (static_cast(Id()) >= n) { throw std::invalid_argument("my ID is invalid in config"); } for (std::size_t i = 0; i < n; ++i) { auto pi = mParties[i]; - if (pi.id >= n) { + if (static_cast(pi.id) >= n) { throw std::invalid_argument("invalid ID in config"); } for (std::size_t j = i + 1; j < n; ++j) { diff --git a/src/scl/net/discovery/client.cc b/src/scl/net/discovery/client.cc index 468f0f2..baf6b55 100644 --- a/src/scl/net/discovery/client.cc +++ b/src/scl/net/discovery/client.cc @@ -27,7 +27,7 @@ using Client = scl::DiscoveryClient; -scl::NetworkConfig Client::Run(unsigned id, int port) { +scl::NetworkConfig Client::Run(int id, int port) const { auto socket = scl::details::ConnectAsClient(mHostname, mPort); std::shared_ptr server = std::make_shared(socket); @@ -37,13 +37,15 @@ scl::NetworkConfig Client::Run(unsigned id, int port) { } Client::ReceiveNetworkConfig Client::SendIdAndPort::Run( - std::shared_ptr ctx) { + const std::shared_ptr& ctx) const { ctx->Send(mId); ctx->Send(mPort); return Client::ReceiveNetworkConfig{mId}; } -static inline std::string ReceiveHostname(std::shared_ptr ctx) { +namespace { + +std::string ReceiveHostname(const std::shared_ptr& ctx) { std::size_t len; ctx->Recv(len); auto buf = std::make_unique(len); @@ -51,8 +53,10 @@ static inline std::string ReceiveHostname(std::shared_ptr ctx) { return std::string(buf.get(), buf.get() + len); } +} // namespace + scl::NetworkConfig Client::ReceiveNetworkConfig::Finalize( - std::shared_ptr ctx) { + const std::shared_ptr& ctx) const { std::size_t number_of_parties; ctx->Recv(number_of_parties); diff --git a/src/scl/net/discovery/server.cc b/src/scl/net/discovery/server.cc index ec1c3ec..6d3ef6d 100644 --- a/src/scl/net/discovery/server.cc +++ b/src/scl/net/discovery/server.cc @@ -20,6 +20,7 @@ #include "scl/net/discovery/server.h" +#include #include #include @@ -28,14 +29,15 @@ using Server = scl::DiscoveryServer; -scl::NetworkConfig Server::Run(const scl::Party& me) { +scl::NetworkConfig Server::Run(const scl::Party& me) const { // one of the parties is us, which we do not connect to. - auto ssock = scl::details::CreateServerSocket(mPort, mNumberOfParties - 1); + int backlog = static_cast(mNumberOfParties - 1); + auto ssock = scl::details::CreateServerSocket(mPort, backlog); std::vector> channels; std::vector hostnames; for (std::size_t i = 0; i < mNumberOfParties; ++i) { - if (i == me.id) { + if (i == static_cast(me.id)) { channels.emplace_back(nullptr); hostnames.emplace_back(me.hostname); } else { @@ -53,15 +55,15 @@ scl::NetworkConfig Server::Run(const scl::Party& me) { } Server::SendNetworkConfig Server::CollectIdsAndPorts::Run(Server::Ctx& ctx) { - auto my_id = ctx.me.id; + auto my_id = static_cast(ctx.me.id); std::vector parties(mHostnames.size()); parties[my_id] = ctx.me; for (std::size_t i = 0; i < mHostnames.size(); ++i) { if (my_id != i) { - unsigned id; + int id; ctx.network.Party(i)->Recv(id); - if (id >= parties.size()) { + if (static_cast(id) >= parties.size()) { throw std::logic_error("received invalid party ID"); } int port; @@ -74,7 +76,9 @@ Server::SendNetworkConfig Server::CollectIdsAndPorts::Run(Server::Ctx& ctx) { return Server::SendNetworkConfig(cfg); } -static inline void SendHostname(scl::Channel* channel, std::string hostname) { +namespace { + +void SendHostname(scl::Channel* channel, const std::string& hostname) { std::size_t len = hostname.size(); const unsigned char* ptr = reinterpret_cast(hostname.c_str()); @@ -83,8 +87,7 @@ static inline void SendHostname(scl::Channel* channel, std::string hostname) { channel->Send(ptr, len); } -static inline void SendConfig(scl::Channel* channel, - const scl::NetworkConfig& config) { +void SendConfig(scl::Channel* channel, const scl::NetworkConfig& config) { channel->Send(config.NetworkSize()); for (std::size_t i = 0; i < config.NetworkSize(); ++i) { auto party = config.Parties()[i]; @@ -94,12 +97,15 @@ static inline void SendConfig(scl::Channel* channel, } } +} // namespace scl::NetworkConfig Server::SendNetworkConfig::Finalize(Server::Ctx& ctx) { std::size_t network_size = mConfig.NetworkSize(); for (std::size_t i = 0; i < network_size; ++i) { - if (i == mConfig.Id()) continue; + if (i == static_cast(mConfig.Id())) { + continue; + } - auto channel = ctx.network.Party(i); + auto* channel = ctx.network.Party(i); SendConfig(channel, mConfig); } diff --git a/src/scl/net/mem_channel.cc b/src/scl/net/mem_channel.cc index 654b6bb..72f186c 100644 --- a/src/scl/net/mem_channel.cc +++ b/src/scl/net/mem_channel.cc @@ -22,11 +22,14 @@ #include +// used to silence narrowing conversion errors. x will have type std::size_t +#define DIFF_T(x) static_cast::difference_type>((x)) + void scl::InMemoryChannel::Send(const unsigned char* src, std::size_t n) { mOut->PushBack(std::vector(src, src + n)); } -int scl::InMemoryChannel::Recv(unsigned char* dst, std::size_t n) { +std::size_t scl::InMemoryChannel::Recv(unsigned char* dst, std::size_t n) { std::size_t rem = n; // if there's any leftovers from previous calls to recv, then we retrieve @@ -34,10 +37,10 @@ int scl::InMemoryChannel::Recv(unsigned char* dst, std::size_t n) { const auto leftovers = mOverflow.size(); if (leftovers > 0) { const auto to_copy = leftovers > rem ? rem : leftovers; - auto data = mOverflow.data(); + auto* data = mOverflow.data(); std::memcpy(dst, data, to_copy); rem -= to_copy; - mOverflow = std::vector(mOverflow.begin() + to_copy, + mOverflow = std::vector(mOverflow.begin() + DIFF_T(to_copy), mOverflow.end()); } @@ -53,8 +56,8 @@ int scl::InMemoryChannel::Recv(unsigned char* dst, std::size_t n) { const auto leftovers = data.size() - to_copy; const auto old_size = mOverflow.size(); mOverflow.reserve(old_size + leftovers); - mOverflow.insert(mOverflow.begin() + old_size, data.begin() + to_copy, - data.end()); + mOverflow.insert(mOverflow.begin() + DIFF_T(old_size), + data.begin() + DIFF_T(to_copy), data.end()); } } diff --git a/src/scl/net/network.cc b/src/scl/net/network.cc index 239b23c..c7e0272 100644 --- a/src/scl/net/network.cc +++ b/src/scl/net/network.cc @@ -37,9 +37,9 @@ void scl::details::SCL_AcceptConnections( // Act as server for all clients with an ID strictly greater than ours. auto my_id = config.Id(); auto n = config.NetworkSize() - my_id - 1; - if (n) { + if (n > 0) { auto port = config.GetParty(my_id).port; - int ssock = scl::details::CreateServerSocket(port, n); + int ssock = scl::details::CreateServerSocket(port, static_cast(n)); for (std::size_t i = config.Id() + 1; i < config.NetworkSize(); ++i) { auto ac = scl::details::AcceptConnection(ssock); std::shared_ptr channel = diff --git a/src/scl/net/tcp_channel.cc b/src/scl/net/tcp_channel.cc index 20594e2..e54135c 100644 --- a/src/scl/net/tcp_channel.cc +++ b/src/scl/net/tcp_channel.cc @@ -20,10 +20,15 @@ #include "scl/net/tcp_channel.h" +#include +#include + #include "scl/net/tcp_utils.h" void scl::TcpChannel::Close() { - if (!mAlive) return; + if (!mAlive) { + return; + } const auto err = scl::details::CloseSocket(mSocket); if (err < 0) { @@ -46,13 +51,13 @@ void scl::TcpChannel::Send(const unsigned char* src, std::size_t n) { } } -int scl::TcpChannel::Recv(unsigned char* dst, std::size_t n) { +std::size_t scl::TcpChannel::Recv(unsigned char* dst, std::size_t n) { std::size_t rem = n; std::size_t offset = 0; while (rem > 0) { auto recv = scl::details::ReadFromSocket(mSocket, dst + offset, rem); - if (!recv) { + if (recv == 0) { break; } if (recv < 0) { @@ -65,3 +70,17 @@ int scl::TcpChannel::Recv(unsigned char* dst, std::size_t n) { return n - rem; } + +bool scl::TcpChannel::HasData() { + struct pollfd fds { + mSocket, POLLIN, 0 + }; + + auto r = poll(&fds, 1, 0); + + if (r < 0) { + SCL_THROW_SYS_ERROR("hasData failed"); + } + + return r > 0 && fds.revents == POLLIN; +} diff --git a/src/scl/net/tcp_utils.cc b/src/scl/net/tcp_utils.cc index 7105fe2..bec9b62 100644 --- a/src/scl/net/tcp_utils.cc +++ b/src/scl/net/tcp_utils.cc @@ -76,19 +76,19 @@ scl::details::AcceptedConnection scl::details::AcceptConnection( ::accept(server_socket, ac.socket_info.get(), (socklen_t*)&addrsize); if (ac.socket < 0) { SCL_THROW_SYS_ERROR("could not accept connection"); - } else { - return ac; } + + return ac; } std::string scl::details::GetAddress( - scl::details::AcceptedConnection connection) { + const scl::details::AcceptedConnection& connection) { struct sockaddr_in* s = reinterpret_cast(connection.socket_info.get()); return inet_ntoa(s->sin_addr); } -int scl::details::ConnectAsClient(std::string hostname, int port) { +int scl::details::ConnectAsClient(const std::string& hostname, int port) { using namespace std::chrono_literals; int sock = ::socket(AF_INET, SOCK_STREAM, 0); @@ -101,26 +101,30 @@ int scl::details::ConnectAsClient(std::string hostname, int port) { addr.sin_port = ::htons(port); int err = ::inet_pton(AF_INET, hostname.c_str(), &(addr.sin_addr)); + if (err == 0) { throw std::runtime_error("invalid hostname"); - } else if (err < 0) { + } + + if (err < 0) { SCL_THROW_SYS_ERROR("invalid address family"); } - while (::connect(sock, (struct sockaddr*)&addr, sizeof(addr)) < 0) + while (::connect(sock, (struct sockaddr*)&addr, sizeof(addr)) < 0) { std::this_thread::sleep_for(300ms); + } return sock; } int scl::details::CloseSocket(int socket) { return ::close(socket); } -int scl::details::ReadFromSocket(int socket, unsigned char* dst, - std::size_t n) { +ssize_t scl::details::ReadFromSocket(int socket, unsigned char* dst, + std::size_t n) { return ::read(socket, dst, n); } -int scl::details::WriteToSocket(int socket, const unsigned char* src, - std::size_t n) { +ssize_t scl::details::WriteToSocket(int socket, const unsigned char* src, + std::size_t n) { return ::write(socket, src, n); } diff --git a/src/scl/prg.cc b/src/scl/prg.cc index 1a172d9..d08150f 100644 --- a/src/scl/prg.cc +++ b/src/scl/prg.cc @@ -30,19 +30,19 @@ using block_t = __m128i; using std::size_t; -#define DO_ENC_BLOCK(m, k) \ - do { \ - m = _mm_xor_si128(m, k[0]); \ - m = _mm_aesenc_si128(m, k[1]); \ - m = _mm_aesenc_si128(m, k[2]); \ - m = _mm_aesenc_si128(m, k[3]); \ - m = _mm_aesenc_si128(m, k[4]); \ - m = _mm_aesenc_si128(m, k[5]); \ - m = _mm_aesenc_si128(m, k[6]); \ - m = _mm_aesenc_si128(m, k[7]); \ - m = _mm_aesenc_si128(m, k[8]); \ - m = _mm_aesenc_si128(m, k[9]); \ - m = _mm_aesenclast_si128(m, k[10]); \ +#define DO_ENC_BLOCK(m, k) \ + do { \ + (m) = _mm_xor_si128(m, (k)[0]); \ + (m) = _mm_aesenc_si128(m, (k)[1]); \ + (m) = _mm_aesenc_si128(m, (k)[2]); \ + (m) = _mm_aesenc_si128(m, (k)[3]); \ + (m) = _mm_aesenc_si128(m, (k)[4]); \ + (m) = _mm_aesenc_si128(m, (k)[5]); \ + (m) = _mm_aesenc_si128(m, (k)[6]); \ + (m) = _mm_aesenc_si128(m, (k)[7]); \ + (m) = _mm_aesenc_si128(m, (k)[8]); \ + (m) = _mm_aesenc_si128(m, (k)[9]); \ + (m) = _mm_aesenclast_si128(m, (k)[10]); \ } while (0) #define AES_128_key_exp(k, rcon) \ @@ -97,18 +97,24 @@ static inline auto create_mask(const long counter) { } void scl::PRG::Next(byte_t* dest, size_t nbytes) { - if (!nbytes) return; + if (nbytes == 0) { + return; + } size_t nblocks = nbytes / BlockSize(); - if (nbytes % BlockSize()) nblocks++; + if ((nbytes % BlockSize()) != 0) { + nblocks++; + } block_t mask = create_mask(mCounter); byte_t* out = (byte_t*)malloc(nblocks * BlockSize()); byte_t* p = out; // LCOV_EXCL_START - if (!out) throw std::runtime_error("Could not allocate memory for PRG."); + if (out == nullptr) { + throw std::runtime_error("Could not allocate memory for PRG."); + } // LCOV_EXCL_STOP for (size_t i = 0; i < nblocks; i++) { diff --git a/test/scl/gf7.cc b/test/scl/gf7.cc index 176ffdf..7148f3d 100644 --- a/test/scl/gf7.cc +++ b/test/scl/gf7.cc @@ -38,10 +38,11 @@ void scl::details::FieldAdd(unsigned char& out, const unsigned char& op) { template <> void scl::details::FieldSubtract(unsigned char& out, const unsigned char& op) { - if (out < op) + if (out < op) { out = 7 + out - op; - else + } else { out = out - op; + } } template <> @@ -95,8 +96,8 @@ void scl::details::FieldFromBytes(unsigned char& dest, } template <> -std::string scl::details::FieldToString(const unsigned char& value) { +std::string scl::details::FieldToString(const unsigned char& in) { std::stringstream ss; - ss << (int)value; + ss << (int)in; return ss.str(); } diff --git a/test/scl/gf7.h b/test/scl/gf7.h index 0c68f94..ec9e9c9 100644 --- a/test/scl/gf7.h +++ b/test/scl/gf7.h @@ -18,8 +18,8 @@ * along with this program. If not, see . */ -#ifndef _SCUTIL_MATH_GF7_H -#define _SCUTIL_MATH_GF7_H +#ifndef TEST_SCL_GF7_H +#define TEST_SCL_GF7_H #include #include @@ -37,4 +37,4 @@ struct GF7 { } // namespace details } // namespace scl -#endif /* _SCUTIL_MATH_GF7_H */ +#endif /* TEST_SCL_GF7_H */ diff --git a/test/scl/math/test_ff.cc b/test/scl/math/test_ff.cc index 41b932a..2f22384 100644 --- a/test/scl/math/test_ff.cc +++ b/test/scl/math/test_ff.cc @@ -31,17 +31,21 @@ using GF7 = scl::FF; #ifdef SCL_ENABLE_EC_TESTS using Secp256k1_Field = scl::FF; +using Secp256k1_Order = scl::FF; #endif template T RandomNonZero(scl::PRG& prg) { auto a = T::Random(prg); for (std::size_t i = 0; i < 10; ++i) { - if (a == T::Zero()) a = T::Random(prg); + if (a == T::Zero()) { + a = T::Random(prg); + } break; } - if (a == T::Zero()) + if (a == T::Zero()) { throw std::logic_error("could not generate a non-zero random value"); + } return a; } @@ -59,7 +63,7 @@ GF7 RandomNonZero(scl::PRG& prg) { #define REPEAT for (std::size_t i = 0; i < 50; ++i) #ifdef SCL_ENABLE_EC_TESTS -#define ARG_LIST Mersenne61, Mersenne127, GF7, Secp256k1_Field +#define ARG_LIST Mersenne61, Mersenne127, GF7, Secp256k1_Field, Secp256k1_Order #else #define ARG_LIST Mersenne61, Mersenne127, GF7 #endif diff --git a/test/scl/math/test_mat.cc b/test/scl/math/test_mat.cc index f622896..4410098 100644 --- a/test/scl/math/test_mat.cc +++ b/test/scl/math/test_mat.cc @@ -26,7 +26,7 @@ using F = scl::Fp<61>; using Mat = scl::Mat; -inline void Populate(Mat& m, unsigned values[]) { +inline void Populate(Mat& m, const int* values) { for (std::size_t i = 0; i < m.Rows(); i++) { for (std::size_t j = 0; j < m.Cols(); j++) { m(i, j) = F(values[i * m.Cols() + j]); @@ -36,10 +36,10 @@ inline void Populate(Mat& m, unsigned values[]) { TEST_CASE("Matrix", "[math]") { Mat m0(2, 2); - unsigned v0[] = {1, 2, 5, 6}; + int v0[] = {1, 2, 5, 6}; Populate(m0, v0); Mat m1(2, 2); - unsigned v1[] = {4, 3, 2, 1}; + int v1[] = {4, 3, 2, 1}; Populate(m1, v1); REQUIRE(!m0.Equals(m1)); @@ -59,7 +59,7 @@ TEST_CASE("Matrix", "[math]") { SECTION("ToString") { Mat m(3, 2); - unsigned v[] = {1, 2, 44444, 5, 6, 7}; + int v[] = {1, 2, 44444, 5, 6, 7}; Populate(m, v); std::string expected = "\n" @@ -133,16 +133,16 @@ TEST_CASE("Matrix", "[math]") { REQUIRE(m2(1, 1) == F(21)); Mat m3(2, 10); - unsigned v3[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 0, - 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}; + int v3[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 0, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}; Populate(m3, v3); auto m5 = m0.Multiply(m3); REQUIRE(m5.Rows() == 2); REQUIRE(m5.Cols() == 10); Mat m4(2, 10); - unsigned v4[] = {23, 26, 29, 32, 35, 38, 41, 44, 47, 40, - 71, 82, 93, 104, 115, 126, 137, 148, 159, 120}; + int v4[] = {23, 26, 29, 32, 35, 38, 41, 44, 47, 40, + 71, 82, 93, 104, 115, 126, 137, 148, 159, 120}; Populate(m4, v4); REQUIRE(m5.Equals(m4)); @@ -167,7 +167,7 @@ TEST_CASE("Matrix", "[math]") { SECTION("Transpose") { Mat m3(2, 3); - unsigned v3[] = {1, 2, 3, 11, 12, 13}; + int v3[] = {1, 2, 3, 11, 12, 13}; Populate(m3, v3); auto m4 = m3.Transpose(); REQUIRE(m4.Rows() == m3.Cols()); @@ -244,10 +244,11 @@ TEST_CASE("Matrix", "[math]") { bool good = true; for (std::size_t i = 0; i < 10; ++i) { for (std::size_t j = 0; j < 10; ++j) { - if (i == j) + if (i == j) { good &= A(i, j) == F(1); - else + } else { good &= A(i, j) == F(0); + } } } REQUIRE(good); diff --git a/test/scl/math/test_secp256k1.cc b/test/scl/math/test_secp256k1.cc index 60cf847..d9053b5 100644 --- a/test/scl/math/test_secp256k1.cc +++ b/test/scl/math/test_secp256k1.cc @@ -25,6 +25,7 @@ #include "scl/math/curves/secp256k1.h" #include "scl/math/ec_ops.h" #include "scl/math/fp.h" +#include "scl/math/number.h" #include "scl/prg.h" using Curve = scl::EC; @@ -63,14 +64,17 @@ TEST_CASE("secp256k1_field", "[math]") { } SECTION("From affine") { - auto g = Curve::FromAffine( - Field::FromString("e47b4a1c2e13cf0e97c9adf5a645ce388e04317b7830401aabb4" - "2e188c9883fa"), // - Field::FromString("2aafa6e870684327ec92006e6c601a8b6e0fb9ff06ae120cb330" - "a2eee86009ff") // - ); + auto x = Field::FromString( + "e47b4a1c2e13cf0e97c9adf5a645ce388e04317b7830401aabb42e188c9883fa"); + auto y = Field::FromString( + "2aafa6e870684327ec92006e6c601a8b6e0fb9ff06ae120cb330a2eee86009ff"); + auto g = Curve::FromAffine(x, y); REQUIRE(!g.PointAtInfinity()); + auto as_affine = g.ToAffine(); + REQUIRE(as_affine[0] == x); + REQUIRE(as_affine[1] == y); + REQUIRE_THROWS_MATCHES( Curve::FromAffine(Field(0), Field(0)), std::invalid_argument, Catch::Matchers::Message("provided (x, y) not on curve")); @@ -142,6 +146,8 @@ TEST_CASE("secp256k1_field", "[math]") { } } +using Scalar = scl::FF; + static Curve RandomPoint(scl::PRG& prg) { auto r = scl::Number::Random(100, prg); return Curve::Generator() * r; @@ -203,6 +209,38 @@ TEST_CASE("secp256k1", "[math]") { REQUIRE((a + b).PointAtInfinity()); } + SECTION("Scalar mul order") { + auto a = RandomPoint(prg); + auto b = Scalar::FromString( + "fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364140"); + REQUIRE(b + Scalar::One() == Scalar::Zero()); + + auto c = a * b; + REQUIRE(!c.PointAtInfinity()); + + REQUIRE((c + a).PointAtInfinity()); + } + + SECTION("Scalar mul distributive") { + auto a = RandomPoint(prg); + auto b = Scalar::Random(prg); + auto c = Scalar::Random(prg); + REQUIRE((b + c) * a == b * a + c * a); + } + + SECTION("Scalar mul assoc") { + auto a = Curve::Generator(); + + auto v = Scalar::FromString("03"); + auto u = Scalar::FromString("02"); + auto w = Scalar::FromString("06"); + + auto P = a * w; + auto Q = (a * v) * u; + + REQUIRE(P == Q); + } + SECTION("negation exceptional") { using CurveT = scl::details::Secp256k1; CurveT::ValueType point = {Field(1), Field(0), Field(1)}; diff --git a/test/scl/math/test_vec.cc b/test/scl/math/test_vec.cc index cf8496d..a126e39 100644 --- a/test/scl/math/test_vec.cc +++ b/test/scl/math/test_vec.cc @@ -134,6 +134,13 @@ TEST_CASE("Vector", "[math]") { REQUIRE(r[2] != v0[2]); } + SECTION("Range") { + auto v = Vec::Range(1, 4); + REQUIRE(v[0] == F(1)); + REQUIRE(v[1] == F(2)); + REQUIRE(v[2] == F(3)); + } + SECTION("iterators") { auto v2 = Vec{F(1), F(2), F(3)}; std::size_t i = 0; @@ -147,4 +154,12 @@ TEST_CASE("Vector", "[math]") { auto v3 = Vec(v2.begin(), v2.end()); REQUIRE(v3.Equals(v2)); } + + SECTION("subvector") { + auto v = Vec{F(1), F(2), F(3), F(4)}; + REQUIRE(v.SubVector(1, 2) == Vec{F(2)}); + REQUIRE(v.SubVector(1, 3) == Vec{F(2), F(3)}); + REQUIRE(v.SubVector(1, 1) == Vec{}); + REQUIRE(v.SubVector(2) == Vec{F(1), F(2)}); + } } diff --git a/test/scl/math/test_z2k.cc b/test/scl/math/test_z2k.cc index 88ff42d..954d810 100644 --- a/test/scl/math/test_z2k.cc +++ b/test/scl/math/test_z2k.cc @@ -82,7 +82,7 @@ TEMPLATE_TEST_CASE("Z2k", "[math]", Z2k1, Z2k2) { #define RANDOM_INVERTIBLE(var) \ TestType var; \ - while (var.Lsb() == 0) var = TestType::Random(prg) + while ((var).Lsb() == 0) (var) = TestType::Random(prg) SECTION("inverses") { RANDOM_INVERTIBLE(a); diff --git a/test/scl/net/test_config.cc b/test/scl/net/test_config.cc index 8e1dc08..c79a45c 100644 --- a/test/scl/net/test_config.cc +++ b/test/scl/net/test_config.cc @@ -26,7 +26,7 @@ TEST_CASE("Config", "[network]") { SECTION("From file") { - auto filename = SCL_TEST_DATA_DIR "3_parties.txt"; + const auto* filename = SCL_TEST_DATA_DIR "3_parties.txt"; auto cfg = scl::NetworkConfig::Load(0, filename); REQUIRE(cfg.NetworkSize() == 3); @@ -41,22 +41,22 @@ TEST_CASE("Config", "[network]") { } SECTION("Invalid file") { - auto invalid_empty = SCL_TEST_DATA_DIR "invalid_no_entries.txt"; + const auto* invalid_empty = SCL_TEST_DATA_DIR "invalid_no_entries.txt"; REQUIRE_THROWS_MATCHES(scl::NetworkConfig::Load(0, invalid_empty), std::invalid_argument, Catch::Matchers::Message("n cannot be zero")); - auto valid = SCL_TEST_DATA_DIR "3_parties.txt"; + const auto* valid = SCL_TEST_DATA_DIR "3_parties.txt"; REQUIRE_THROWS_MATCHES(scl::NetworkConfig::Load(4, valid), std::invalid_argument, Catch::Matchers::Message("invalid id")); - auto invalid_entry = SCL_TEST_DATA_DIR "invalid_entry.txt"; + const auto* invalid_entry = SCL_TEST_DATA_DIR "invalid_entry.txt"; REQUIRE_THROWS_MATCHES( scl::NetworkConfig::Load(0, invalid_entry), std::invalid_argument, Catch::Matchers::Message("invalid entry in config file")); - auto invalid_non_existing_file = ""; + const auto* invalid_non_existing_file = ""; REQUIRE_THROWS_MATCHES( scl::NetworkConfig::Load(0, invalid_non_existing_file), std::invalid_argument, Catch::Matchers::Message("could not open file")); diff --git a/test/scl/net/test_discover.cc b/test/scl/net/test_discover.cc index f1ef4d4..7d7d09c 100644 --- a/test/scl/net/test_discover.cc +++ b/test/scl/net/test_discover.cc @@ -28,16 +28,20 @@ #include "scl/net/mem_channel.h" #include "scl/net/network.h" -static inline bool VerifyParty(scl::Party& party, unsigned id, - std::string hostname, int port) { +namespace { + +bool VerifyParty(scl::Party& party, int id, const std::string& hostname, + int port) { return party.id == id && party.hostname == hostname && party.port == port; } -static inline bool PartyEquals(scl::Party& first, scl::Party& second) { +bool PartyEquals(scl::Party& first, scl::Party& second) { return first.id == second.id && first.hostname == second.hostname && first.port == second.port; } +} // namespace + TEST_CASE("Discovery Server", "[network]") { SECTION("CollectIdsAndPorts") { std::vector hostnames = {"1.2.3.4", "4.4.4.4", "127.0.0.1"}; @@ -127,15 +131,16 @@ TEST_CASE("Discovery Server", "[network]") { } } -static inline void SendHostname(scl::Channel* channel, std::string hostname) { +namespace { + +void SendHostname(scl::Channel* channel, const std::string& hostname) { std::size_t length = hostname.size(); channel->Send(length); channel->Send(reinterpret_cast(hostname.c_str()), length); } -static inline void SendConfig(scl::Channel* channel, - const scl::NetworkConfig& config) { +void SendConfig(scl::Channel* channel, const scl::NetworkConfig& config) { channel->Send(config.NetworkSize()); for (const auto& party : config.Parties()) { channel->Send(party.id); @@ -144,6 +149,8 @@ static inline void SendConfig(scl::Channel* channel, } } +} // namespace + TEST_CASE("Discovery Client", "[network]") { SECTION("SendIdAndPort") { auto channels = scl::InMemoryChannel::CreatePaired(); @@ -196,7 +203,9 @@ TEST_CASE("Discovery", "[network]") { Catch::Matchers::Message("number_of_parties exceeds max")); } - scl::NetworkConfig c0, c1, c2; + scl::NetworkConfig c0; + scl::NetworkConfig c1; + scl::NetworkConfig c2; std::thread server([&c0]() { Server srv(9999, 3); diff --git a/test/scl/net/test_mem_channel.cc b/test/scl/net/test_mem_channel.cc index 79859a7..2c01b89 100644 --- a/test/scl/net/test_mem_channel.cc +++ b/test/scl/net/test_mem_channel.cc @@ -26,14 +26,9 @@ #include "scl/math.h" #include "scl/net/mem_channel.h" #include "scl/prg.h" +#include "util.h" -static inline bool Eq(const unsigned char* p, const unsigned char* q, int n) { - while (n-- > 0 && *p++ == *q++) - ; - return n < 0; -} - -static inline void PrintBuf(const unsigned char* b, std::size_t n) { +void PrintBuf(const unsigned char* b, std::size_t n) { for (std::size_t i = 0; i < n; ++i) { std::cout << (int)b[i] << " "; } @@ -52,9 +47,12 @@ TEST_CASE("InMemoryChannel", "[network]") { SECTION("Send and receive") { unsigned char data_out[200] = {0}; + REQUIRE(!chl1->HasData()); chl0->Send(data_in, 200); + REQUIRE(!chl0->HasData()); + REQUIRE(chl1->HasData()); chl1->Recv(data_out, 200); - REQUIRE(Eq(data_in, data_out, 200)); + REQUIRE(scl_tests::BufferEquals(data_in, data_out, 200)); } chl0->Flush(); @@ -68,7 +66,7 @@ TEST_CASE("InMemoryChannel", "[network]") { chl0->Send(data_in + 100, 100); chl1->Recv(data_out, 200); - REQUIRE(Eq(data_in, data_out, 200)); + REQUIRE(scl_tests::BufferEquals(data_in, data_out, 200)); } chl0->Flush(); @@ -81,7 +79,7 @@ TEST_CASE("InMemoryChannel", "[network]") { chl1->Recv(data_out, 100); chl1->Recv(data_out + 100, 100); - REQUIRE(Eq(data_in, data_out, 200)); + REQUIRE(scl_tests::BufferEquals(data_in, data_out, 200)); } chl0->Flush(); @@ -105,9 +103,10 @@ TEST_CASE("InMemoryChannel", "[network]") { scl::Channel* c1 = chl1.get(); std::vector data = {1, 2, 3, 4, 11111111}; c0->Send(data); - std::vector recv(data.size()); + std::vector recv; c1->Recv(recv); REQUIRE(data == recv); + REQUIRE(recv.size() == data.size()); } using FF = scl::Fp<61>; @@ -153,6 +152,6 @@ TEST_CASE("InMemoryChannel", "[network]") { c->Recv(data_out + 10, 100); c->Recv(data_out + 110, 90); - REQUIRE(Eq(data_in, data_out, 200)); + REQUIRE(scl_tests::BufferEquals(data_in, data_out, 200)); } } diff --git a/test/scl/net/test_network.cc b/test/scl/net/test_network.cc index e5196bf..f23f6cb 100644 --- a/test/scl/net/test_network.cc +++ b/test/scl/net/test_network.cc @@ -71,7 +71,9 @@ TEST_CASE("Network", "[network]") { } SECTION("TCP") { - scl::Network network0, network1, network2; + scl::Network network0; + scl::Network network1; + scl::Network network2; std::thread t0([&]() { network0 = scl::Network::Create( diff --git a/test/scl/net/test_tcp_channel.cc b/test/scl/net/test_tcp_channel.cc index 82cecce..0efa63a 100644 --- a/test/scl/net/test_tcp_channel.cc +++ b/test/scl/net/test_tcp_channel.cc @@ -30,11 +30,14 @@ TEST_CASE("TcpChannel", "[network]") { SECTION("Connect and Close") { auto port = scl_tests::GetPort(); - std::shared_ptr client, server; + std::shared_ptr client; + std::shared_ptr server; + std::thread clt([&]() { int socket = scl::details::ConnectAsClient("0.0.0.0", port); client = std::make_shared(socket); }); + std::thread srv([&]() { int ssock = scl::details::CreateServerSocket(port, 1); auto ac = scl::details::AcceptConnection(ssock); @@ -58,7 +61,9 @@ TEST_CASE("TcpChannel", "[network]") { SECTION("Send Receive") { auto port = scl_tests::GetPort(); - std::shared_ptr client, server; + std::shared_ptr client; + std::shared_ptr server; + std::thread clt([&]() { int socket = scl::details::ConnectAsClient("0.0.0.0", port); client = std::make_shared(socket); @@ -79,8 +84,12 @@ TEST_CASE("TcpChannel", "[network]") { unsigned char recv[200] = {0}; prg.Next(send, 200); + REQUIRE(!server->HasData()); + client->Send(send, 100); client->Send(send + 100, 100); + + REQUIRE(server->HasData()); server->Recv(recv, 20); server->Recv(recv + 20, 180); @@ -90,7 +99,9 @@ TEST_CASE("TcpChannel", "[network]") { SECTION("Recv from closed") { auto port = scl_tests::GetPort(); - std::shared_ptr client, server; + std::shared_ptr client; + std::shared_ptr server; + std::thread clt([&]() { int socket = scl::details::ConnectAsClient("0.0.0.0", port); client = std::make_shared(socket); diff --git a/test/scl/net/test_threaded_sender.cc b/test/scl/net/test_threaded_sender.cc index 8c241af..c23b432 100644 --- a/test/scl/net/test_threaded_sender.cc +++ b/test/scl/net/test_threaded_sender.cc @@ -32,7 +32,8 @@ TEST_CASE("ThreadedSender", "[network]") { SECTION("Connect and send") { auto port = scl_tests::GetPort(); - std::shared_ptr client, server; + std::shared_ptr client; + std::shared_ptr server; std::thread clt([&]() { int socket = scl::details::ConnectAsClient("0.0.0.0", port); @@ -54,9 +55,23 @@ TEST_CASE("ThreadedSender", "[network]") { unsigned char recv[200] = {0}; prg.Next(send, 200); + REQUIRE(!server->HasData()); + client->Send(send, 100); client->Send(send + 100, 100); + // because the sender returns immediately, there might not be data + // available, so we will try a couple of times before failing. + { + using namespace std::chrono_literals; + auto c = 0; + while (c < 10 && !server->HasData()) { + std::this_thread::sleep_for(100ms); + c++; + } + } + REQUIRE(server->HasData()); + server->Recv(recv, 20); server->Recv(recv + 20, 180); diff --git a/test/scl/net/util.cc b/test/scl/net/util.cc index baf186b..6b1bbb7 100644 --- a/test/scl/net/util.cc +++ b/test/scl/net/util.cc @@ -26,7 +26,8 @@ int scl_tests::GetPort() { return test_port++; } bool scl_tests::BufferEquals(const unsigned char *a, const unsigned char *b, int n) { - while (n-- > 0 && *a++ == *b++) + while (n-- > 0 && *a++ == *b++) { ; + } return n < 0; } diff --git a/test/scl/net/util.h b/test/scl/net/util.h index 5ca9c9d..6fe5a34 100644 --- a/test/scl/net/util.h +++ b/test/scl/net/util.h @@ -18,10 +18,9 @@ * along with this program. If not, see . */ -#ifndef _TEST_SCL_NET_UTIL_H -#define _TEST_SCL_NET_UTIL_H +#ifndef TEST_SCL_NET_UTIL_H +#define TEST_SCL_NET_UTIL_H -#include namespace scl_tests { /** @@ -48,4 +47,4 @@ bool BufferEquals(const unsigned char* a, const unsigned char* b, int n); } // namespace scl_tests -#endif /* _TEST_SCL_NET_UTIL_H */ +#endif // TEST_SCL_NET_UTIL_H diff --git a/test/scl/p/test_simple.cc b/test/scl/p/test_simple.cc index c0aaa67..663e3f4 100644 --- a/test/scl/p/test_simple.cc +++ b/test/scl/p/test_simple.cc @@ -46,7 +46,7 @@ class BeaverMulFinalize public: BeaverMulFinalize(Triple t) : mTriple(t){}; - FF Finalize(Context& ctx) { + FF Finalize(Context& ctx) const { scl::Vec ed0(2); scl::Vec ed1(2); ctx.network.Party(0)->Recv(ed0); @@ -56,10 +56,11 @@ class BeaverMulFinalize auto d = ed0[1] + ed1[1]; if (ctx.id == 0) { + // constant addition return e * d - e * mTriple.b - d * mTriple.a + mTriple.c; - } else { - return -e * mTriple.b - d * mTriple.a + mTriple.c; } + + return -e * mTriple.b - d * mTriple.a + mTriple.c; }; private: diff --git a/test/scl/ss/test_shamir.cc b/test/scl/ss/test_shamir.cc index edee349..9a44801 100644 --- a/test/scl/ss/test_shamir.cc +++ b/test/scl/ss/test_shamir.cc @@ -19,6 +19,7 @@ */ #include +#include #include "../gf7.h" #include "scl/math.h" @@ -29,97 +30,87 @@ TEST_CASE("Shamir", "[ss]") { using FF = scl::Fp<61>; using Vec = scl::Vec; - scl::PRG prg; + const std::size_t t = 2; - SECTION("Share") { - auto secret = FF(123); - Vec alphas = {FF(2), FF(5), FF(3)}; - auto share_poly = scl::details::CreateShamirSharePolynomial(secret, 2, prg); - auto shares = scl::CreateShamirShares(share_poly, alphas); - - auto reconstructed = - scl::ReconstructShamirPassive(shares, alphas, FF(0), 2); - REQUIRE(reconstructed == secret); + SECTION("Reconstruct") { + scl::PRG prg; + scl::details::ShamirSSFactory factory( + t, prg, scl::details::SecurityLevel::PASSIVE); + auto intr = factory.GetInterpolator(); - auto some_share = scl::ReconstructShamirPassive(shares, alphas, FF(3), 2); - REQUIRE(some_share == shares[2]); - } - - SECTION("Passive") { auto secret = FF(123); - auto shares = scl::CreateShamirShares(secret, 4, 3, prg); - auto reconstructed = scl::ReconstructShamirPassive(shares, 3); - REQUIRE(reconstructed == secret); + auto shares = factory.Share(secret); + auto s = intr.Reconstruct(shares); + REQUIRE(s == secret); REQUIRE_THROWS_MATCHES( - scl::ReconstructShamirPassive(shares, 4), std::invalid_argument, + intr.Reconstruct(shares.SubVector(1)), std::invalid_argument, Catch::Matchers::Message("not enough shares to reconstruct")); - - Vec alphas = {FF(1), FF(2), FF(3)}; - REQUIRE_THROWS_MATCHES( - scl::ReconstructShamirPassive(shares, alphas, FF(0), 3), - std::invalid_argument, - Catch::Matchers::Message("not enough alphas to reconstruct")); } SECTION("Detection") { - auto secret = FF(123); - auto shares = scl::CreateShamirShares(secret, 7, 3, prg); - auto reconstructed = scl::ReconstructShamir(shares, 3); - REQUIRE(reconstructed == secret); + scl::PRG prg; + scl::details::ShamirSSFactory factory( + t, prg, scl::details::SecurityLevel::DETECT); + auto intr = scl::details::Reconstructor::Create( + t, scl::details::SecurityLevel::DETECT); - auto shares0 = shares; - shares0[2] = FF(4); - REQUIRE_THROWS_MATCHES( - scl::ReconstructShamir(shares0, 3), std::logic_error, - Catch::Matchers::Message("error detected during reconstruction")); + auto secret = FF(555); + auto shares = factory.Share(secret); + REQUIRE(shares.Size() == 2 * t + 1); + REQUIRE(intr.Reconstruct(shares) == secret); - auto shares1 = shares; - shares1[6] = FF(3); REQUIRE_THROWS_MATCHES( - scl::ReconstructShamir(shares1, 3), std::logic_error, - Catch::Matchers::Message("error detected during reconstruction")); + intr.Reconstruct(shares.SubVector(2)), std::invalid_argument, + Catch::Matchers::Message("not enough shares to reconstruct")); - REQUIRE_THROWS_MATCHES( - scl::ReconstructShamir(shares0, 4), std::invalid_argument, - Catch::Matchers::Message( - "not enough shares to reconstruct with error detection")); - REQUIRE_THROWS_MATCHES( - scl::ReconstructShamir(shares0, Vec{}, FF(0), 2), std::invalid_argument, - Catch::Matchers::Message( - "not enough alphas to reconstruct with error detection")); + auto ss = intr.ReconstructShare(shares, 2); + REQUIRE(ss == shares[2]); + REQUIRE(intr.Reconstruct(shares, 3) == intr.ReconstructShare(shares, 2)); } - SECTION("Correction") { + SECTION("Robust") { + scl::PRG prg; + scl::details::ShamirSSFactory factory( + t, prg, scl::details::SecurityLevel::CORRECT); + // no errors auto secret = FF(123); - auto shares = scl::CreateShamirShares(secret, 7, 2, prg); - auto reconstructed = scl::ReconstructShamirRobust(shares, 2); - CHECK(reconstructed == secret); + auto shares = factory.Share(secret); + REQUIRE(shares.Size() == 3 * t + 1); + auto reconstructed = scl::details::ReconstructShamirRobust(shares, t); + REQUIRE(reconstructed == secret); + + // can also reconstruct with an interpolator + auto intr = scl::details::Reconstructor::Create( + t, scl::details::SecurityLevel::CORRECT); + REQUIRE(intr.Reconstruct(shares) == secret); // one error shares[0] = FF(63212); - auto reconstructed_1 = scl::ReconstructShamirRobust(shares, 2); - CHECK(reconstructed_1 == secret); + auto reconstructed_1 = scl::details::ReconstructShamirRobust(shares, t); + REQUIRE(reconstructed_1 == secret); // two errors shares[2] = FF(63212211); - auto reconstructed_2 = scl::ReconstructShamirRobust(shares, 2); - CHECK(reconstructed_2 == secret); + auto reconstructed_2 = scl::details::ReconstructShamirRobust(shares, t); + REQUIRE(reconstructed_2 == secret); // three errors -- that's one too many shares[1] = FF(123); REQUIRE_THROWS_MATCHES( - scl::ReconstructShamirRobust(shares, 2), std::logic_error, + scl::details::ReconstructShamirRobust(shares, t), std::logic_error, Catch::Matchers::Message("could not correct shares")); REQUIRE_THROWS_MATCHES( - scl::ReconstructShamirRobust(shares, 3), std::invalid_argument, + scl::details::ReconstructShamirRobust(shares, t + 1), + std::invalid_argument, Catch::Matchers::Message( "not enough shares to reconstruct with error correction")); REQUIRE_THROWS_MATCHES( - scl::ReconstructShamirRobust(shares, Vec{}, 2), std::invalid_argument, + scl::details::ReconstructShamirRobust(shares, Vec{}, t), + std::invalid_argument, Catch::Matchers::Message( "not enough alphas to reconstruct with error correction")); } @@ -135,9 +126,9 @@ TEST_CASE("BerlekampWelch", "[ss][math]") { Vec as = {FF(0), FF(1), FF(2), FF(3), FF(4), FF(5), FF(6)}; Vec corrected = {FF(1), FF(6), FF(3), FF(6), FF(1), FF(2), FF(2)}; - auto pe = scl::ReconstructShamirRobust(bs, as, 2); - auto p = pe[0]; - auto e = pe[1]; + auto pe = scl::details::ReconstructShamirRobust(bs, as, 2); + auto p = std::get<0>(pe); + auto e = std::get<1>(pe); // errors REQUIRE(e.Evaluate(FF(1)) == FF{}); diff --git a/test/scl/test_hash.cc b/test/scl/test_hash.cc index 31d1801..30eb57f 100644 --- a/test/scl/test_hash.cc +++ b/test/scl/test_hash.cc @@ -67,42 +67,54 @@ TEST_CASE("Hash", "[misc]") { SECTION("SHA3-256 0xA3 x 200") { unsigned char byte = 0xA3; unsigned char buf[200]; - for (std::size_t i = 0; i < 200; ++i) buf[i] = byte; + for (std::size_t i = 0; i < 200; ++i) { + buf[i] = byte; + } scl::Hash<256> hash0; auto digest = hash0.Update(buf, 200).Finalize(); REQUIRE(digest == SHA3_256_0xa3_200_times); scl::Hash<256> hash1; - for (std::size_t i = 0; i < 200; ++i) hash1.Update(&byte, 1); + for (std::size_t i = 0; i < 200; ++i) { + hash1.Update(&byte, 1); + } REQUIRE(hash1.Finalize() == SHA3_256_0xa3_200_times); } SECTION("SHA3-384 0xA3 x 200") { unsigned char byte = 0xA3; unsigned char buf[200]; - for (std::size_t i = 0; i < 200; ++i) buf[i] = byte; + for (std::size_t i = 0; i < 200; ++i) { + buf[i] = byte; + } scl::Hash<384> hash0; auto digest = hash0.Update(buf, 200).Finalize(); REQUIRE(digest == SHA3_384_0xa3_200_times); scl::Hash<384> hash1; - for (std::size_t i = 0; i < 200; ++i) hash1.Update(&byte, 1); + for (std::size_t i = 0; i < 200; ++i) { + hash1.Update(&byte, 1); + } REQUIRE(hash1.Finalize() == SHA3_384_0xa3_200_times); } SECTION("SHA3-512 0xA3 x 200") { unsigned char byte = 0xA3; unsigned char buf[200]; - for (std::size_t i = 0; i < 200; ++i) buf[i] = byte; + for (std::size_t i = 0; i < 200; ++i) { + buf[i] = byte; + } scl::Hash<512> hash0; auto digest = hash0.Update(buf, 200).Finalize(); REQUIRE(digest == SHA3_512_0xa3_200_times); scl::Hash<512> hash1; - for (std::size_t i = 0; i < 200; ++i) hash1.Update(&byte, 1); + for (std::size_t i = 0; i < 200; ++i) { + hash1.Update(&byte, 1); + } REQUIRE(hash1.Finalize() == SHA3_512_0xa3_200_times); } diff --git a/test/scl/test_prg.cc b/test/scl/test_prg.cc index 5d8e843..e156824 100644 --- a/test/scl/test_prg.cc +++ b/test/scl/test_prg.cc @@ -24,10 +24,13 @@ inline bool BufferCmp(const unsigned char* b0, const unsigned char* b1, unsigned len) { - auto p0 = b0; - auto p1 = b1; - while (len-- > 0) - if (*p0++ != *p1++) return false; + const auto* p0 = b0; + const auto* p1 = b1; + while (len-- > 0) { + if (*p0++ != *p1++) { + return false; + } + } return true; } @@ -40,7 +43,7 @@ inline bool BufferLooksRandom(const unsigned char* p, unsigned len) { bool all_in_interval = true; for (std::size_t i = 0; i < 256; i++) { - auto p = 100 * ((float)buckets[i] / len); + auto p = 100 * ((float)buckets[i] / (float)len); all_in_interval &= p >= 0.2 || p <= 6.0; } return all_in_interval; @@ -84,7 +87,9 @@ TEST_CASE("PRG", "[misc]") { std::vector buffer(100); prg.Next(buffer, 50); bool last_is_zero = true; - for (std::size_t i = 50; i < 100; i++) last_is_zero &= buffer[i] == 0; + for (std::size_t i = 50; i < 100; i++) { + last_is_zero &= buffer[i] == 0; + } REQUIRE(last_is_zero); REQUIRE_THROWS_MATCHES(