diff --git a/.github/workflows/Build.yml b/.github/workflows/Build.yml
index b24c786..a478e9d 100644
--- a/.github/workflows/Build.yml
+++ b/.github/workflows/Build.yml
@@ -1,9 +1,14 @@
name: Build
-on: [push]
+on:
+ push:
+ branches:
+ - '*'
+ - '!master'
jobs:
build:
+ name: Release
runs-on: ubuntu-latest
steps:
diff --git a/.github/workflows/Checks.yml b/.github/workflows/Checks.yml
index 6a08abd..05adccc 100644
--- a/.github/workflows/Checks.yml
+++ b/.github/workflows/Checks.yml
@@ -1,30 +1,49 @@
name: Checks
-on: [push]
+on:
+ push:
+ branches:
+ - '*'
+ - '!master'
jobs:
- build:
+ documentation:
+ name: Documentation
runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v2
+
+ - name: Setup
+ run: sudo apt-get install -y doxygen
+
+ - name: Documentation
+ shell: bash
+ run: ./scripts/build_documentation.sh
+ headers:
+ name: Header files
+ runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v2
- - name: Setup
- run: sudo apt-get install -y doxygen clang-format-12
+ - name: Copyright
+ run: ./scripts/check_copyright_headers.py
- - name: Documentation
- shell: bash
- run: ./scripts/build_documentation.sh
+ - name: Header Guards
+ run: ./scripts/check_header_guards.py
- - name: Copyright
- run: ./scripts/check_copyright_headers.py
+ style:
+ name: Code style
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v2
- - name: Header Guards
- run: ./scripts/check_header_guards.py
+ - name: Setup
+ run: sudo apt-get install -y clang-format-12
- - name: Style
- shell: bash
- run: |
- find . -type f \( -iname "*.h" -o -iname "*.cc" \) -exec clang-format -n --style=Google {} \; &> checks.txt
- cat checks.txt
- test ! -s checks.txt
+ - name: Check
+ shell: bash
+ run: |
+ find . -type f \( -iname "*.h" -o -iname "*.cc" \) -exec clang-format -n --style=Google {} \; &> checks.txt
+ cat checks.txt
+ test ! -s checks.txt
diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml
index 4b0332e..c52ab45 100644
--- a/.github/workflows/Test.yml
+++ b/.github/workflows/Test.yml
@@ -1,38 +1,39 @@
name: Test
-on: [push]
+on:
+ push:
+ branches:
+ - '*'
+ - '!master'
env:
BUILD_TYPE: Debug
jobs:
build:
+ name: Coverage and Linting
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- - name: Setup catch2
+ - name: Setup
run: |
- sudo apt-get install -y lcov
+ sudo apt-get install -y lcov bear
curl -L https://github.com/catchorg/Catch2/archive/v2.13.0.tar.gz -o c.tar.gz
tar xvf c.tar.gz
cd Catch2-2.13.0/
- cmake -Bbuild -H. -DBUILD_TESTING=OFF
- sudo cmake --build build/ --target install
+ cmake -B catch -DBUILD_TESTING=OFF
+ cmake --build catch
+ sudo cmake --install catch
- - name: Create build directory
- run: cmake -E make_directory ${{runner.workspace}}/build
-
- - name: Configure CMake
- shell: bash
- working-directory: ${{runner.workspace}}/build
- run: cmake $GITHUB_WORKSPACE -DCMAKE_BUILD_TYPE=$BUILD_TYPE
+ - name: CMake
+ run: cmake -B ${{runner.workspace}}/build -DCMAKE_BUILD_TYPE=$BUILD_TYPE .
- name: Build
working-directory: ${{runner.workspace}}/build
shell: bash
- run: cmake --build . --config $BUILD_TYPE
+ run: bear make -s -j4
- name: Test
working-directory: ${{runner.workspace}}/build
@@ -40,12 +41,17 @@ jobs:
run: ctest -C $BUILD_TYPE
- name: Coverage
- working-directory: ${{runner.workspace}}/build
shell: bash
run: |
- make coverage
- lcov --summary coverage.info >> summary.txt
+ cmake --build ${{runner.workspace}}/build --target coverage
+ lcov --summary ${{runner.workspace}}/build/coverage.info >> ${{runner.workspace}}/summary.txt
+ ./scripts/check_coverage.py ${{runner.workspace}}/summary.txt
- - name: Check
+ - name: Lint
shell: bash
- run: ./scripts/check_coverage.py ${{runner.workspace}}/build/summary.txt
+ run: |
+ find include/ src/ test/ -type f \( -iname "*.h" -o -iname "*.cc" \) \
+ -exec clang-tidy -p ${{runner.workspace}}/build/compile_commands.json --quiet {} \; 1>> lint.txt 2>/dev/null
+ cat lint.txt
+ test ! -s lint.txt
+
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 810c4f6..25be133 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -16,7 +16,7 @@
cmake_minimum_required( VERSION 3.14 )
-project( scl VERSION 2.1.0 DESCRIPTION "Secure Computation Library" )
+project( scl VERSION 3.0.0 DESCRIPTION "Secure Computation Library" )
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
@@ -120,7 +120,7 @@ if(CMAKE_BUILD_TYPE MATCHES "Debug")
add_compile_definitions(SCL_ENABLE_EC_TESTS)
endif()
- set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0")
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -fsanitize=address")
find_package(Catch2 REQUIRED)
include(CTest)
include(Catch)
@@ -150,3 +150,5 @@ if(CMAKE_BUILD_TYPE MATCHES "Debug")
EXCLUDE "/usr/include/*" "test/*" "/usr/lib/*" "/usr/local/*")
endif()
+
+message(STATUS "CXX_FLAGS=" ${CMAKE_CXX_FLAGS})
diff --git a/RELEASE.txt b/RELEASE.txt
index afc0177..f56e41d 100644
--- a/RELEASE.txt
+++ b/RELEASE.txt
@@ -1,3 +1,17 @@
+3.0: More features, build changes
+- Add method for returning a point as a pair of affine coordinates
+- Add method to check if a channel has data available
+- Allow sending and receiving STL vectors without specifying the size
+- Extend Vec with a SubVector, operator== and operator!= methods
+- Begin Shamir code refactor and move all of it into details namespace
+- bugs:
+ - fix scalar multiplication for secp256k1_order
+ - fix compilation error on g++12
+- build:
+ - build tests with -fsanitize=address
+ - disable actions for master branch
+ - add clang-tidy action
+
2.1: More Finite Fields
- Provide a FF implementation for computations modulo the order of Secp256k1
- Extend EC with support for scalar multiplications with scalars from a finite
diff --git a/examples/03_secret_sharing.cc b/examples/03_secret_sharing.cc
index 55a6d55..c5f9dcc 100644
--- a/examples/03_secret_sharing.cc
+++ b/examples/03_secret_sharing.cc
@@ -46,22 +46,28 @@ int main() {
* correction. Lets see error detection at work first
*/
+ scl::details::ShamirSSFactory factory(
+ 1, prg, scl::details::SecurityLevel::CORRECT);
/* We create 4 shamir shares with a threshold of 1.
*/
- auto shamir_shares = scl::CreateShamirShares(secret, 4, 1, prg);
+ auto shamir_shares = factory.Share(secret);
std::cout << shamir_shares << "\n";
/* Of course, these can be reconstructed. The second parameter is the
* threshold. This performs reconstruction with error detection.
*/
- auto shamir_reconstructed = scl::ReconstructShamir(shamir_shares, 1);
+ auto recon = factory.GetInterpolator();
+ auto shamir_reconstructed =
+ recon.Reconstruct(shamir_shares, scl::details::SecurityLevel::DETECT);
std::cout << shamir_reconstructed << "\n";
/* If we introduce an error, then reconstruction fails
*/
shamir_shares[2] = Fp(123);
try {
- std::cout << scl::ReconstructShamir(shamir_shares, 1) << "\n";
+ std::cout << recon.Reconstruct(shamir_shares,
+ scl::details::SecurityLevel::DETECT)
+ << "\n";
} catch (std::logic_error& e) {
std::cout << e.what() << "\n";
}
@@ -69,7 +75,7 @@ int main() {
/* On the other hand, we can use the robust reconstruction since the threshold
* is low enough. I.e., because 4 >= 3*1 + 1.
*/
- auto r = scl::ReconstructShamirRobust(shamir_shares, 1);
+ auto r = recon.Reconstruct(shamir_shares);
std::cout << r << "\n";
/* With a bit of extra work, we can even learn which share had the error.
@@ -79,7 +85,7 @@ int main() {
* default these are just the field elements 1 through 4.
*/
Vec alphas = {Fp(1), Fp(2), Fp(3), Fp(4)};
- auto pe = scl::ReconstructShamirRobust(shamir_shares, alphas, 1);
+ auto pe = scl::details::ReconstructShamirRobust(shamir_shares, alphas, 1);
/* pe is a pair of polynomials. The first is the original polynomial used for
* generating the shares and the second is a polynomial whose roots tell which
@@ -87,18 +93,18 @@ int main() {
*
* The secret is embedded in the constant term.
*/
- std::cout << pe[0].Evaluate(Fp(0)) << "\n";
+ std::cout << std::get<0>(pe).Evaluate(Fp(0)) << "\n";
/* This will be 0, indicating that the share corresponding to party 3 had an
* error.
*/
- std::cout << pe[1].Evaluate(Fp(3)) << "\n";
+ std::cout << std::get<1>(pe).Evaluate(Fp(3)) << "\n";
/* Lastly, if there's too many errors, then correction is not possible
*/
shamir_shares[1] = Fp(22);
try {
- scl::ReconstructShamirRobust(shamir_shares, 1);
+ recon.Reconstruct(shamir_shares);
} catch (std::logic_error& e) {
std::cout << e.what() << "\n";
}
diff --git a/include/scl/hash.h b/include/scl/hash.h
index fa1416f..9b25dee 100644
--- a/include/scl/hash.h
+++ b/include/scl/hash.h
@@ -130,13 +130,17 @@ Hash &Hash::Update(const unsigned char *bytes, std::size_t nbytes) {
const unsigned char *p = bytes;
if (nbytes < old_tail) {
- while (nbytes--) mSaved |= (uint64_t)(*(p++)) << ((mByteIndex++) * 8);
+ while (nbytes-- > 0) {
+ mSaved |= (uint64_t)(*(p++)) << ((mByteIndex++) * 8);
+ }
return *this;
}
- if (old_tail) {
+ if (old_tail != 0) {
nbytes -= old_tail;
- while (old_tail--) mSaved |= (uint64_t)(*(p++)) << ((mByteIndex++) * 8);
+ while (old_tail-- != 0) {
+ mSaved |= (uint64_t)(*(p++)) << ((mByteIndex++) * 8);
+ }
mState[mWordIndex] ^= mSaved;
mByteIndex = 0;
@@ -167,7 +171,9 @@ Hash &Hash::Update(const unsigned char *bytes, std::size_t nbytes) {
p += sizeof(uint64_t);
}
- while (tail--) mSaved |= (uint64_t)(*(p++)) << ((mByteIndex++) * 8);
+ while (tail-- > 0) {
+ mSaved |= (uint64_t)(*(p++)) << ((mByteIndex++) * 8);
+ }
return *this;
}
@@ -194,7 +200,9 @@ auto Hash::Finalize() -> DigestType {
// truncate
DigestType digest = {0};
- for (std::size_t i = 0; i < digest.size(); ++i) digest[i] = mStateBytes[i];
+ for (std::size_t i = 0; i < digest.size(); ++i) {
+ digest[i] = mStateBytes[i];
+ }
return digest;
}
@@ -208,7 +216,9 @@ template
std::string DigestToString(const D &digest) {
std::stringstream ss;
ss << std::setw(2) << std::setfill('0') << std::hex;
- for (const auto &c : digest) ss << (int)c;
+ for (const auto &c : digest) {
+ ss << (int)c;
+ }
return ss.str();
}
diff --git a/include/scl/math/ec.h b/include/scl/math/ec.h
index c33c196..caed8d4 100644
--- a/include/scl/math/ec.h
+++ b/include/scl/math/ec.h
@@ -271,6 +271,14 @@ class EC {
return details::CurveIsPointAtInfinity(mValue);
};
+ /**
+ * @brief Return this point as a pair of affine coordinates.
+ * @return this point as a pair of affine coordinates.
+ */
+ std::array ToAffine() const {
+ return details::CurveToAffine(mValue);
+ };
+
/**
* @brief Output this point as a string.
*/
diff --git a/include/scl/math/ec_ops.h b/include/scl/math/ec_ops.h
index 1d1f433..8a601f4 100644
--- a/include/scl/math/ec_ops.h
+++ b/include/scl/math/ec_ops.h
@@ -59,6 +59,15 @@ template
void CurveSetAffine(typename C::ValueType& out, const FF& x,
const FF& y);
+/**
+ * @brief Convert a point to a pair of affine coordinates.
+ * @param point the point to convert.
+ * @return a set of affine coordinates.
+ */
+template
+std::array, 2> CurveToAffine(
+ const typename C::ValueType& point);
+
/**
* @brief Add two elliptic curve points in-place.
* @param out the first point and output
@@ -135,11 +144,11 @@ void CurveToBytes(unsigned char* dest, const typename C::ValueType& in,
/**
* @brief Convert an elliptic curve point to a string
- * @param in the point
+ * @param point the point
* @return an STL string representation of \p in.
*/
template
-std::string CurveToString(const typename C::ValueType& in);
+std::string CurveToString(const typename C::ValueType& point);
} // namespace details
} // namespace scl
diff --git a/include/scl/math/ff_ops.h b/include/scl/math/ff_ops.h
index 7e64750..8c293fd 100644
--- a/include/scl/math/ff_ops.h
+++ b/include/scl/math/ff_ops.h
@@ -37,7 +37,7 @@ namespace details {
* @param value the integer to convert
*/
template
-void FieldConvertIn(typename F::ValueType& out, const int value);
+void FieldConvertIn(typename F::ValueType& out, int value);
/**
* @brief Add two field elements in-place.
diff --git a/include/scl/math/la.h b/include/scl/math/la.h
index 4651b72..21ad7a7 100644
--- a/include/scl/math/la.h
+++ b/include/scl/math/la.h
@@ -53,7 +53,9 @@ void SwapRows(Mat& A, std::size_t k, std::size_t h) {
*/
template
void MultiplyRow(Mat& A, std::size_t row, const T& m) {
- for (std::size_t j = 0; j < A.Cols(); ++j) A(row, j) *= m;
+ for (std::size_t j = 0; j < A.Cols(); ++j) {
+ A(row, j) *= m;
+ }
}
/**
@@ -65,7 +67,9 @@ void MultiplyRow(Mat& A, std::size_t row, const T& m) {
*/
template
void AddRows(Mat& A, std::size_t dst, std::size_t op, const T& m) {
- for (std::size_t j = 0; j < A.Cols(); ++j) A(dst, j) += A(op, j) * m;
+ for (std::size_t j = 0; j < A.Cols(); ++j) {
+ A(dst, j) += A(op, j) * m;
+ }
}
/**
@@ -83,7 +87,9 @@ void RowReduceInPlace(Mat& A) {
while (r < n && c < m) {
// find pivot in current column
auto pivot = r;
- while (pivot < n && A(pivot, c) == zero) pivot++;
+ while (pivot < n && A(pivot, c) == zero) {
+ pivot++;
+ }
if (pivot == n) {
// this column was all 0, so go to next one
@@ -97,10 +103,14 @@ void RowReduceInPlace(Mat& A) {
// finally, for each row that is not r, subtract a multiple of row r.
for (std::size_t k = 0; k < n; ++k) {
- if (k == r) continue;
+ if (k == r) {
+ continue;
+ }
// skip row if leading coefficient of that row is 0.
auto t = A(k, c);
- if (t != zero) AddRows(A, k, r, -t);
+ if (t != zero) {
+ AddRows(A, k, r, -t);
+ }
}
r++;
c++;
@@ -122,7 +132,9 @@ int GetPivotInColumn(const Mat& A, int col) {
while (i-- > 0) {
if (A(i, col) != zero) {
for (int k = 0; k < col - 1; ++k) {
- if (A(i, k) != zero) return -1;
+ if (A(i, k) != zero) {
+ return -1;
+ }
}
return i;
}
@@ -152,7 +164,9 @@ std::size_t FindFirstNonZeroRow(const Mat& A) {
break;
}
}
- if (non_zero) break;
+ if (non_zero) {
+ break;
+ }
}
return nzr;
}
@@ -185,7 +199,9 @@ Vec ExtractSolution(const Mat& A) {
x[c] = T{1};
} else {
T sum;
- for (std::size_t j = p + 1; j < n; ++j) sum += A(i, j) * x[j];
+ for (std::size_t j = p + 1; j < n; ++j) {
+ sum += A(i, j) * x[j];
+ }
x[c] = A(i, m - 1) - sum;
i--;
}
@@ -218,9 +234,13 @@ bool HasSolution(const Mat& A, bool unique_only) {
// the last column (the augmentation). I.e., when row(A', i) == 0, but
// row(A, i) != 0.
if (unique_only) {
- if (all_zero) return false;
+ if (all_zero) {
+ return false;
+ }
} else {
- if (all_zero && A(i, m - 1) != zero) return false;
+ if (all_zero && A(i, m - 1) != zero) {
+ return false;
+ }
}
}
return true;
diff --git a/include/scl/math/mat.h b/include/scl/math/mat.h
index a505426..c027f7d 100644
--- a/include/scl/math/mat.h
+++ b/include/scl/math/mat.h
@@ -85,7 +85,9 @@ class Mat {
static Mat Vandermonde(std::size_t n, std::size_t m) {
std::vector xs;
xs.reserve(n);
- for (std::size_t i = 0; i < n; ++i) xs.emplace_back(T(i + 1));
+ for (std::size_t i = 0; i < n; ++i) {
+ xs.emplace_back(T(i + 1));
+ }
return Mat::Vandermonde(n, m, xs);
}
@@ -119,7 +121,9 @@ class Mat {
*/
static Mat FromVector(std::size_t n, std::size_t m,
const std::vector& vec) {
- if (vec.size() != n * m) throw std::invalid_argument("invalid dimensions");
+ if (vec.size() != n * m) {
+ throw std::invalid_argument("invalid dimensions");
+ }
return Mat(n, m, vec);
};
@@ -128,7 +132,9 @@ class Mat {
*/
static Mat Identity(std::size_t n) {
Mat I(n);
- for (std::size_t i = 0; i < n; ++i) I(i, i) = T(1);
+ for (std::size_t i = 0; i < n; ++i) {
+ I(i, i) = T(1);
+ }
return I;
} // LCOV_EXCL_LINE
@@ -143,7 +149,9 @@ class Mat {
* @param m the number of columns
*/
explicit Mat(std::size_t n, std::size_t m) {
- if (!(n && m)) throw std::invalid_argument("n or m cannot be 0");
+ if (n == 0 || m == 0) {
+ throw std::invalid_argument("n or m cannot be 0");
+ }
std::vector v(n * m);
mRows = n;
mCols = m;
@@ -296,7 +304,9 @@ class Mat {
* @return this scaled by \p scalar.
*/
Mat& ScalarMultiplyInPlace(const T& scalar) {
- for (auto& v : mValues) v *= scalar;
+ for (auto& v : mValues) {
+ v *= scalar;
+ }
return *this;
};
@@ -317,8 +327,9 @@ class Mat {
* @param cols the new column count
*/
Mat& Resize(std::size_t rows, std::size_t cols) {
- if (rows * cols != Rows() * Cols())
+ if (rows * cols != Rows() * Cols()) {
throw std::invalid_argument("cannot resize matrix");
+ }
mRows = rows;
mCols = cols;
return *this;
@@ -345,11 +356,14 @@ class Mat {
* @brief Test if this matrix is equal to another.
*/
bool Equals(const Mat& other) const {
- if (Rows() != other.Rows() || Cols() != other.Cols()) return false;
+ if (Rows() != other.Rows() || Cols() != other.Cols()) {
+ return false;
+ }
bool equal = true;
- for (std::size_t i = 0; i < mValues.size(); i++)
+ for (std::size_t i = 0; i < mValues.size(); i++) {
equal &= mValues[i] == other.mValues[i];
+ }
return equal;
};
@@ -378,8 +392,9 @@ class Mat {
: mRows(r), mCols(c), mValues(v){};
void EnsureCompatible(const Mat& other) {
- if (mRows != other.mRows || mCols != other.mCols)
+ if (mRows != other.mRows || mCols != other.mCols) {
throw std::invalid_argument("incompatible matrices");
+ }
};
std::size_t mRows;
@@ -391,7 +406,7 @@ class Mat {
template
Mat Mat::Read(std::size_t n, std::size_t m, const unsigned char* src) {
- auto ptr = src;
+ const auto* ptr = src;
auto total = n * m;
// write all elements now that we know we'll not exceed the maximum read size.
@@ -406,10 +421,9 @@ Mat Mat::Read(std::size_t n, std::size_t m, const unsigned char* src) {
template
void Mat::Write(unsigned char* dest) const {
- auto ptr = dest;
for (const auto& v : mValues) {
- v.Write(ptr);
- ptr += T::ByteSize();
+ v.Write(dest);
+ dest += T::ByteSize();
}
}
@@ -421,10 +435,11 @@ Mat Mat::Random(std::size_t n, std::size_t m, PRG& prg) {
std::size_t buffer_size = nelements * T::ByteSize();
auto buffer = std::make_unique(buffer_size);
- auto ptr = buffer.get();
+ auto* ptr = buffer.get();
prg.Next(buffer.get(), buffer_size);
- for (std::size_t i = 0; i < nelements; i++)
+ for (std::size_t i = 0; i < nelements; i++) {
elements.emplace_back(T::Read(ptr + i * T::ByteSize()));
+ }
return Mat(n, m, elements);
}
@@ -482,8 +497,9 @@ Mat Mat::HyperInvertible(std::size_t n, std::size_t m) {
template
Mat Mat::Multiply(const Mat& other) const {
- if (Cols() != other.Rows())
+ if (Cols() != other.Rows()) {
throw std::invalid_argument("invalid matrix dimensions for multiply");
+ }
const auto n = Rows();
const auto p = Cols();
const auto m = other.Cols();
@@ -512,15 +528,18 @@ Mat Mat::Transpose() const {
template
bool Mat::IsIdentity() const {
- if (!IsSquare()) return false;
+ if (!IsSquare()) {
+ return false;
+ }
bool is_ident = true;
for (std::size_t i = 0; i < Rows(); ++i) {
for (std::size_t j = 0; j < Cols(); ++j) {
- if (i == j)
+ if (i == j) {
is_ident &= operator()(i, j) == T{1};
- else
+ } else {
is_ident &= operator()(i, j) == T{0};
+ }
}
}
return is_ident;
@@ -536,13 +555,15 @@ std::string Mat::ToString() const {
const auto n = Rows();
const auto m = Cols();
- if (!(n && m)) return "[ EMPTY_MATRIX ]";
+ if (!(n && m)) {
+ return "[ EMPTY_MATRIX ]";
+ }
// convert all elements to strings and find the widest string in each column
// since that will be used to properly align the final output.
std::vector elements;
elements.reserve(n * m);
- std::vector fills;
+ std::vector fills;
fills.reserve(m);
for (std::size_t j = 0; j < m; j++) {
auto first = operator()(0, j).ToString();
@@ -551,7 +572,9 @@ std::string Mat::ToString() const {
for (std::size_t i = 1; i < n; i++) {
auto next = operator()(i, j).ToString();
auto next_fill = next.size();
- if (next_fill > fill) fill = next_fill;
+ if (next_fill > fill) {
+ fill = next_fill;
+ }
elements.push_back(next);
}
fills.push_back(fill + 1);
@@ -566,7 +589,9 @@ std::string Mat::ToString() const {
<< " ";
}
ss << "]";
- if (i < n - 1) ss << "\n";
+ if (i < n - 1) {
+ ss << "\n";
+ }
}
return ss.str();
}
diff --git a/include/scl/math/number.h b/include/scl/math/number.h
index 45cc0da..4b536ed 100644
--- a/include/scl/math/number.h
+++ b/include/scl/math/number.h
@@ -76,7 +76,7 @@ class Number {
* @brief Move constructor for a Number.
* @param number the Number that is moved
*/
- Number(Number&& number);
+ Number(Number&& number) noexcept;
/**
* @brief Copy assignment from a Number.
@@ -94,7 +94,7 @@ class Number {
* @param number the Number that is moved
* @return this
*/
- Number& operator=(Number&& number) {
+ Number& operator=(Number&& number) noexcept {
swap(*this, number);
return *this;
};
@@ -257,42 +257,47 @@ class Number {
*/
int Compare(const Number& number) const;
-#define SCL_CMP_FUN_NN(op) \
- friend bool operator op(const Number& lhs, const Number& rhs) { \
- return lhs.Compare(rhs) op 0; \
- }
-
/**
* @brief Equality of two numbers.
*/
- SCL_CMP_FUN_NN(==);
+ friend bool operator==(const Number& lhs, const Number& rhs) {
+ return lhs.Compare(rhs) == 0;
+ };
/**
* @brief In-equality of two numbers.
*/
- SCL_CMP_FUN_NN(!=);
+ friend bool operator!=(const Number& lhs, const Number& rhs) {
+ return lhs.Compare(rhs) != 0;
+ };
/**
* @brief Strictly less-than of two numbers.
*/
- SCL_CMP_FUN_NN(<);
+ friend bool operator<(const Number& lhs, const Number& rhs) {
+ return lhs.Compare(rhs) < 0;
+ };
/**
* @brief Less-than-or-equal of two numbers.
*/
- SCL_CMP_FUN_NN(<=);
+ friend bool operator<=(const Number& lhs, const Number& rhs) {
+ return lhs.Compare(rhs) <= 0;
+ };
/**
* @brief Strictly greater-than of two numbers.
*/
- SCL_CMP_FUN_NN(>);
+ friend bool operator>(const Number& lhs, const Number& rhs) {
+ return lhs.Compare(rhs) > 0;
+ };
/**
* @brief Greater-than-or-equal of two numbers.
*/
- SCL_CMP_FUN_NN(>=);
-
-#undef SCL_CMP_FUN_NN
+ friend bool operator>=(const Number& lhs, const Number& rhs) {
+ return lhs.Compare(rhs) >= 0;
+ };
/**
* @brief Get the size of this Number in bits.
diff --git a/include/scl/math/str.h b/include/scl/math/str.h
index 4d7f25e..f337729 100644
--- a/include/scl/math/str.h
+++ b/include/scl/math/str.h
@@ -50,7 +50,9 @@ namespace details {
template
T FromHexString(const std::string& s) {
auto n = s.size();
- if (n % 2) throw std::invalid_argument("odd-length hex string");
+ if (n % 2) {
+ throw std::invalid_argument("odd-length hex string");
+ }
T t = 0;
for (std::size_t i = 0; i < n; i += 2) {
char c0 = s[i];
diff --git a/include/scl/math/vec.h b/include/scl/math/vec.h
index 6b33fcd..ab5e89a 100644
--- a/include/scl/math/vec.h
+++ b/include/scl/math/vec.h
@@ -24,6 +24,7 @@
#include
#include
#include
+#include
#include
#include
@@ -32,6 +33,25 @@
namespace scl {
+namespace details {
+
+/**
+ * @brief Computes an unchecked inner product between two iterators.
+ * @param xb start of the first iterator
+ * @param xe end of the first iterator
+ * @param yb start of the second iterator
+ */
+template
+T UncheckedInnerProd(It xb, It xe, It yb) {
+ T v;
+ while (xb != xe) {
+ v += *xb++ * *yb++;
+ }
+ return v;
+}
+
+} // namespace details
+
/**
* @brief Vector.
*
@@ -98,6 +118,14 @@ class Vec {
*/
static Vec Random(std::size_t n, PRG& prg);
+ /**
+ * @brief Create a vector with values in a range.
+ * @param start the start value, inclusive
+ * @param end the end value, exclusive
+ * @return a vector with values [start, start + 1, ..., end - 1].
+ */
+ static Vec Range(std::size_t start, std::size_t end);
+
/**
* @brief Default constructor that creates an empty Vec.
*/
@@ -212,10 +240,7 @@ class Vec {
*/
T Dot(const Vec& other) const {
EnsureCompatible(other);
- T result;
- for (std::size_t i = 0; i < Size(); i++)
- result += mValues[i] * other.mValues[i];
- return result;
+ return details::UncheckedInnerProd(begin(), end(), other.begin());
};
/**
@@ -224,7 +249,9 @@ class Vec {
*/
T Sum() const {
T sum;
- for (const auto& v : mValues) sum += v;
+ for (const auto& v : mValues) {
+ sum += v;
+ }
return sum;
};
@@ -236,7 +263,9 @@ class Vec {
Vec ScalarMultiply(const T& scalar) const {
std::vector r;
r.reserve(Size());
- for (const auto& v : mValues) r.emplace_back(scalar * v);
+ for (const auto& v : mValues) {
+ r.emplace_back(scalar * v);
+ }
return Vec(r);
};
@@ -246,7 +275,9 @@ class Vec {
* @return a scaled version of this vector.
*/
Vec& ScalarMultiplyInPlace(const T& scalar) {
- for (auto& v : mValues) v *= scalar;
+ for (auto& v : mValues) {
+ v *= scalar;
+ }
return *this;
};
@@ -257,6 +288,20 @@ class Vec {
*/
bool Equals(const Vec& other) const;
+ /**
+ * @brief Operator == overload for Vec.
+ */
+ friend bool operator==(const Vec& left, const Vec& right) {
+ return left.Equals(right);
+ };
+
+ /**
+ * @brief Operator != overload for Vec.
+ */
+ friend bool operator!=(const Vec& left, const Vec& right) {
+ return !(left == right);
+ };
+
/**
* @brief Convert this vector into a 1-by-N row matrix.
*/
@@ -277,6 +322,29 @@ class Vec {
*/
const std::vector& ToStlVector() const { return mValues; };
+ /**
+ * @brief Extract a sub-vector
+ * @param start the start index, inclusive
+ * @param end the end index, exclusive
+ * @return a sub-vector.
+ */
+ Vec SubVector(std::size_t start, std::size_t end) {
+ if (start > end) {
+ throw std::logic_error("invalid range");
+ }
+ return Vec(begin() + start, begin() + end);
+ }
+
+ /**
+ * @brief Extract a sub-vector.
+ *
+ * This method is equivalent to Vec#SubVector(0, end).
+ *
+ * @param end the end index, exclusive
+ * @return a sub-vector.
+ */
+ Vec SubVector(std::size_t end) { return SubVector(0, end); };
+
/**
* @brief Return a string representation of this vector.
*/
@@ -362,8 +430,9 @@ class Vec {
private:
void EnsureCompatible(const Vec& other) const {
- if (Size() != other.Size())
+ if (Size() != other.Size()) {
throw std::invalid_argument("Vec sizes mismatch");
+ }
};
std::vector mValues;
@@ -396,6 +465,23 @@ Vec Vec::PartialRandom(std::size_t n, Pred predicate, PRG& prg) {
return Vec(v);
}
+template
+Vec Vec::Range(std::size_t start, std::size_t end) {
+ if (start > end) {
+ throw std::invalid_argument("invalid range");
+ }
+ if (start == end) {
+ return Vec{};
+ }
+
+ std::vector v;
+ v.reserve(end - start);
+ for (std::size_t i = start; i < end; ++i) {
+ v.emplace_back(T{(int)i});
+ }
+ return Vec(v);
+}
+
template
Vec Vec::Random(std::size_t n, PRG& prg) {
return Vec::PartialRandom(
@@ -459,24 +545,27 @@ bool Vec::Equals(const Vec& other) const {
template
std::string Vec::ToString() const {
+ std::string str;
if (Size()) {
std::stringstream ss;
ss << "[";
std::size_t i = 0;
- for (; i < Size() - 1; i++) ss << mValues[i] << ", ";
+ for (; i < Size() - 1; i++) {
+ ss << mValues[i] << ", ";
+ }
ss << mValues[i] << "]";
- return ss.str();
+ str = ss.str();
} else {
- return "[ EMPTY_VECTOR ]";
+ str = "[ EMPTY_VECTOR ]";
}
+ return str;
}
template
void Vec::Write(unsigned char* dest) const {
- unsigned char* p = dest;
for (const auto& v : mValues) {
- v.Write(p);
- p += T::ByteSize();
+ v.Write(dest);
+ dest += T::ByteSize();
}
}
diff --git a/include/scl/math/z2k_ops.h b/include/scl/math/z2k_ops.h
index 5d7e618..c767bad 100644
--- a/include/scl/math/z2k_ops.h
+++ b/include/scl/math/z2k_ops.h
@@ -82,8 +82,9 @@ unsigned LsbZ2k(T& v) {
*/
template = true>
void InvertZ2k(T& v) {
- if (!LsbZ2k(v))
+ if (!LsbZ2k(v)) {
throw std::invalid_argument("value not invertible modulo 2^K");
+ }
std::size_t bits = 5;
T z = ((v * 3) ^ 2);
@@ -95,7 +96,7 @@ void InvertZ2k(T& v) {
v = z;
}
-#define SCL_MASK(T, K) ((static_cast(1) << K) - 1)
+#define SCL_MASK(T, K) ((static_cast(1) << (K)) - 1)
/**
* @brief Compute equality modulo a power of 2.
diff --git a/include/scl/net/channel.h b/include/scl/net/channel.h
index 75c9d4d..1a4becb 100644
--- a/include/scl/net/channel.h
+++ b/include/scl/net/channel.h
@@ -51,8 +51,8 @@ namespace scl {
#define MAX_MAT_READ_SIZE 1 << 25
#endif
-#define SCL_CC(x) reinterpret_cast(x)
-#define SCL_C(x) reinterpret_cast(x)
+#define SCL_CC(x) reinterpret_cast((x))
+#define SCL_C(x) reinterpret_cast((x))
/**
* @brief Abstract channel for communicating between two peers.
@@ -85,7 +85,14 @@ class Channel {
* @param n how much data to receive
* @return how many bytes were received.
*/
- virtual int Recv(unsigned char* dst, std::size_t n) = 0;
+ virtual std::size_t Recv(unsigned char* dst, std::size_t n) = 0;
+
+ /**
+ * @brief Check if there is something to receive on this channel.
+ * @return true if this channel has data and false otherwise.
+ * @note the default implementation always returns true.
+ */
+ virtual bool HasData() { return true; };
/**
* @brief Send a trivially copyable item.
@@ -108,6 +115,7 @@ class Channel {
template , bool> = true>
void Send(const std::vector& src) {
+ Send(src.size());
Send(SCL_CC(src.data()), sizeof(T) * src.size());
}
@@ -168,15 +176,16 @@ class Channel {
/**
* @brief Receive a vector of trivially copyable items.
- *
- * dst.size() determines how many bytes to receive.
- *
* @param dst where to store the received items
+ * @note any existing content in \p dst is overwritten.
*/
template , bool> = true>
void Recv(std::vector& dst) {
- Recv(SCL_C(dst.data()), sizeof(T) * dst.size());
+ std::size_t size;
+ Recv(size);
+ dst.resize(size);
+ Recv(SCL_C(dst.data()), sizeof(T) * size);
}
/**
@@ -188,8 +197,9 @@ class Channel {
template
void Recv(Vec& vec) {
auto vec_size = RecvSize();
- if (vec_size > MAX_VEC_READ_SIZE)
+ if (vec_size > MAX_VEC_READ_SIZE) {
throw std::logic_error("received vector exceeds size limit");
+ }
auto n = vec_size * T::ByteSize();
auto buf = std::make_unique(n);
Recv(SCL_C(buf.get()), n);
@@ -206,8 +216,9 @@ class Channel {
void Recv(Mat& mat) {
auto rows = RecvSize();
auto cols = RecvSize();
- if (rows * cols > MAX_MAT_READ_SIZE)
+ if (rows * cols > MAX_MAT_READ_SIZE) {
throw std::logic_error("received matrix exceeds size limit");
+ }
auto n = rows * cols * T::ByteSize();
auto buf = std::make_unique(n);
Recv(SCL_C(buf.get()), n);
diff --git a/include/scl/net/config.h b/include/scl/net/config.h
index fd36d54..8831779 100644
--- a/include/scl/net/config.h
+++ b/include/scl/net/config.h
@@ -42,7 +42,7 @@ struct Party {
/**
* @brief The id of this party.
*/
- unsigned id;
+ int id;
/**
* @brief The hostname.
@@ -68,7 +68,7 @@ class NetworkConfig {
* @param id the identity of this party
* @param filename the filename
*/
- static NetworkConfig Load(unsigned id, std::string filename);
+ static NetworkConfig Load(int id, const std::string& filename);
/**
* @brief Create a network config where all parties are running locally.
@@ -82,14 +82,14 @@ class NetworkConfig {
* @param size the size of the network
* @param port_base the base port
*/
- static NetworkConfig Localhost(unsigned id, std::size_t size, int port_base);
+ static NetworkConfig Localhost(int id, int size, int port_base);
/**
* @brief Create a network config where all parties are running locally.
* @param id the identity of this party
* @param size the size of the network
*/
- static NetworkConfig Localhost(unsigned id, std::size_t size) {
+ static NetworkConfig Localhost(int id, int size) {
return NetworkConfig::Localhost(id, size, DEFAULT_PORT_OFFSET);
};
@@ -98,7 +98,7 @@ class NetworkConfig {
* @param id the id of the local party
* @param parties a list of parties
*/
- NetworkConfig(unsigned id, std::vector parties)
+ NetworkConfig(int id, const std::vector& parties)
: mId(id), mParties(parties) {
Validate();
};
@@ -112,7 +112,7 @@ class NetworkConfig {
/**
* @brief Gets the identity of this party.
*/
- unsigned Id() const { return mId; };
+ int Id() const { return mId; };
/**
* @brief Gets the size of the network.
@@ -137,7 +137,7 @@ class NetworkConfig {
private:
void Validate();
- unsigned mId;
+ int mId;
std::vector mParties;
};
diff --git a/include/scl/net/discovery/client.h b/include/scl/net/discovery/client.h
index ed81866..50e5bb1 100644
--- a/include/scl/net/discovery/client.h
+++ b/include/scl/net/discovery/client.h
@@ -41,14 +41,14 @@ class DiscoveryClient {
* @param discovery_port the port of the discovery server
* @param discovery_hostname the hostname of the discovery server
*/
- DiscoveryClient(std::string discovery_hostname, int discovery_port)
+ DiscoveryClient(const std::string& discovery_hostname, int discovery_port)
: mHostname(discovery_hostname), mPort(discovery_port){};
/**
* @brief Create a new client in a discovery protocol.
* @param discovery_hostname the hostname of the discovery server
*/
- DiscoveryClient(std::string discovery_hostname)
+ DiscoveryClient(const std::string& discovery_hostname)
: DiscoveryClient(discovery_hostname, DEFAULT_DISCOVERY_PORT){};
/**
@@ -57,7 +57,7 @@ class DiscoveryClient {
* @param port the port of this party
* @return A network configuration.
*/
- NetworkConfig Run(unsigned id, int port);
+ NetworkConfig Run(int id, int port) const;
class SendIdAndPort;
class ReceiveNetworkConfig;
@@ -77,15 +77,16 @@ class DiscoveryClient::SendIdAndPort
/**
* @brief Constructor.
*/
- SendIdAndPort(unsigned id, int port) : mId(id), mPort(port){};
+ SendIdAndPort(int id, int port) : mId(id), mPort(port){};
/**
* @brief Run this protocol step.
*/
- DiscoveryClient::ReceiveNetworkConfig Run(std::shared_ptr ctx);
+ DiscoveryClient::ReceiveNetworkConfig Run(
+ const std::shared_ptr& ctx) const;
private:
- unsigned mId;
+ int mId;
int mPort;
};
@@ -99,15 +100,15 @@ class DiscoveryClient::ReceiveNetworkConfig
/**
* @brief Constructor.
*/
- ReceiveNetworkConfig(unsigned id) : mId(id){};
+ ReceiveNetworkConfig(int id) : mId(id){};
/**
* @brief Finalize the discovery protocol.
*/
- NetworkConfig Finalize(std::shared_ptr ctx);
+ NetworkConfig Finalize(const std::shared_ptr& ctx) const;
private:
- unsigned mId;
+ int mId;
};
} // namespace scl
diff --git a/include/scl/net/discovery/server.h b/include/scl/net/discovery/server.h
index 371c6d9..b02e22c 100644
--- a/include/scl/net/discovery/server.h
+++ b/include/scl/net/discovery/server.h
@@ -53,8 +53,9 @@ class DiscoveryServer {
*/
DiscoveryServer(int discovery_port, std::size_t number_of_parties)
: mPort(discovery_port), mNumberOfParties(number_of_parties) {
- if (number_of_parties > MAX_DISCOVER_PARTIES)
+ if (number_of_parties > MAX_DISCOVER_PARTIES) {
throw std::invalid_argument("number_of_parties exceeds max");
+ }
};
/**
@@ -69,7 +70,7 @@ class DiscoveryServer {
* @param me ID, port and hostname information for this party
* @return A network configuration.
*/
- NetworkConfig Run(const Party& me);
+ NetworkConfig Run(const Party& me) const;
class CollectIdsAndPorts;
class SendNetworkConfig;
@@ -107,7 +108,7 @@ class DiscoveryServer::CollectIdsAndPorts
/**
* @brief Constructor.
*/
- CollectIdsAndPorts(std::vector hostnames)
+ CollectIdsAndPorts(const std::vector& hostnames)
: mHostnames(hostnames){};
/**
diff --git a/include/scl/net/mem_channel.h b/include/scl/net/mem_channel.h
index a054b08..5e552b0 100644
--- a/include/scl/net/mem_channel.h
+++ b/include/scl/net/mem_channel.h
@@ -46,8 +46,9 @@ class InMemoryChannel final : public Channel {
static std::array, 2> CreatePaired() {
auto buf0 = std::make_shared();
auto buf1 = std::make_shared();
- return {std::make_shared(buf0, buf1),
- std::make_shared(buf1, buf0)};
+ auto chl0 = std::make_shared(buf0, buf1);
+ auto chl1 = std::make_shared(buf1, buf0);
+ return {chl0, chl1};
};
/**
@@ -65,18 +66,21 @@ class InMemoryChannel final : public Channel {
*/
InMemoryChannel(std::shared_ptr in_buffer,
std::shared_ptr out_buffer)
- : mIn(in_buffer), mOut(out_buffer){};
+ : mIn(std::move(in_buffer)), mOut(std::move(out_buffer)){};
/**
* @brief Flush the incomming buffer;
*/
void Flush() {
- while (mIn->Size()) mIn->PopFront();
+ while (mIn->Size() > 0) {
+ mIn->PopFront();
+ }
mOverflow.clear();
};
void Send(const unsigned char* src, std::size_t n) override;
- int Recv(unsigned char* dst, std::size_t n) override;
+ std::size_t Recv(unsigned char* dst, std::size_t n) override;
+ bool HasData() override { return mIn->Size() > 0 || !mOverflow.empty(); };
void Close() override{};
private:
diff --git a/include/scl/net/network.h b/include/scl/net/network.h
index 85b0ed2..b1ed6f3 100644
--- a/include/scl/net/network.h
+++ b/include/scl/net/network.h
@@ -21,6 +21,7 @@
#ifndef SCL_NET_NETWORK_H
#define SCL_NET_NETWORK_H
+#include
#include
#include
#include
@@ -75,7 +76,9 @@ class Network {
* @brief Closes all channels in the network.
*/
void Close() {
- for (auto c : mChannels) c->Close();
+ for (auto& c : mChannels) {
+ c->Close();
+ }
};
private:
@@ -153,7 +156,7 @@ Network Network::Create(const scl::NetworkConfig& config) {
std::thread server(scl::details::SCL_AcceptConnections, std::ref(channels),
config);
- for (std::size_t i = 0; i < config.Id(); ++i) {
+ for (std::size_t i = 0; i < static_cast(config.Id()); ++i) {
const auto party = config.GetParty(i);
auto socket = scl::details::ConnectAsClient(party.hostname, party.port);
std::shared_ptr channel = std::make_shared(socket);
diff --git a/include/scl/net/tcp_channel.h b/include/scl/net/tcp_channel.h
index 1bcbb85..f5af2fe 100644
--- a/include/scl/net/tcp_channel.h
+++ b/include/scl/net/tcp_channel.h
@@ -52,7 +52,8 @@ class TcpChannel final : public Channel {
bool Alive() const { return mAlive; };
void Send(const unsigned char* src, std::size_t n) override;
- int Recv(unsigned char* dst, std::size_t n) override;
+ std::size_t Recv(unsigned char* dst, std::size_t n) override;
+ bool HasData() override;
void Close() override;
private:
diff --git a/include/scl/net/tcp_utils.h b/include/scl/net/tcp_utils.h
index f10684a..fd75e70 100644
--- a/include/scl/net/tcp_utils.h
+++ b/include/scl/net/tcp_utils.h
@@ -64,7 +64,7 @@ AcceptedConnection AcceptConnection(int server_socket);
/**
* @brief Extra the hostname of an accepted connection.
*/
-std::string GetAddress(AcceptedConnection connection);
+std::string GetAddress(const AcceptedConnection& connection);
/**
* @brief Connect in client mode.
@@ -72,7 +72,7 @@ std::string GetAddress(AcceptedConnection connection);
* @param port the port of the server
* @return a socket.
*/
-int ConnectAsClient(std::string hostname, int port);
+int ConnectAsClient(const std::string& hostname, int port);
/**
* @brief Close a socket.
@@ -82,12 +82,12 @@ int CloseSocket(int socket);
/**
* @brief Read from a socket.
*/
-int ReadFromSocket(int socket, unsigned char* dst, std::size_t n);
+ssize_t ReadFromSocket(int socket, unsigned char* dst, std::size_t n);
/**
* @brief Write to a socket.
*/
-int WriteToSocket(int socket, const unsigned char* src, std::size_t n);
+ssize_t WriteToSocket(int socket, const unsigned char* src, std::size_t n);
} // namespace details
} // namespace scl
diff --git a/include/scl/net/threaded_sender.h b/include/scl/net/threaded_sender.h
index 55b8c84..f06e1b5 100644
--- a/include/scl/net/threaded_sender.h
+++ b/include/scl/net/threaded_sender.h
@@ -52,10 +52,12 @@ class ThreadedSenderChannel final : public Channel {
mSendBuffer.PushBack({src, src + n});
};
- int Recv(unsigned char* dst, std::size_t n) override {
+ std::size_t Recv(unsigned char* dst, std::size_t n) override {
return mChannel.Recv(dst, n);
};
+ bool HasData() override { return mChannel.HasData(); };
+
private:
TcpChannel mChannel;
details::SharedDeque> mSendBuffer;
diff --git a/include/scl/prg.h b/include/scl/prg.h
index e633512..5699aa8 100644
--- a/include/scl/prg.h
+++ b/include/scl/prg.h
@@ -120,8 +120,9 @@ class PRG {
* @param nbytes how many bytes to generate.
*/
void Next(std::vector &dest, std::size_t nbytes) {
- if (dest.size() < nbytes)
+ if (dest.size() < nbytes) {
throw std::invalid_argument("requested more randomness than dest.size()");
+ }
Next(dest.data(), nbytes);
};
diff --git a/include/scl/ss/additive.h b/include/scl/ss/additive.h
index 7976551..9d935af 100644
--- a/include/scl/ss/additive.h
+++ b/include/scl/ss/additive.h
@@ -38,7 +38,9 @@ namespace scl {
*/
template
Vec CreateAdditiveShares(const T& secret, std::size_t n, PRG& prg) {
- if (!n) throw std::invalid_argument("cannot create shares for 0 people");
+ if (!n) {
+ throw std::invalid_argument("cannot create shares for 0 people");
+ }
Vec shares = Vec::PartialRandom(
n, [](std::size_t i) { return i > 0; }, prg);
diff --git a/include/scl/ss/poly.h b/include/scl/ss/poly.h
index 61d2f05..04340a9 100644
--- a/include/scl/ss/poly.h
+++ b/include/scl/ss/poly.h
@@ -61,7 +61,9 @@ class Polynomial {
auto it = mCoefficients.rbegin();
auto end = mCoefficients.rend();
auto y = *it++;
- while (it != end) y = *it++ + y * x;
+ while (it != end) {
+ y = *it++ + y * x;
+ }
return y;
};
@@ -78,23 +80,23 @@ class Polynomial {
/**
* @brief Add two polynomials.
*/
- Polynomial Add(const Polynomial& p) const;
+ Polynomial Add(const Polynomial& q) const;
/**
* @brief Subtraction two polynomials.
*/
- Polynomial Subtract(const Polynomial& p) const;
+ Polynomial Subtract(const Polynomial& q) const;
/**
* @brief Multiply two polynomials.
*/
- Polynomial Multiply(const Polynomial& p) const;
+ Polynomial Multiply(const Polynomial& q) const;
/**
* @brief Divide two polynomials.
* @return A pair \f$(q, r)\f$ such that \f$\mathtt{this} = p * q + r\f$.
*/
- std::array Divide(const Polynomial& p) const;
+ std::array Divide(const Polynomial& q) const;
/**
* @brief Returns true if this is the 0 polynomial.
@@ -155,14 +157,18 @@ Polynomial Polynomial::Create(const Vec& coefficients) {
auto cutoff = coefficients.Size();
T zero;
for (; it != end; ++it) {
- if (*it != zero) break;
+ if (*it != zero) {
+ break;
+ }
--cutoff;
}
const auto c = Vec(coefficients.begin(), coefficients.begin() + cutoff);
- if (!c.Size())
+
+ if (!c.Size()) {
return Polynomial{};
- else
- return Polynomial{c};
+ }
+
+ return Polynomial{c};
}
/**
@@ -175,7 +181,9 @@ template
Vec PadCoefficients(const Polynomial& p, std::size_t n) {
Vec c(n);
for (std::size_t i = 0; i < n; ++i) {
- if (i <= p.Degree()) c[i] = p[i];
+ if (i <= p.Degree()) {
+ c[i] = p[i];
+ }
}
return c;
} // LCOV_EXCL_LINE
@@ -226,28 +234,33 @@ Polynomial DivideLeadingTerms(const Polynomial& p,
template
std::array, 2> Polynomial::Divide(
- const Polynomial& d) const {
- if (d.IsZero()) throw std::invalid_argument("division by 0");
+ const Polynomial& q) const {
+ if (q.IsZero()) {
+ throw std::invalid_argument("division by 0");
+ }
// https://en.wikipedia.org/wiki/Polynomial_long_division#Pseudocode
- Polynomial q;
+ Polynomial p;
Polynomial r = *this;
- while (!r.IsZero() && r.Degree() >= d.Degree()) {
- const auto t = DivideLeadingTerms(r, d);
- q = q.Add(t);
- r = r.Subtract(t.Multiply(d));
+ while (!r.IsZero() && r.Degree() >= q.Degree()) {
+ const auto t = DivideLeadingTerms(r, q);
+ p = p.Add(t);
+ r = r.Subtract(t.Multiply(q));
}
- return {q, r};
+ return {p, r};
}
template
-std::string Polynomial::ToString(const char* pn, const char* vn) const {
+std::string Polynomial::ToString(const char* polynomial_name,
+ const char* variable_name) const {
std::stringstream ss;
- ss << pn << "(" << vn << ") = " << mCoefficients[0];
+ ss << polynomial_name << "(" << variable_name << ") = " << mCoefficients[0];
for (std::size_t i = 1; i < mCoefficients.Size(); i++) {
- ss << " + " << mCoefficients[i] << vn;
- if (i > 1) ss << "^" << i;
+ ss << " + " << mCoefficients[i] << variable_name;
+ if (i > 1) {
+ ss << "^" << i;
+ }
}
return ss.str();
}
diff --git a/include/scl/ss/shamir.h b/include/scl/ss/shamir.h
index 367c7b5..b38a4c2 100644
--- a/include/scl/ss/shamir.h
+++ b/include/scl/ss/shamir.h
@@ -22,8 +22,11 @@
#define SCL_SS_SHAMIR_H
#include
-#include
+#include
#include
+#include
+#include
+#include
#include "scl/math/la.h"
#include "scl/math/vec.h"
@@ -34,200 +37,42 @@ namespace scl {
namespace details {
/**
- * @brief Create a random polynomial suitable for create Shamir secret-shares.
- * @param secret the secret to embed in the constant term of the polynomial
- * @param t the threshold that this polynomial should support
- * @param prg a PRG used to generate random coefficients
- * @return a Polynomial that can be used to generate degree t shares.
- */
-template
-details::Polynomial CreateShamirSharePolynomial(const T& secret,
- std::size_t t, PRG& prg) {
- if (!t) throw std::invalid_argument("threshold cannot be 0");
- auto coeff = Vec::PartialRandom(
- t + 1, [](std::size_t i) { return i > 0; }, prg);
- coeff[0] = secret;
- return details::Polynomial::Create(coeff);
-}
-
-/**
- * @brief Interpolate a polynomial given a list of evaluations (f(x), x).
- * @param ys the evaluation results
- * @param xs the evaluation points
- * @param k the number of points to use
- * @param x the point at which to interpolate
- * @param offset an offset into the \p ys and \p xs
- * @return f(\p x) where f was interpolated from \p k of the provided points.
- */
-template
-T InterpolateAt(const Vec& ys, const Vec& xs, std::size_t k, const T& x,
- std::size_t offset) {
- T z;
- for (std::size_t j = 0; j < k; ++j) {
- T ell(1);
- auto xj = xs[offset + j];
- for (std::size_t m = 0; m < k; ++m) {
- if (m == j) continue;
- auto xm = xs[offset + m];
- ell *= (x - xm) / (xj - xm);
- }
- z += ys[offset + j] * ell;
- }
- return z;
-}
-} // namespace details
-
-/**
- * @brief Create a collection of shamir shares given a polynomial and points.
- * @param sharing_polynomial the polynomial used to generate the shares
- * @param alphas the evaluation points
- * @tparam T a finite field
- * @return A vector of shamir secret-shares.
- */
-template
-Vec CreateShamirShares(const details::Polynomial& sharing_polynomial,
- const Vec& alphas) {
- auto n = alphas.Size();
- Vec shares(n);
- for (std::size_t i = 0; i < n; i++) {
- shares[i] = sharing_polynomial.Evaluate(alphas[i]);
- }
- return shares;
-} // LCOV_EXCL_LINE
-
-/**
- * @brief Return the list of canonical evaluation points
- * @param n the number of points
- * @tparam T a finite field
- * @return the list [1, 2, ..., n] of finite field elements.
- */
-template
-Vec CanonicalAlphas(std::size_t n) {
- Vec alphas(n);
- for (std::size_t i = 0; i < n; i++) alphas[i] = T{(int)(i + 1)};
- return alphas;
-} // LCOV_EXCL_LINE
-
-/**
- * @brief Create a shamir secret-sharing.
- * @param secret the secret
- * @param n the number of shares to create
- * @param t the privacy threshold
- * @param prg a PRG for randomness
- * @return a vector of Shamir secret-shares.
- */
-template
-Vec CreateShamirShares(const T& secret, std::size_t n, std::size_t t,
- PRG& prg) {
- return CreateShamirShares(
- details::CreateShamirSharePolynomial(secret, t, prg),
- CanonicalAlphas(n));
-}
-
-/**
- * @brief Reconstruct a shamir shared secret with passive security.
- * @param shares the shars
- * @param alphas the alphas
- * @param pos the position of the secret
- * @param t the threshold
- * @return the reconstructed value.
- * @throws std::invalid_argument if less than \f$t+1\f$ shares were given
- * @throws std::invalid_argument if less than \f$t+1\f$ alphas were given
- */
-template
-T ReconstructShamirPassive(const Vec& shares, const Vec& alphas,
- const T& pos, std::size_t t) {
- if (t + 1 > shares.Size())
- throw std::invalid_argument("not enough shares to reconstruct");
- if (t + 1 > alphas.Size())
- throw std::invalid_argument("not enough alphas to reconstruct");
- return InterpolateAt(shares, alphas, t + 1, pos, 0);
-}
-
-/**
- * @brief Reconstruct a Shamir shared secret with passive security.
- *
- * This method makes no guarantee regarding the reconstructed value. In
- * particular, this method should only be used in a passive security model.
+ * @brief Robust reconstruction.
*
- * The shares is assumed to have been generated using alphas as output by \ref
- * scl::CanonicalAlphas, and the secret is assumed to have been placed on the
- * constant term of the sharing polynomial.
- *
- * @param shares the shares
- * @param t the threshold
- * @return the reconstructed secret.
- * @see ReconstructShamirPassive
- */
-template
-T ReconstructShamirPassive(const Vec& shares, std::size_t t) {
- return ReconstructShamirPassive(shares, CanonicalAlphas(shares.Size()),
- T(0), t);
-}
-
-/**
- * @brief Reconstruct a Shamir shared secret with error detection.
+ * This function performs a robust Shamir secret-share reconstruction. Given
+ * at least \f$3t + 1\f$ pairs \f$(f(x),x)\f$, where each \f$f(x)\f$ is a share
+ * and \f$x\f$ a corresponding evaluation alpha, finds the degree \f$t\f$
+ * polynomial \f$f\f$ that passes through all supplied shares.
*
- * This method will attempt to reconstruct a degree \f$t\f$ shared secret from
- * at least \f$2t+1\f$ provided points. In case the provided shares do not all
- * lie on a degree \f$t\f$ polynomial, an exception is thrown.
+ * The return value is a pair of polynomials \f$(f,e)\f$ where \f$f\f$ is the
+ * reconstructed polynomial and \f$e\f$ a polynomial whose roots indicate which
+ * (if any) of the input shares were invalid. I.e., \f$e(i)=0\f$ if
+ * share[i] had an error.
*
- * @param shares the shares
- * @param alphas the alphas
- * @param pos the position of the secret
- * @param t the threshold
- * @return the correct reconstructed value.
- * @throws std::invalid_argument if less than \f$2t+1\f$ shares were given
- * @throws std::invalid_argument if less than \f$2t+1\f$ alphas were given
- * @throws std::logic_error if one of the shares contained an error
- */
-template
-T ReconstructShamir(const Vec& shares, const Vec& alphas, const T& pos,
- std::size_t t) {
- if (2 * t + 1 > shares.Size())
- throw std::invalid_argument(
- "not enough shares to reconstruct with error detection");
- if (2 * t + 1 > alphas.Size())
- throw std::invalid_argument(
- "not enough alphas to reconstruct with error detection");
-
- for (std::size_t k = t + 1; k < 2 * t + 1; ++k) {
- auto s = InterpolateAt(shares, alphas, t + 1, alphas[k], 0);
- if (s != shares[k])
- throw std::logic_error("error detected during reconstruction");
- }
- return InterpolateAt(shares, alphas, t + 1, pos, 0);
-}
-
-/**
- * @brief Reconstruct a Shamir shared secret with error detection.
- * @see ReconstructShamir
- */
-template
-T ReconstructShamir(const Vec& shares, std::size_t t) {
- return ReconstructShamir(shares, CanonicalAlphas(shares.Size()), T(0), t);
-}
-
-/**
- * @brief Reconstruct a Shamir shared secret with error correction.
+ * @param shares the shares to use for reconstruction
+ * @param alphas the evaluation alphas
+ * @param t the degree of the polynomial to reconstruct
+ * @return a pair of polynomials.
*/
template
-std::array, 2> ReconstructShamirRobust(
- const Vec& shares, const Vec& alphas, std::size_t t) {
+auto ReconstructShamirRobust(const Vec& shares, const Vec& alphas,
+ std::size_t t) {
std::size_t n = 3 * t + 1;
- if (n > shares.Size())
+ if (n > shares.Size()) {
throw std::invalid_argument(
"not enough shares to reconstruct with error correction");
- if (n > alphas.Size())
+ }
+ if (n > alphas.Size()) {
throw std::invalid_argument(
"not enough alphas to reconstruct with error correction");
+ }
Mat A(n);
Vec b(n);
Vec x(n);
int e;
for (std::size_t k = 0; k <= t; ++k) {
- e = t - k;
+ e = t - k; // NOLINT
for (std::size_t i = 0; i < n; ++i) {
b[i] = -shares[i];
@@ -243,7 +88,9 @@ std::array, 2> ReconstructShamirRobust(
}
}
- if (SolveLinearSystem(x, A, b)) break;
+ if (SolveLinearSystem(x, A, b)) {
+ break;
+ }
}
Vec cE{x.begin(), x.begin() + e + 1};
@@ -252,21 +99,303 @@ std::array, 2> ReconstructShamirRobust(
auto E = details::Polynomial::Create(cE);
auto Q = details::Polynomial::Create(Vec{x.begin() + e, x.end()});
auto qr = Q.Divide(E);
- if (qr[1].IsZero()) return {qr[0], E};
- throw std::logic_error("could not correct shares");
+ if (!qr[1].IsZero()) {
+ throw std::logic_error("could not correct shares");
+ }
+
+ return std::make_pair(qr[0], E);
}
/**
- * @brief Reconstruct a Shamir shared secret with error correction.
+ * @brief Robust reconstruction.
+ *
+ * This function is a short-hand for ReconstructShamirRobust(shares,
+ * Vec::Range(1, shares.Size() + 1), t).
+ *
+ * @param shares the shares
+ * @param t the degree of the polynomial to reconstruct
+ * @return \f$f(0)\f$ where \f$f\f$ is the reconstructed polynomial.
*/
template
T ReconstructShamirRobust(const Vec& shares, std::size_t t) {
auto p =
- ReconstructShamirRobust(shares, CanonicalAlphas(shares.Size()), t);
- return p[0].Evaluate(T(0));
+ ReconstructShamirRobust(shares, Vec::Range(1, shares.Size() + 1), t);
+ return std::get<0>(p).Evaluate(T(0));
+}
+
+/**
+ * @brief Defines some common security levels.
+ *
+ * These security levels are used to derive some common defaults. For example,
+ * SecurityLevel::DETECT is used to indicate that ShamirSSFactory, when
+ * instantiated with a threshold of \f$t\f$, should generate \f$2t + 1\f$ shares
+ * if no other value is specified.
+ */
+enum class SecurityLevel {
+ /**
+ * @brief \f$t + 1\f$.
+ */
+ PASSIVE,
+
+ /**
+ * @brief \f$2t + 1\f$. Enough shares to detect errors.
+ */
+ DETECT,
+
+ /**
+ * @brief \f$3t + 1\f$. Enough shares to correct errors.
+ */
+ CORRECT
+};
+
+/**
+ * @brief Class for reconstructing secrets from Shamir secret-shares.
+ *
+ * The main purpose of this class is to cache the lagrange coefficients used
+ * when performing interpolation. These coefficients are computed when they are
+ * first needed. The only exception is the coefficients needed to reconstruct
+ * the point at 0, which is usually where the secret is placed.
+ *
+ * A Reconstructor contains essentially two methods for reconstructing:
+ *
+ * - Reconstruct reconstructs a particular evaluation of the implied
+ * polynomial.
+ * - ReconstructShare reconstructs a particular share of a party
+ *
+ * Both methods do effectively the same, and the only difference is how the
+ * index passed to them is interpreted. For the first, the index is interpreted
+ * as-is. I.e., if we pass an index of 3, then it will recover the point
+ * \f$f(3)\f$, where \f$f\f$ is the polynomial implied by the shares we provide.
+ * For the latter method, passing an index of \f$i\f$ implies that we wish to
+ * interpolate the point \f$f(i + 1)\f$. This is because we assume parties are
+ * indexed from 0, but their evaluation indices are indexed from 1 (because the
+ * 0'th index is reserved for the secret).
+ */
+template
+class Reconstructor {
+ public:
+ /**
+ * @brief Create a new Reconstructor instance.
+ *
+ * This creator method creates a new instance and pre-computes the
+ * coefficients needed to recover a secret from a set of shamir shares. The
+ * SecurityLevel given is used to infer the default reconstruction method:
+ *
+ * - SecurityLevel::PASSIVE: use passive reconstruction. I.e., reconstruct
+ * the value we ask for, using the minimal amount of \f$t + 1\f$ shares
+ * needed.
+ *
+ * - SecurityLevel::DETECT: use the first \f$t+1\f$ shares to reconstruct
+ * the remaning \f$t\f$ shares. If all of these check out, then reconstruct
+ * the value we seek.
+ *
+ * - SecurityLevel::CORRECT: use an error correcting
+ * algorithm to recover the polynomial \f$f\f$ from \f$3t+1\f$ shares,
+ * assuming that at most \f$t\f$ shares are faulty.
+ *
+ *
+ * @param threshold the threshold \f$t\f$
+ * @param default_security_level the default security level
+ */
+ static Reconstructor Create(std::size_t threshold,
+ SecurityLevel default_security_level);
+
+ /**
+ * @brief Interpolate a set of shares according to a given security level
+ * @param shares the shares
+ * @param security_level the security level
+ * @param index the index to interpolate. Defaults to 0
+ */
+ T Reconstruct(const Vec& shares, SecurityLevel security_level,
+ int index = 0) const;
+
+ /**
+ * @brief Interpolate a set of shares with the default security level
+ * @param shares the shares
+ * @param index the index to interpolate. Defaults to 0
+ */
+ T Reconstruct(const Vec& shares, int index = 0) const {
+ return Reconstruct(shares, mSecurityLevel, index);
+ };
+
+ /**
+ * @brief Reconstruct the share of a party.
+ */
+ T ReconstructShare(const Vec& shares, SecurityLevel level,
+ int party_index = 0) const {
+ return Reconstruct(shares, level, party_index + 1);
+ };
+
+ /**
+ * @brief Reconstruct the share of a party.
+ */
+ T ReconstructShare(const Vec& shares, int party_index = 0) const {
+ return Reconstruct(shares, party_index + 1);
+ };
+
+ /**
+ * @brief Get the lagrange coefficients for interpolating a particular index.
+ *
+ * This method modifies the internal cache in case coefficients for the
+ * requested index does not exist.
+ */
+ const Vec& GetLagrangeCoefficients(int index) const {
+ if (mLagrangeCoeff.count(index) == 0) {
+ ComputeLagrangeCoefficients(index);
+ }
+ return mLagrangeCoeff[index];
+ };
+
+ private:
+ Reconstructor(std::size_t threshold, SecurityLevel security_level)
+ : mThreshold(threshold), mSecurityLevel(security_level){};
+
+ void ComputeLagrangeCoefficients(int index) const;
+
+ std::size_t mThreshold;
+ SecurityLevel mSecurityLevel;
+ mutable std::unordered_map> mLagrangeCoeff;
+};
+
+template
+Reconstructor Reconstructor::Create(
+ std::size_t threshold, SecurityLevel default_security_level) {
+ Reconstructor intr(threshold, default_security_level);
+ intr.ComputeLagrangeCoefficients(0);
+ return intr;
}
+template
+T Reconstructor::Reconstruct(const Vec& shares,
+ SecurityLevel security_level, int index) const {
+ if (security_level == SecurityLevel::CORRECT) {
+ // TODO(anders): Currently only supports index = 0
+ return details::ReconstructShamirRobust(shares, mThreshold);
+ }
+
+ const auto min_size = mThreshold + 1;
+ if (shares.Size() < min_size) {
+ throw std::invalid_argument("not enough shares to reconstruct");
+ }
+
+ const auto& coeff = GetLagrangeCoefficients(index);
+ auto x = details::UncheckedInnerProd(
+ shares.begin(), shares.begin() + min_size, coeff.begin());
+
+ if (security_level == SecurityLevel::DETECT) {
+ const auto n = 2 * mThreshold + 1;
+ if (shares.Size() < n) {
+ throw std::invalid_argument("not enough shares to detect errors");
+ }
+
+ for (std::size_t i = min_size; i < n; ++i) {
+ const auto& ccoeff = GetLagrangeCoefficients(i + 1);
+ const auto c = details::UncheckedInnerProd(
+ shares.begin(), shares.begin() + min_size, ccoeff.begin());
+ if (c != shares[i]) {
+ throw std::logic_error("error detected during reconstruction");
+ }
+ }
+ }
+ return x;
+}
+
+template
+void Reconstructor::ComputeLagrangeCoefficients(int index) const {
+ Vec coeff(mThreshold + 1);
+ const auto x = T(index);
+ for (std::size_t j = 0; j <= mThreshold; ++j) {
+ auto ell = T::One();
+ const auto xj = T(j + 1);
+ for (std::size_t m = 0; m <= mThreshold; ++m) {
+ if (j != m) {
+ const auto xm = T(m + 1);
+ ell *= (x - xm) / (xj - xm);
+ }
+ }
+ coeff[j] = ell;
+ }
+ mLagrangeCoeff[index] = coeff;
+}
+
+/**
+ * @brief A factory object for creating Shamir secret shares.
+ * @tparam T the finite field to use.
+ */
+template
+class ShamirSSFactory {
+ public:
+ /**
+ * @brief Create a new Shamir secret share factory
+ * @param t the privacy threshold to use
+ * @param prg the PRG to use when generating new shares
+ * @param security_level the SecurityLevel to use. Default to
+ * SecurityLevel::DETECT
+ */
+ ShamirSSFactory(std::size_t t, PRG& prg,
+ SecurityLevel security_level = SecurityLevel::DETECT)
+ : mThreshold(t), mPrg(prg), mDefaultSecurityLevel(security_level){};
+
+ /**
+ * @brief Create a new set of Shamir secret shares of a secret
+ * @param secret the secret
+ * @param number_of_shares the number of shares to generate. If empty, then
+ * the number of shares is determined by GetDefaultNumberOfShares
+ * @return a vector of shares.
+ */
+ Vec Share(const T& secret,
+ std::optional number_of_shares = {});
+
+ /**
+ * @brief Get the number of shares to create based on set security level
+ *
+ * The returned number is determined based on the \ref SecurityLevel
+ * provided during construction. Given a threshold \f$t\f$, then the number of
+ * shares is computed based on the following rules:
+ */
+ std::size_t GetDefaultNumberOfShares() const {
+ switch (mDefaultSecurityLevel) {
+ case SecurityLevel::PASSIVE:
+ return mThreshold + 1;
+ case SecurityLevel::DETECT:
+ return 2 * mThreshold + 1;
+ default: // SecurityLevel::ROBUST:
+ return 3 * mThreshold + 1;
+ }
+ };
+
+ /**
+ * @brief Get an scl::Interpolator suitable for this factory.
+ */
+ Reconstructor GetInterpolator() const {
+ return Reconstructor::Create(mThreshold, mDefaultSecurityLevel);
+ };
+
+ private:
+ std::size_t mThreshold;
+ PRG mPrg;
+ SecurityLevel mDefaultSecurityLevel;
+};
+
+template
+Vec ShamirSSFactory::Share(const T& secret,
+ std::optional number_of_shares) {
+ auto coeff = Vec::PartialRandom(
+ mThreshold + 1, [](std::size_t i) { return i > 0; }, mPrg);
+ coeff[0] = secret;
+ auto p = details::Polynomial::Create(coeff);
+
+ auto n = number_of_shares.value_or(GetDefaultNumberOfShares());
+ std::vector shares;
+ shares.reserve(n);
+ for (std::size_t i = 0; i < n; ++i) {
+ shares.emplace_back(p.Evaluate(T(i + 1)));
+ }
+ return Vec(shares);
+}
+
+} // namespace details
} // namespace scl
#endif // SCL_SS_SHAMIR_H
diff --git a/src/scl/hash.cc b/src/scl/hash.cc
index 5706436..6c34d9b 100644
--- a/src/scl/hash.cc
+++ b/src/scl/hash.cc
@@ -29,13 +29,16 @@ void scl::Keccakf(uint64_t state[25]) {
uint64_t bc[5];
for (std::size_t round = 0; round < 24; ++round) {
- for (std::size_t i = 0; i < 5; ++i)
+ for (std::size_t i = 0; i < 5; ++i) {
bc[i] = state[i] ^ state[i + 5] ^ state[i + 10] ^ state[i + 15] ^
state[i + 20];
+ }
for (std::size_t i = 0; i < 5; ++i) {
t = bc[(i + 4) % 5] ^ rotl64(bc[(i + 1) % 5], 1);
- for (std::size_t j = 0; j < 25; j += 5) state[j + i] ^= t;
+ for (std::size_t j = 0; j < 25; j += 5) {
+ state[j + i] ^= t;
+ }
}
t = state[1];
@@ -47,9 +50,12 @@ void scl::Keccakf(uint64_t state[25]) {
}
for (std::size_t j = 0; j < 25; j += 5) {
- for (std::size_t i = 0; i < 5; ++i) bc[i] = state[j + i];
- for (std::size_t i = 0; i < 5; ++i)
+ for (std::size_t i = 0; i < 5; ++i) {
+ bc[i] = state[j + i];
+ }
+ for (std::size_t i = 0; i < 5; ++i) {
state[j + i] ^= (~bc[(i + 1) % 5]) & bc[(i + 2) % 5];
+ }
}
state[0] ^= keccakf_rndc[round];
diff --git a/src/scl/math/mersenne127.cc b/src/scl/math/mersenne127.cc
index 3b9962e..d024a5a 100644
--- a/src/scl/math/mersenne127.cc
+++ b/src/scl/math/mersenne127.cc
@@ -57,8 +57,10 @@ struct u256 {
// https://cp-algorithms.com/algebra/montgomery_multiplication.html
u256 MultiplyFull(const u128 x, const u128 y) {
- u64 a = x >> 64, b = x;
- u64 c = y >> 64, d = y;
+ u64 a = x >> 64;
+ u64 b = x;
+ u64 c = y >> 64;
+ u64 d = y;
// (a*2^64 + b) * (c*2^64 + d) =
// (a*c) * 2^128 + (a*d + b*c)*2^64 + (b*d)
u128 ac = (u128)a * c;
@@ -66,9 +68,9 @@ u256 MultiplyFull(const u128 x, const u128 y) {
u128 bc = (u128)b * c;
u128 bd = (u128)b * d;
- u128 carry = (u128)(u64)ad + (u128)(u64)bc + (bd >> 64u);
- u128 high = ac + (ad >> 64u) + (bc >> 64u) + (carry >> 64u);
- u128 low = (ad << 64u) + (bc << 64u) + bd;
+ u128 carry = (u128)(u64)ad + (u128)(u64)bc + (bd >> 64U);
+ u128 high = ac + (ad >> 64U) + (bc >> 64U) + (carry >> 64U);
+ u128 low = (ad << 64U) + (bc << 64U) + bd;
return {high, low};
}
@@ -115,8 +117,8 @@ std::string scl::details::FieldToString(const u128& in) {
}
template <>
-void scl::details::FieldFromString(u128& dest,
- const std::string& str) {
- dest = FromHexString(str);
- dest = dest % p;
+void scl::details::FieldFromString(u128& out,
+ const std::string& src) {
+ out = FromHexString(src);
+ out = out % p;
}
diff --git a/src/scl/math/mersenne61.cc b/src/scl/math/mersenne61.cc
index 2a72759..2866028 100644
--- a/src/scl/math/mersenne61.cc
+++ b/src/scl/math/mersenne61.cc
@@ -90,8 +90,8 @@ std::string scl::details::FieldToString(const u64& in) {
}
template <>
-void scl::details::FieldFromString(u64& dest,
- const std::string& str) {
- dest = FromHexString(str);
- dest = dest % p;
+void scl::details::FieldFromString(u64& out,
+ const std::string& src) {
+ out = FromHexString(src);
+ out = out % p;
}
diff --git a/src/scl/math/number.cc b/src/scl/math/number.cc
index 8a46cf6..0134768 100644
--- a/src/scl/math/number.cc
+++ b/src/scl/math/number.cc
@@ -31,7 +31,7 @@ scl::Number::Number(const Number& number) : Number() {
mpz_set(mValue, number.mValue);
}
-scl::Number::Number(Number&& number) : Number() {
+scl::Number::Number(Number&& number) noexcept : Number() {
mpz_set(mValue, number.mValue);
}
@@ -47,7 +47,7 @@ scl::Number scl::Number::Random(std::size_t bits, PRG& prg) {
scl::Number r;
mpz_import(r.mValue, len - 1, 1, 1, 0, 0, data.get() + 1);
- if (data[0] & 1) {
+ if ((data[0] & 1) != 0) {
mpz_neg(r.mValue, r.mValue);
}
return r;
@@ -92,23 +92,23 @@ scl::Number scl::Number::operator/(const Number& number) const {
} // LCOV_EXCL_LINE
scl::Number scl::Number::operator<<(int shift) const {
+ scl::Number shifted;
if (shift < 0) {
- return operator>>(-shift);
+ shifted = operator>>(-shift);
} else {
- scl::Number shifted;
mpz_mul_2exp(shifted.mValue, mValue, shift);
- return shifted;
}
+ return shifted;
}
scl::Number scl::Number::operator>>(int shift) const {
+ scl::Number shifted;
if (shift < 0) {
- return operator<<(-shift);
+ shifted = operator<<(-shift);
} else {
- scl::Number shifted;
mpz_tdiv_q_2exp(shifted.mValue, mValue, shift);
- return shifted;
}
+ return shifted;
}
scl::Number scl::Number::operator^(const Number& number) const {
diff --git a/src/scl/math/ops_gmp_ff.h b/src/scl/math/ops_gmp_ff.h
index 4d947ee..704d14a 100644
--- a/src/scl/math/ops_gmp_ff.h
+++ b/src/scl/math/ops_gmp_ff.h
@@ -36,11 +36,11 @@ namespace details {
#define BITS_PER_LIMB static_cast(mp_bits_per_limb)
#define BYTES_PER_LIMB sizeof(mp_limb_t)
-#define SCL_COPY(out, in, size) \
- do { \
- for (std::size_t i = 0; i < size; ++i) { \
- *((out) + i) = *((in) + i); \
- } \
+#define SCL_COPY(out, in, size) \
+ do { \
+ for (std::size_t i = 0; i < (size); ++i) { \
+ *((out) + i) = *((in) + i); \
+ } \
} while (0)
/**
@@ -198,8 +198,8 @@ inline bool TestBit(const mp_limb_t* v, std::size_t pos) {
template
void ModExp(mp_limb_t* out, const mp_limb_t* x, const mp_limb_t* e,
const mp_limb_t* mod, const mp_limb_t* np) {
- auto n = mpn_scan1(e, N * BITS_PER_LIMB);
- for (int i = n - 1; i >= 0; --i) {
+ auto n = mpn_sizeinbase(e, N, 2);
+ for (std::size_t i = n; i-- > 0;) {
ModSqr(out, out, mod, np);
if (TestBit(e, i)) {
ModMul(out, x, mod, np);
@@ -262,25 +262,29 @@ std::string ToString(const mp_limb_t* val, const mp_limb_t* mod,
}
}
auto s = ss.str();
+
// trim leading 0s
auto n = FindFirstNonZero(s);
if (n > 0) {
s = s.substr(n, s.length() - 1);
}
+
if (s.length()) {
return s;
- } else {
- return "0";
}
+
+ return "0";
}
template
void FromString(mp_limb_t* out, const mp_limb_t* mod, const std::string& str) {
if (str.length()) {
- auto n = str.length();
- if (n > 64) {
+ auto n_ = str.length();
+ if (n_ > 64) {
throw std::invalid_argument("hex string too large to parse");
}
+ // to silence conversion errors. Safe to do because n_ is pretty small.
+ int n = static_cast(n_);
std::string s = str;
if (n % 2) {
@@ -288,10 +292,10 @@ void FromString(mp_limb_t* out, const mp_limb_t* mod, const std::string& str) {
n++;
}
- const auto m = 2 * BYTES_PER_LIMB;
+ const auto m = static_cast(2 * BYTES_PER_LIMB);
int c = (n - 1) / m;
auto beg = s.begin();
- for (std::size_t i = 0; i < n && c >= 0; i += m) {
+ for (int i = 0; i < n && c >= 0; i += m) {
auto end = std::min(n, i + m);
out[c--] = FromHexString(std::string(beg + i, beg + end));
}
@@ -299,7 +303,6 @@ void FromString(mp_limb_t* out, const mp_limb_t* mod, const std::string& str) {
}
}
-#undef SCL_COPY
#undef BITS_PER_LIMB
#undef BYTES_PER_LIMB
diff --git a/src/scl/math/ops_small_fp.h b/src/scl/math/ops_small_fp.h
index 3eeb39b..0c5b2f2 100644
--- a/src/scl/math/ops_small_fp.h
+++ b/src/scl/math/ops_small_fp.h
@@ -32,7 +32,9 @@ namespace details {
template
void ModAdd(T& t, const T& v, const T& m) {
t = t + v;
- if (t >= m) t = t - m;
+ if (t >= m) {
+ t = t - m;
+ }
}
/**
@@ -40,10 +42,11 @@ void ModAdd(T& t, const T& v, const T& m) {
*/
template
void ModSub(T& t, const T& v, const T& m) {
- if (v > t)
+ if (v > t) {
t = t + m - v;
- else
+ } else {
t = t - v;
+ }
}
/**
@@ -51,7 +54,9 @@ void ModSub(T& t, const T& v, const T& m) {
*/
template
void ModNeg(T& t, const T& m) {
- if (t) t = m - t;
+ if (t) {
+ t = m - t;
+ }
}
/**
@@ -62,8 +67,8 @@ void ModInv(T& t, const T& v, const T& m) {
#define SCL_PARALLEL_ASSIGN(v1, v2, q) \
do { \
const auto __temp = v2; \
- v2 = v1 - q * __temp; \
- v1 = __temp; \
+ (v2) = (v1) - (q)*__temp; \
+ (v1) = __temp; \
} while (0)
if (v == 0) {
diff --git a/src/scl/math/secp256k1_curve.cc b/src/scl/math/secp256k1_curve.cc
index c5d35ca..afa5b9f 100644
--- a/src/scl/math/secp256k1_curve.cc
+++ b/src/scl/math/secp256k1_curve.cc
@@ -18,6 +18,7 @@
* along with this program. If not, see .
*/
+#include
#include
#include "./secp256k1_extras.h"
@@ -34,9 +35,9 @@ using Point = Curve::ValueType;
#define POINT_AT_INFINITY Point{{Field{0}, Field{1}, Field{0}}}
// clang-format on
-#define _X(point) (point)[0]
-#define _Y(point) (point)[1]
-#define _Z(point) (point)[2]
+#define GET_X(point) (point)[0]
+#define GET_Y(point) (point)[1]
+#define GET_Z(point) (point)[2]
static const Field kCurveB(7);
@@ -66,42 +67,37 @@ void scl::details::CurveSetAffine(Point& out, const Field& x,
}
template <>
-bool scl::details::CurveEqual(const Point& in1, const Point& in2) {
- const auto Z1 = _Z(in1);
- const auto Z2 = _Z(in2);
- // (X1, Y1, Z1) eqv (X2, Y2, Z2) <==> (X1 * Z2, Y1 * Z2) == (X2 * Z1, Y2 * Z2)
- return _X(in1) * Z2 == _X(in2) * Z1 && _Y(in1) * Z2 == _Y(in2) * Z1;
+std::array scl::details::CurveToAffine(const Point& point) {
+ const auto Z = GET_Z(point);
+ return {GET_X(point) / Z, GET_Y(point) / Z};
}
template <>
-bool scl::details::CurveIsPointAtInfinity(const Point& out) {
- return CurveEqual(out, POINT_AT_INFINITY);
+bool scl::details::CurveEqual(const Point& in1, const Point& in2) {
+ const auto Z1 = GET_Z(in1);
+ const auto Z2 = GET_Z(in2);
+ // (X1, Y1, Z1) eqv (X2, Y2, Z2) <==> (X1 * Z2, Y1 * Z2) == (X2 * Z1, Y2 * Z2)
+ return GET_X(in1) * Z2 == GET_X(in2) * Z1 &&
+ GET_Y(in1) * Z2 == GET_Y(in2) * Z1;
}
-struct AffinePoint {
- Field x;
- Field y;
-};
-
-namespace {
-
-AffinePoint ToAffine(const Point& point) {
- const auto Z = _Z(point);
- return {_X(point) / Z, _Y(point) / Z};
+template <>
+bool scl::details::CurveIsPointAtInfinity(const Point& point) {
+ return CurveEqual(point, POINT_AT_INFINITY);
}
-} // namespace
-
template <>
std::string scl::details::CurveToString(const Point& point) {
+ std::string str;
if (CurveIsPointAtInfinity(point)) {
- return "EC{POINT_AT_INFINITY}";
+ str = "EC{POINT_AT_INFINITY}";
} else {
- auto ap = ToAffine(point);
+ auto ap = CurveToAffine(point);
std::stringstream ss;
- ss << "EC{" << ap.x << ", " << ap.y << "}";
- return ss.str();
+ ss << "EC{" << ap[0] << ", " << ap[1] << "}";
+ str = ss.str();
}
+ return str;
}
template <>
@@ -119,12 +115,12 @@ void scl::details::CurveSetGenerator(Point& out) {
template <>
void scl::details::CurveDouble(Point& out) {
if (!CurveIsPointAtInfinity(out)) {
- if (_Y(out) == Field::Zero()) {
+ if (GET_Y(out) == Field::Zero()) {
CurveSetPointAtInfinity(out);
} else if (!CurveIsPointAtInfinity(out)) {
- const auto X = _X(out);
- const auto Y = _Y(out);
- const auto Z = _Z(out);
+ const auto X = GET_X(out);
+ const auto Y = GET_Y(out);
+ const auto Z = GET_Z(out);
const auto W = Field(3) * X * X;
const auto S = Y * Z;
@@ -141,16 +137,16 @@ void scl::details::CurveDouble(Point& out) {
}
template <>
-void scl::details::CurveAdd(Point& out, const Point& op) {
+void scl::details::CurveAdd(Point& out, const Point& in) {
if (CurveIsPointAtInfinity(out)) {
- out = op;
- } else if (!CurveIsPointAtInfinity(op)) {
- const auto X1 = _X(out);
- const auto Y1 = _Y(out);
- const auto Z1 = _Z(out);
- const auto X2 = _X(op);
- const auto Y2 = _Y(op);
- const auto Z2 = _Z(op);
+ out = in;
+ } else if (!CurveIsPointAtInfinity(in)) {
+ const auto X1 = GET_X(out);
+ const auto Y1 = GET_Y(out);
+ const auto Z1 = GET_Z(out);
+ const auto X2 = GET_X(in);
+ const auto Y2 = GET_Y(in);
+ const auto Z2 = GET_Z(in);
const auto U1 = Y2 * Z1;
const auto U2 = Y1 * Z2;
@@ -180,16 +176,16 @@ void scl::details::CurveAdd(Point& out, const Point& op) {
template <>
void scl::details::CurveNegate(Point& out) {
- if (_Y(out) == Field::Zero()) {
+ if (GET_Y(out) == Field::Zero()) {
CurveSetPointAtInfinity(out);
} else {
- _Y(out).Negate();
+ GET_Y(out).Negate();
}
}
template <>
-void scl::details::CurveSubtract(Point& out, const Point& op) {
- Point copy(op);
+void scl::details::CurveSubtract(Point& out, const Point& in) {
+ Point copy(in);
CurveNegate(copy);
CurveAdd(out, copy);
}
@@ -201,7 +197,8 @@ void scl::details::CurveScalarMultiply(Point& out,
const auto n = scalar.BitSize();
Point res;
CurveSetPointAtInfinity(res);
- for (int i = n - 1; i >= 0; --i) {
+ // equivalent to for (int i = n - 1; i >= 0; i--)
+ for (auto i = n; i-- > 0;) {
CurveDouble(res);
if (scalar.TestBit(i)) {
CurveAdd(res, out);
@@ -215,12 +212,13 @@ template <>
void scl::details::CurveScalarMultiply(Point& out,
const FF& scalar) {
if (!CurveIsPointAtInfinity(out)) {
- const auto n = scl::SCL_FF_Extras::HigestSetBit(scalar);
+ auto x = scl::SCL_FF_Extras::FromMonty(scalar);
+ const auto n = scl::SCL_FF_Extras::HigestSetBit(x);
Point res;
CurveSetPointAtInfinity(res);
- for (int i = n - 1; i >= 0; --i) {
+ for (auto i = n; i-- > 0;) {
CurveDouble(res);
- if (scl::SCL_FF_Extras::TestBit(scalar, i)) {
+ if (scl::SCL_FF_Extras::TestBit(x, i)) {
CurveAdd(res, out);
}
}
@@ -232,9 +230,9 @@ void scl::details::CurveScalarMultiply(Point& out,
#define POINT_AT_INFINITY_FLAG 0x02
#define SELECT_SMALLER_FLAG 0x01
-#define IS_COMPRESSED(flags) (flags & COMPRESSED_FLAG)
-#define IS_POINT_AT_INFINITY(flags) (flags & POINT_AT_INFINITY_FLAG)
-#define SELECT_SMALLER(flags) (flags & SELECT_SMALLER_FLAG)
+#define IS_COMPRESSED(flags) ((flags)&COMPRESSED_FLAG)
+#define IS_POINT_AT_INFINITY(flags) ((flags)&POINT_AT_INFINITY_FLAG)
+#define SELECT_SMALLER(flags) ((flags)&SELECT_SMALLER_FLAG)
namespace {
@@ -272,9 +270,9 @@ void scl::details::CurveFromBytes(Point& out, const unsigned char* src) {
auto smaller = IsSmaller(y, yn);
auto select_smaller = SELECT_SMALLER(flags);
if (smaller) {
- out[1] = select_smaller ? y : yn;
+ out[1] = select_smaller == 0 ? yn : y;
} else {
- out[1] = select_smaller ? yn : y;
+ out[1] = select_smaller == 0 ? y : yn;
}
} else {
out[0] = Field::Read(src + 1);
@@ -284,41 +282,43 @@ void scl::details::CurveFromBytes(Point& out, const unsigned char* src) {
}
}
-#define MARK_COMPRESSED(buf) (*buf |= COMPRESSED_FLAG)
-#define MARK_POINT_AT_INFINITY(buf) (*buf |= POINT_AT_INFINITY_FLAG)
-#define MARK_SELECT_SMALLER(buf) (*buf |= SELECT_SMALLER_FLAG)
+#define MARK_COMPRESSED(buf) (*(buf) |= COMPRESSED_FLAG)
+#define MARK_POINT_AT_INFINITY(buf) (*(buf) |= POINT_AT_INFINITY_FLAG)
+#define MARK_SELECT_SMALLER(buf) (*(buf) |= SELECT_SMALLER_FLAG)
template <>
-void scl::details::CurveToBytes(unsigned char* dest, const Point& point,
+void scl::details::CurveToBytes(unsigned char* dest, const Point& in,
bool compress) {
// Make sure flag byte is zeroed.
*dest = 0;
// if point is compressed, mark it as such.
- if (compress) MARK_COMPRESSED(dest);
+ if (compress) {
+ MARK_COMPRESSED(dest);
+ }
- if (CurveIsPointAtInfinity(point)) {
+ if (CurveIsPointAtInfinity(in)) {
MARK_POINT_AT_INFINITY(dest);
// zero rest of the buffer to ensure we can always safely send the right
// amount of bytes.
std::memset(dest + 1, 0, compress ? 32 : 64);
} else {
- const auto ap = ToAffine(point);
+ const auto ap = CurveToAffine(in);
// if compression is used, we indicate a bit indicating which of {y, -y} is
// the smaller, and the only write the x coordinate. Otherwise we write both
// x and y.
if (compress) {
// include a flag which indicates which of {y, -y} is the smaller.
- const auto y = ap.y;
+ const auto y = ap[1];
const auto yn = y.Negated();
if (IsSmaller(y, yn)) {
MARK_SELECT_SMALLER(dest);
}
- ap.x.Write(dest + 1);
+ ap[0].Write(dest + 1);
} else {
- ap.x.Write(dest + 1);
- ap.y.Write(dest + 1 + Field::ByteSize());
+ ap[0].Write(dest + 1);
+ ap[1].Write(dest + 1 + Field::ByteSize());
}
}
}
diff --git a/src/scl/math/secp256k1_extras.h b/src/scl/math/secp256k1_extras.h
index 40fbffa..c6d5844 100644
--- a/src/scl/math/secp256k1_extras.h
+++ b/src/scl/math/secp256k1_extras.h
@@ -42,6 +42,10 @@ struct SCL_FF_Extras {
template <>
struct SCL_FF_Extras {
+ // Convert an element out of montgomery representation
+ static scl::FF FromMonty(
+ const scl::FF& element);
+
// Get position of the highest set bit
static std::size_t HigestSetBit(
const scl::FF& element);
diff --git a/src/scl/math/secp256k1_field.cc b/src/scl/math/secp256k1_field.cc
index 20d4506..c78ddcd 100644
--- a/src/scl/math/secp256k1_field.cc
+++ b/src/scl/math/secp256k1_field.cc
@@ -18,6 +18,7 @@
* along with this program. If not, see .
*/
+#include
#include
#include "./ops_gmp_ff.h"
@@ -94,24 +95,24 @@ void scl::details::FieldInvert(Elem& out) {
}
template <>
-bool scl::details::FieldEqual(const Elem& first, const Elem& second) {
- return CompareValues(PTR(first), PTR(second)) == 0;
+bool scl::details::FieldEqual(const Elem& in1, const Elem& in2) {
+ return CompareValues(PTR(in1), PTR(in2)) == 0;
}
template <>
-void scl::details::FieldFromBytes(Elem& out, const unsigned char* src) {
- ValueFromBytes(PTR(out), src);
+void scl::details::FieldFromBytes(Elem& dest, const unsigned char* src) {
+ ValueFromBytes(PTR(dest), src);
}
template <>
-std::string scl::details::FieldToString(const Elem& element) {
- return ToString(PTR(element), kPrime, kMontyN);
+std::string scl::details::FieldToString(const Elem& in) {
+ return ToString(PTR(in), kPrime, kMontyN);
}
template <>
-void scl::details::FieldFromString(Elem& out, const std::string& str) {
+void scl::details::FieldFromString(Elem& out, const std::string& src) {
out = {0};
- FromString(PTR(out), kPrime, str);
+ FromString(PTR(out), kPrime, src);
}
bool scl::SCL_FF_Extras::IsSmaller(
diff --git a/src/scl/math/secp256k1_order.cc b/src/scl/math/secp256k1_order.cc
index b616b0f..8a78a93 100644
--- a/src/scl/math/secp256k1_order.cc
+++ b/src/scl/math/secp256k1_order.cc
@@ -20,6 +20,7 @@
#include
+#include
#include
#include "./ops_gmp_ff.h"
@@ -94,24 +95,24 @@ void scl::details::FieldInvert(Elem& out) {
}
template <>
-bool scl::details::FieldEqual(const Elem& first, const Elem& second) {
- return CompareValues(PTR(first), PTR(second)) == 0;
+bool scl::details::FieldEqual(const Elem& in1, const Elem& in2) {
+ return CompareValues(PTR(in1), PTR(in2)) == 0;
}
template <>
-void scl::details::FieldFromBytes(Elem& out, const unsigned char* src) {
- ValueFromBytes(PTR(out), src);
+void scl::details::FieldFromBytes(Elem& dest, const unsigned char* src) {
+ ValueFromBytes(PTR(dest), src);
}
template <>
-std::string scl::details::FieldToString(const Elem& element) {
- return ToString(PTR(element), kPrime, kMontyN);
+std::string scl::details::FieldToString(const Elem& in) {
+ return ToString(PTR(in), kPrime, kMontyN);
}
template <>
-void scl::details::FieldFromString(Elem& out, const std::string& str) {
+void scl::details::FieldFromString(Elem& out, const std::string& src) {
out = {0};
- FromString(PTR(out), kPrime, str);
+ FromString(PTR(out), kPrime, src);
}
std::size_t scl::SCL_FF_Extras::HigestSetBit(
@@ -127,4 +128,16 @@ bool scl::SCL_FF_Extras::TestBit(const scl::FF& element,
return ((element.mValue[limb] >> limb_pos) & 1) == 1;
}
+scl::FF scl::SCL_FF_Extras::FromMonty(
+ const scl::FF& element) {
+ mp_limb_t padded[2 * NUM_LIMBS] = {0};
+ SCL_COPY(padded, PTR(element.mValue), NUM_LIMBS);
+ details::Redc(padded, kPrime, kMontyN);
+
+ scl::FF r;
+ SCL_COPY(PTR(r.mValue), padded, NUM_LIMBS);
+
+ return r;
+}
+
#undef ONE
diff --git a/src/scl/math/str.cc b/src/scl/math/str.cc
index 8e6c227..32bdb74 100644
--- a/src/scl/math/str.cc
+++ b/src/scl/math/str.cc
@@ -24,8 +24,9 @@
template <>
std::string scl::details::ToHexString(const __uint128_t& v) {
+ std::string str;
if (v == 0) {
- return "0";
+ str = "0";
} else {
std::stringstream ss;
auto top = static_cast(v >> 64);
@@ -35,6 +36,7 @@ std::string scl::details::ToHexString(const __uint128_t& v) {
ss << top;
}
ss << bot;
- return ss.str();
+ str = ss.str();
}
+ return str;
}
diff --git a/src/scl/net/config.cc b/src/scl/net/config.cc
index bddc0bd..356752e 100644
--- a/src/scl/net/config.cc
+++ b/src/scl/net/config.cc
@@ -26,15 +26,22 @@
#include
#include
-static inline void ValidateIdAndSize(unsigned id, std::size_t n) {
- if (!n) {
+namespace {
+
+void ValidateIdAndSize(unsigned id, std::size_t n) {
+ if (n == 0) {
throw std::invalid_argument("n cannot be zero");
- } else if (n <= id) {
+ }
+
+ if (n <= id) {
throw std::invalid_argument("invalid id");
}
}
-scl::NetworkConfig scl::NetworkConfig::Load(unsigned id, std::string filename) {
+} // namespace
+
+scl::NetworkConfig scl::NetworkConfig::Load(int id,
+ const std::string& filename) {
std::ifstream file(filename);
if (!file.is_open()) {
@@ -45,14 +52,17 @@ scl::NetworkConfig scl::NetworkConfig::Load(unsigned id, std::string filename) {
std::vector info;
while (std::getline(file, line)) {
- auto a = line.find(',');
- auto b = line.rfind(',');
+ auto a_ = line.find(',');
+ auto b_ = line.rfind(',');
- if (a == std::string::npos || a == b) {
+ if (a_ == std::string::npos || a_ == b_) {
throw std::invalid_argument("invalid entry in config file");
}
- auto id = (unsigned)std::stoul(std::string(line.begin(), line.begin() + a));
+ auto a = static_cast(a_);
+ auto b = static_cast(b_);
+
+ auto id = std::stoi(std::string(line.begin(), line.begin() + a));
auto hostname = std::string(line.begin() + a + 1, line.begin() + b);
auto port = std::stoi(std::string(line.begin() + b + 1, line.end()));
info.emplace_back(Party{id, hostname, port});
@@ -63,14 +73,14 @@ scl::NetworkConfig scl::NetworkConfig::Load(unsigned id, std::string filename) {
return NetworkConfig(id, info);
}
-scl::NetworkConfig scl::NetworkConfig::Localhost(unsigned id, std::size_t n,
+scl::NetworkConfig scl::NetworkConfig::Localhost(int id, int size,
int port_base) {
- ValidateIdAndSize(id, n);
+ ValidateIdAndSize(id, size);
std::vector info;
- for (std::size_t i = 0; i < n; ++i) {
+ for (int i = 0; i < size; ++i) {
int port = port_base + i;
- info.emplace_back(Party{(unsigned)i, "127.0.0.1", port});
+ info.emplace_back(Party{i, "127.0.0.1", port});
}
return NetworkConfig(id, info);
@@ -93,13 +103,13 @@ std::string scl::NetworkConfig::ToString() const {
void scl::NetworkConfig::Validate() {
auto n = NetworkSize();
- if (Id() >= n) {
+ if (static_cast(Id()) >= n) {
throw std::invalid_argument("my ID is invalid in config");
}
for (std::size_t i = 0; i < n; ++i) {
auto pi = mParties[i];
- if (pi.id >= n) {
+ if (static_cast(pi.id) >= n) {
throw std::invalid_argument("invalid ID in config");
}
for (std::size_t j = i + 1; j < n; ++j) {
diff --git a/src/scl/net/discovery/client.cc b/src/scl/net/discovery/client.cc
index 468f0f2..baf6b55 100644
--- a/src/scl/net/discovery/client.cc
+++ b/src/scl/net/discovery/client.cc
@@ -27,7 +27,7 @@
using Client = scl::DiscoveryClient;
-scl::NetworkConfig Client::Run(unsigned id, int port) {
+scl::NetworkConfig Client::Run(int id, int port) const {
auto socket = scl::details::ConnectAsClient(mHostname, mPort);
std::shared_ptr server =
std::make_shared(socket);
@@ -37,13 +37,15 @@ scl::NetworkConfig Client::Run(unsigned id, int port) {
}
Client::ReceiveNetworkConfig Client::SendIdAndPort::Run(
- std::shared_ptr ctx) {
+ const std::shared_ptr& ctx) const {
ctx->Send(mId);
ctx->Send(mPort);
return Client::ReceiveNetworkConfig{mId};
}
-static inline std::string ReceiveHostname(std::shared_ptr ctx) {
+namespace {
+
+std::string ReceiveHostname(const std::shared_ptr& ctx) {
std::size_t len;
ctx->Recv(len);
auto buf = std::make_unique(len);
@@ -51,8 +53,10 @@ static inline std::string ReceiveHostname(std::shared_ptr ctx) {
return std::string(buf.get(), buf.get() + len);
}
+} // namespace
+
scl::NetworkConfig Client::ReceiveNetworkConfig::Finalize(
- std::shared_ptr ctx) {
+ const std::shared_ptr& ctx) const {
std::size_t number_of_parties;
ctx->Recv(number_of_parties);
diff --git a/src/scl/net/discovery/server.cc b/src/scl/net/discovery/server.cc
index ec1c3ec..6d3ef6d 100644
--- a/src/scl/net/discovery/server.cc
+++ b/src/scl/net/discovery/server.cc
@@ -20,6 +20,7 @@
#include "scl/net/discovery/server.h"
+#include
#include
#include
@@ -28,14 +29,15 @@
using Server = scl::DiscoveryServer;
-scl::NetworkConfig Server::Run(const scl::Party& me) {
+scl::NetworkConfig Server::Run(const scl::Party& me) const {
// one of the parties is us, which we do not connect to.
- auto ssock = scl::details::CreateServerSocket(mPort, mNumberOfParties - 1);
+ int backlog = static_cast(mNumberOfParties - 1);
+ auto ssock = scl::details::CreateServerSocket(mPort, backlog);
std::vector> channels;
std::vector hostnames;
for (std::size_t i = 0; i < mNumberOfParties; ++i) {
- if (i == me.id) {
+ if (i == static_cast(me.id)) {
channels.emplace_back(nullptr);
hostnames.emplace_back(me.hostname);
} else {
@@ -53,15 +55,15 @@ scl::NetworkConfig Server::Run(const scl::Party& me) {
}
Server::SendNetworkConfig Server::CollectIdsAndPorts::Run(Server::Ctx& ctx) {
- auto my_id = ctx.me.id;
+ auto my_id = static_cast(ctx.me.id);
std::vector parties(mHostnames.size());
parties[my_id] = ctx.me;
for (std::size_t i = 0; i < mHostnames.size(); ++i) {
if (my_id != i) {
- unsigned id;
+ int id;
ctx.network.Party(i)->Recv(id);
- if (id >= parties.size()) {
+ if (static_cast(id) >= parties.size()) {
throw std::logic_error("received invalid party ID");
}
int port;
@@ -74,7 +76,9 @@ Server::SendNetworkConfig Server::CollectIdsAndPorts::Run(Server::Ctx& ctx) {
return Server::SendNetworkConfig(cfg);
}
-static inline void SendHostname(scl::Channel* channel, std::string hostname) {
+namespace {
+
+void SendHostname(scl::Channel* channel, const std::string& hostname) {
std::size_t len = hostname.size();
const unsigned char* ptr =
reinterpret_cast(hostname.c_str());
@@ -83,8 +87,7 @@ static inline void SendHostname(scl::Channel* channel, std::string hostname) {
channel->Send(ptr, len);
}
-static inline void SendConfig(scl::Channel* channel,
- const scl::NetworkConfig& config) {
+void SendConfig(scl::Channel* channel, const scl::NetworkConfig& config) {
channel->Send(config.NetworkSize());
for (std::size_t i = 0; i < config.NetworkSize(); ++i) {
auto party = config.Parties()[i];
@@ -94,12 +97,15 @@ static inline void SendConfig(scl::Channel* channel,
}
}
+} // namespace
scl::NetworkConfig Server::SendNetworkConfig::Finalize(Server::Ctx& ctx) {
std::size_t network_size = mConfig.NetworkSize();
for (std::size_t i = 0; i < network_size; ++i) {
- if (i == mConfig.Id()) continue;
+ if (i == static_cast(mConfig.Id())) {
+ continue;
+ }
- auto channel = ctx.network.Party(i);
+ auto* channel = ctx.network.Party(i);
SendConfig(channel, mConfig);
}
diff --git a/src/scl/net/mem_channel.cc b/src/scl/net/mem_channel.cc
index 654b6bb..72f186c 100644
--- a/src/scl/net/mem_channel.cc
+++ b/src/scl/net/mem_channel.cc
@@ -22,11 +22,14 @@
#include
+// used to silence narrowing conversion errors. x will have type std::size_t
+#define DIFF_T(x) static_cast::difference_type>((x))
+
void scl::InMemoryChannel::Send(const unsigned char* src, std::size_t n) {
mOut->PushBack(std::vector(src, src + n));
}
-int scl::InMemoryChannel::Recv(unsigned char* dst, std::size_t n) {
+std::size_t scl::InMemoryChannel::Recv(unsigned char* dst, std::size_t n) {
std::size_t rem = n;
// if there's any leftovers from previous calls to recv, then we retrieve
@@ -34,10 +37,10 @@ int scl::InMemoryChannel::Recv(unsigned char* dst, std::size_t n) {
const auto leftovers = mOverflow.size();
if (leftovers > 0) {
const auto to_copy = leftovers > rem ? rem : leftovers;
- auto data = mOverflow.data();
+ auto* data = mOverflow.data();
std::memcpy(dst, data, to_copy);
rem -= to_copy;
- mOverflow = std::vector(mOverflow.begin() + to_copy,
+ mOverflow = std::vector(mOverflow.begin() + DIFF_T(to_copy),
mOverflow.end());
}
@@ -53,8 +56,8 @@ int scl::InMemoryChannel::Recv(unsigned char* dst, std::size_t n) {
const auto leftovers = data.size() - to_copy;
const auto old_size = mOverflow.size();
mOverflow.reserve(old_size + leftovers);
- mOverflow.insert(mOverflow.begin() + old_size, data.begin() + to_copy,
- data.end());
+ mOverflow.insert(mOverflow.begin() + DIFF_T(old_size),
+ data.begin() + DIFF_T(to_copy), data.end());
}
}
diff --git a/src/scl/net/network.cc b/src/scl/net/network.cc
index 239b23c..c7e0272 100644
--- a/src/scl/net/network.cc
+++ b/src/scl/net/network.cc
@@ -37,9 +37,9 @@ void scl::details::SCL_AcceptConnections(
// Act as server for all clients with an ID strictly greater than ours.
auto my_id = config.Id();
auto n = config.NetworkSize() - my_id - 1;
- if (n) {
+ if (n > 0) {
auto port = config.GetParty(my_id).port;
- int ssock = scl::details::CreateServerSocket(port, n);
+ int ssock = scl::details::CreateServerSocket(port, static_cast(n));
for (std::size_t i = config.Id() + 1; i < config.NetworkSize(); ++i) {
auto ac = scl::details::AcceptConnection(ssock);
std::shared_ptr channel =
diff --git a/src/scl/net/tcp_channel.cc b/src/scl/net/tcp_channel.cc
index 20594e2..e54135c 100644
--- a/src/scl/net/tcp_channel.cc
+++ b/src/scl/net/tcp_channel.cc
@@ -20,10 +20,15 @@
#include "scl/net/tcp_channel.h"
+#include
+#include
+
#include "scl/net/tcp_utils.h"
void scl::TcpChannel::Close() {
- if (!mAlive) return;
+ if (!mAlive) {
+ return;
+ }
const auto err = scl::details::CloseSocket(mSocket);
if (err < 0) {
@@ -46,13 +51,13 @@ void scl::TcpChannel::Send(const unsigned char* src, std::size_t n) {
}
}
-int scl::TcpChannel::Recv(unsigned char* dst, std::size_t n) {
+std::size_t scl::TcpChannel::Recv(unsigned char* dst, std::size_t n) {
std::size_t rem = n;
std::size_t offset = 0;
while (rem > 0) {
auto recv = scl::details::ReadFromSocket(mSocket, dst + offset, rem);
- if (!recv) {
+ if (recv == 0) {
break;
}
if (recv < 0) {
@@ -65,3 +70,17 @@ int scl::TcpChannel::Recv(unsigned char* dst, std::size_t n) {
return n - rem;
}
+
+bool scl::TcpChannel::HasData() {
+ struct pollfd fds {
+ mSocket, POLLIN, 0
+ };
+
+ auto r = poll(&fds, 1, 0);
+
+ if (r < 0) {
+ SCL_THROW_SYS_ERROR("hasData failed");
+ }
+
+ return r > 0 && fds.revents == POLLIN;
+}
diff --git a/src/scl/net/tcp_utils.cc b/src/scl/net/tcp_utils.cc
index 7105fe2..bec9b62 100644
--- a/src/scl/net/tcp_utils.cc
+++ b/src/scl/net/tcp_utils.cc
@@ -76,19 +76,19 @@ scl::details::AcceptedConnection scl::details::AcceptConnection(
::accept(server_socket, ac.socket_info.get(), (socklen_t*)&addrsize);
if (ac.socket < 0) {
SCL_THROW_SYS_ERROR("could not accept connection");
- } else {
- return ac;
}
+
+ return ac;
}
std::string scl::details::GetAddress(
- scl::details::AcceptedConnection connection) {
+ const scl::details::AcceptedConnection& connection) {
struct sockaddr_in* s =
reinterpret_cast(connection.socket_info.get());
return inet_ntoa(s->sin_addr);
}
-int scl::details::ConnectAsClient(std::string hostname, int port) {
+int scl::details::ConnectAsClient(const std::string& hostname, int port) {
using namespace std::chrono_literals;
int sock = ::socket(AF_INET, SOCK_STREAM, 0);
@@ -101,26 +101,30 @@ int scl::details::ConnectAsClient(std::string hostname, int port) {
addr.sin_port = ::htons(port);
int err = ::inet_pton(AF_INET, hostname.c_str(), &(addr.sin_addr));
+
if (err == 0) {
throw std::runtime_error("invalid hostname");
- } else if (err < 0) {
+ }
+
+ if (err < 0) {
SCL_THROW_SYS_ERROR("invalid address family");
}
- while (::connect(sock, (struct sockaddr*)&addr, sizeof(addr)) < 0)
+ while (::connect(sock, (struct sockaddr*)&addr, sizeof(addr)) < 0) {
std::this_thread::sleep_for(300ms);
+ }
return sock;
}
int scl::details::CloseSocket(int socket) { return ::close(socket); }
-int scl::details::ReadFromSocket(int socket, unsigned char* dst,
- std::size_t n) {
+ssize_t scl::details::ReadFromSocket(int socket, unsigned char* dst,
+ std::size_t n) {
return ::read(socket, dst, n);
}
-int scl::details::WriteToSocket(int socket, const unsigned char* src,
- std::size_t n) {
+ssize_t scl::details::WriteToSocket(int socket, const unsigned char* src,
+ std::size_t n) {
return ::write(socket, src, n);
}
diff --git a/src/scl/prg.cc b/src/scl/prg.cc
index 1a172d9..d08150f 100644
--- a/src/scl/prg.cc
+++ b/src/scl/prg.cc
@@ -30,19 +30,19 @@ using block_t = __m128i;
using std::size_t;
-#define DO_ENC_BLOCK(m, k) \
- do { \
- m = _mm_xor_si128(m, k[0]); \
- m = _mm_aesenc_si128(m, k[1]); \
- m = _mm_aesenc_si128(m, k[2]); \
- m = _mm_aesenc_si128(m, k[3]); \
- m = _mm_aesenc_si128(m, k[4]); \
- m = _mm_aesenc_si128(m, k[5]); \
- m = _mm_aesenc_si128(m, k[6]); \
- m = _mm_aesenc_si128(m, k[7]); \
- m = _mm_aesenc_si128(m, k[8]); \
- m = _mm_aesenc_si128(m, k[9]); \
- m = _mm_aesenclast_si128(m, k[10]); \
+#define DO_ENC_BLOCK(m, k) \
+ do { \
+ (m) = _mm_xor_si128(m, (k)[0]); \
+ (m) = _mm_aesenc_si128(m, (k)[1]); \
+ (m) = _mm_aesenc_si128(m, (k)[2]); \
+ (m) = _mm_aesenc_si128(m, (k)[3]); \
+ (m) = _mm_aesenc_si128(m, (k)[4]); \
+ (m) = _mm_aesenc_si128(m, (k)[5]); \
+ (m) = _mm_aesenc_si128(m, (k)[6]); \
+ (m) = _mm_aesenc_si128(m, (k)[7]); \
+ (m) = _mm_aesenc_si128(m, (k)[8]); \
+ (m) = _mm_aesenc_si128(m, (k)[9]); \
+ (m) = _mm_aesenclast_si128(m, (k)[10]); \
} while (0)
#define AES_128_key_exp(k, rcon) \
@@ -97,18 +97,24 @@ static inline auto create_mask(const long counter) {
}
void scl::PRG::Next(byte_t* dest, size_t nbytes) {
- if (!nbytes) return;
+ if (nbytes == 0) {
+ return;
+ }
size_t nblocks = nbytes / BlockSize();
- if (nbytes % BlockSize()) nblocks++;
+ if ((nbytes % BlockSize()) != 0) {
+ nblocks++;
+ }
block_t mask = create_mask(mCounter);
byte_t* out = (byte_t*)malloc(nblocks * BlockSize());
byte_t* p = out;
// LCOV_EXCL_START
- if (!out) throw std::runtime_error("Could not allocate memory for PRG.");
+ if (out == nullptr) {
+ throw std::runtime_error("Could not allocate memory for PRG.");
+ }
// LCOV_EXCL_STOP
for (size_t i = 0; i < nblocks; i++) {
diff --git a/test/scl/gf7.cc b/test/scl/gf7.cc
index 176ffdf..7148f3d 100644
--- a/test/scl/gf7.cc
+++ b/test/scl/gf7.cc
@@ -38,10 +38,11 @@ void scl::details::FieldAdd(unsigned char& out, const unsigned char& op) {
template <>
void scl::details::FieldSubtract(unsigned char& out,
const unsigned char& op) {
- if (out < op)
+ if (out < op) {
out = 7 + out - op;
- else
+ } else {
out = out - op;
+ }
}
template <>
@@ -95,8 +96,8 @@ void scl::details::FieldFromBytes(unsigned char& dest,
}
template <>
-std::string scl::details::FieldToString(const unsigned char& value) {
+std::string scl::details::FieldToString(const unsigned char& in) {
std::stringstream ss;
- ss << (int)value;
+ ss << (int)in;
return ss.str();
}
diff --git a/test/scl/gf7.h b/test/scl/gf7.h
index 0c68f94..ec9e9c9 100644
--- a/test/scl/gf7.h
+++ b/test/scl/gf7.h
@@ -18,8 +18,8 @@
* along with this program. If not, see .
*/
-#ifndef _SCUTIL_MATH_GF7_H
-#define _SCUTIL_MATH_GF7_H
+#ifndef TEST_SCL_GF7_H
+#define TEST_SCL_GF7_H
#include
#include
@@ -37,4 +37,4 @@ struct GF7 {
} // namespace details
} // namespace scl
-#endif /* _SCUTIL_MATH_GF7_H */
+#endif /* TEST_SCL_GF7_H */
diff --git a/test/scl/math/test_ff.cc b/test/scl/math/test_ff.cc
index 41b932a..2f22384 100644
--- a/test/scl/math/test_ff.cc
+++ b/test/scl/math/test_ff.cc
@@ -31,17 +31,21 @@ using GF7 = scl::FF;
#ifdef SCL_ENABLE_EC_TESTS
using Secp256k1_Field = scl::FF;
+using Secp256k1_Order = scl::FF;
#endif
template
T RandomNonZero(scl::PRG& prg) {
auto a = T::Random(prg);
for (std::size_t i = 0; i < 10; ++i) {
- if (a == T::Zero()) a = T::Random(prg);
+ if (a == T::Zero()) {
+ a = T::Random(prg);
+ }
break;
}
- if (a == T::Zero())
+ if (a == T::Zero()) {
throw std::logic_error("could not generate a non-zero random value");
+ }
return a;
}
@@ -59,7 +63,7 @@ GF7 RandomNonZero(scl::PRG& prg) {
#define REPEAT for (std::size_t i = 0; i < 50; ++i)
#ifdef SCL_ENABLE_EC_TESTS
-#define ARG_LIST Mersenne61, Mersenne127, GF7, Secp256k1_Field
+#define ARG_LIST Mersenne61, Mersenne127, GF7, Secp256k1_Field, Secp256k1_Order
#else
#define ARG_LIST Mersenne61, Mersenne127, GF7
#endif
diff --git a/test/scl/math/test_mat.cc b/test/scl/math/test_mat.cc
index f622896..4410098 100644
--- a/test/scl/math/test_mat.cc
+++ b/test/scl/math/test_mat.cc
@@ -26,7 +26,7 @@
using F = scl::Fp<61>;
using Mat = scl::Mat;
-inline void Populate(Mat& m, unsigned values[]) {
+inline void Populate(Mat& m, const int* values) {
for (std::size_t i = 0; i < m.Rows(); i++) {
for (std::size_t j = 0; j < m.Cols(); j++) {
m(i, j) = F(values[i * m.Cols() + j]);
@@ -36,10 +36,10 @@ inline void Populate(Mat& m, unsigned values[]) {
TEST_CASE("Matrix", "[math]") {
Mat m0(2, 2);
- unsigned v0[] = {1, 2, 5, 6};
+ int v0[] = {1, 2, 5, 6};
Populate(m0, v0);
Mat m1(2, 2);
- unsigned v1[] = {4, 3, 2, 1};
+ int v1[] = {4, 3, 2, 1};
Populate(m1, v1);
REQUIRE(!m0.Equals(m1));
@@ -59,7 +59,7 @@ TEST_CASE("Matrix", "[math]") {
SECTION("ToString") {
Mat m(3, 2);
- unsigned v[] = {1, 2, 44444, 5, 6, 7};
+ int v[] = {1, 2, 44444, 5, 6, 7};
Populate(m, v);
std::string expected =
"\n"
@@ -133,16 +133,16 @@ TEST_CASE("Matrix", "[math]") {
REQUIRE(m2(1, 1) == F(21));
Mat m3(2, 10);
- unsigned v3[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 0,
- 11, 12, 13, 14, 15, 16, 17, 18, 19, 20};
+ int v3[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 0,
+ 11, 12, 13, 14, 15, 16, 17, 18, 19, 20};
Populate(m3, v3);
auto m5 = m0.Multiply(m3);
REQUIRE(m5.Rows() == 2);
REQUIRE(m5.Cols() == 10);
Mat m4(2, 10);
- unsigned v4[] = {23, 26, 29, 32, 35, 38, 41, 44, 47, 40,
- 71, 82, 93, 104, 115, 126, 137, 148, 159, 120};
+ int v4[] = {23, 26, 29, 32, 35, 38, 41, 44, 47, 40,
+ 71, 82, 93, 104, 115, 126, 137, 148, 159, 120};
Populate(m4, v4);
REQUIRE(m5.Equals(m4));
@@ -167,7 +167,7 @@ TEST_CASE("Matrix", "[math]") {
SECTION("Transpose") {
Mat m3(2, 3);
- unsigned v3[] = {1, 2, 3, 11, 12, 13};
+ int v3[] = {1, 2, 3, 11, 12, 13};
Populate(m3, v3);
auto m4 = m3.Transpose();
REQUIRE(m4.Rows() == m3.Cols());
@@ -244,10 +244,11 @@ TEST_CASE("Matrix", "[math]") {
bool good = true;
for (std::size_t i = 0; i < 10; ++i) {
for (std::size_t j = 0; j < 10; ++j) {
- if (i == j)
+ if (i == j) {
good &= A(i, j) == F(1);
- else
+ } else {
good &= A(i, j) == F(0);
+ }
}
}
REQUIRE(good);
diff --git a/test/scl/math/test_secp256k1.cc b/test/scl/math/test_secp256k1.cc
index 60cf847..d9053b5 100644
--- a/test/scl/math/test_secp256k1.cc
+++ b/test/scl/math/test_secp256k1.cc
@@ -25,6 +25,7 @@
#include "scl/math/curves/secp256k1.h"
#include "scl/math/ec_ops.h"
#include "scl/math/fp.h"
+#include "scl/math/number.h"
#include "scl/prg.h"
using Curve = scl::EC;
@@ -63,14 +64,17 @@ TEST_CASE("secp256k1_field", "[math]") {
}
SECTION("From affine") {
- auto g = Curve::FromAffine(
- Field::FromString("e47b4a1c2e13cf0e97c9adf5a645ce388e04317b7830401aabb4"
- "2e188c9883fa"), //
- Field::FromString("2aafa6e870684327ec92006e6c601a8b6e0fb9ff06ae120cb330"
- "a2eee86009ff") //
- );
+ auto x = Field::FromString(
+ "e47b4a1c2e13cf0e97c9adf5a645ce388e04317b7830401aabb42e188c9883fa");
+ auto y = Field::FromString(
+ "2aafa6e870684327ec92006e6c601a8b6e0fb9ff06ae120cb330a2eee86009ff");
+ auto g = Curve::FromAffine(x, y);
REQUIRE(!g.PointAtInfinity());
+ auto as_affine = g.ToAffine();
+ REQUIRE(as_affine[0] == x);
+ REQUIRE(as_affine[1] == y);
+
REQUIRE_THROWS_MATCHES(
Curve::FromAffine(Field(0), Field(0)), std::invalid_argument,
Catch::Matchers::Message("provided (x, y) not on curve"));
@@ -142,6 +146,8 @@ TEST_CASE("secp256k1_field", "[math]") {
}
}
+using Scalar = scl::FF;
+
static Curve RandomPoint(scl::PRG& prg) {
auto r = scl::Number::Random(100, prg);
return Curve::Generator() * r;
@@ -203,6 +209,38 @@ TEST_CASE("secp256k1", "[math]") {
REQUIRE((a + b).PointAtInfinity());
}
+ SECTION("Scalar mul order") {
+ auto a = RandomPoint(prg);
+ auto b = Scalar::FromString(
+ "fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364140");
+ REQUIRE(b + Scalar::One() == Scalar::Zero());
+
+ auto c = a * b;
+ REQUIRE(!c.PointAtInfinity());
+
+ REQUIRE((c + a).PointAtInfinity());
+ }
+
+ SECTION("Scalar mul distributive") {
+ auto a = RandomPoint(prg);
+ auto b = Scalar::Random(prg);
+ auto c = Scalar::Random(prg);
+ REQUIRE((b + c) * a == b * a + c * a);
+ }
+
+ SECTION("Scalar mul assoc") {
+ auto a = Curve::Generator();
+
+ auto v = Scalar::FromString("03");
+ auto u = Scalar::FromString("02");
+ auto w = Scalar::FromString("06");
+
+ auto P = a * w;
+ auto Q = (a * v) * u;
+
+ REQUIRE(P == Q);
+ }
+
SECTION("negation exceptional") {
using CurveT = scl::details::Secp256k1;
CurveT::ValueType point = {Field(1), Field(0), Field(1)};
diff --git a/test/scl/math/test_vec.cc b/test/scl/math/test_vec.cc
index cf8496d..a126e39 100644
--- a/test/scl/math/test_vec.cc
+++ b/test/scl/math/test_vec.cc
@@ -134,6 +134,13 @@ TEST_CASE("Vector", "[math]") {
REQUIRE(r[2] != v0[2]);
}
+ SECTION("Range") {
+ auto v = Vec::Range(1, 4);
+ REQUIRE(v[0] == F(1));
+ REQUIRE(v[1] == F(2));
+ REQUIRE(v[2] == F(3));
+ }
+
SECTION("iterators") {
auto v2 = Vec{F(1), F(2), F(3)};
std::size_t i = 0;
@@ -147,4 +154,12 @@ TEST_CASE("Vector", "[math]") {
auto v3 = Vec(v2.begin(), v2.end());
REQUIRE(v3.Equals(v2));
}
+
+ SECTION("subvector") {
+ auto v = Vec{F(1), F(2), F(3), F(4)};
+ REQUIRE(v.SubVector(1, 2) == Vec{F(2)});
+ REQUIRE(v.SubVector(1, 3) == Vec{F(2), F(3)});
+ REQUIRE(v.SubVector(1, 1) == Vec{});
+ REQUIRE(v.SubVector(2) == Vec{F(1), F(2)});
+ }
}
diff --git a/test/scl/math/test_z2k.cc b/test/scl/math/test_z2k.cc
index 88ff42d..954d810 100644
--- a/test/scl/math/test_z2k.cc
+++ b/test/scl/math/test_z2k.cc
@@ -82,7 +82,7 @@ TEMPLATE_TEST_CASE("Z2k", "[math]", Z2k1, Z2k2) {
#define RANDOM_INVERTIBLE(var) \
TestType var; \
- while (var.Lsb() == 0) var = TestType::Random(prg)
+ while ((var).Lsb() == 0) (var) = TestType::Random(prg)
SECTION("inverses") {
RANDOM_INVERTIBLE(a);
diff --git a/test/scl/net/test_config.cc b/test/scl/net/test_config.cc
index 8e1dc08..c79a45c 100644
--- a/test/scl/net/test_config.cc
+++ b/test/scl/net/test_config.cc
@@ -26,7 +26,7 @@
TEST_CASE("Config", "[network]") {
SECTION("From file") {
- auto filename = SCL_TEST_DATA_DIR "3_parties.txt";
+ const auto* filename = SCL_TEST_DATA_DIR "3_parties.txt";
auto cfg = scl::NetworkConfig::Load(0, filename);
REQUIRE(cfg.NetworkSize() == 3);
@@ -41,22 +41,22 @@ TEST_CASE("Config", "[network]") {
}
SECTION("Invalid file") {
- auto invalid_empty = SCL_TEST_DATA_DIR "invalid_no_entries.txt";
+ const auto* invalid_empty = SCL_TEST_DATA_DIR "invalid_no_entries.txt";
REQUIRE_THROWS_MATCHES(scl::NetworkConfig::Load(0, invalid_empty),
std::invalid_argument,
Catch::Matchers::Message("n cannot be zero"));
- auto valid = SCL_TEST_DATA_DIR "3_parties.txt";
+ const auto* valid = SCL_TEST_DATA_DIR "3_parties.txt";
REQUIRE_THROWS_MATCHES(scl::NetworkConfig::Load(4, valid),
std::invalid_argument,
Catch::Matchers::Message("invalid id"));
- auto invalid_entry = SCL_TEST_DATA_DIR "invalid_entry.txt";
+ const auto* invalid_entry = SCL_TEST_DATA_DIR "invalid_entry.txt";
REQUIRE_THROWS_MATCHES(
scl::NetworkConfig::Load(0, invalid_entry), std::invalid_argument,
Catch::Matchers::Message("invalid entry in config file"));
- auto invalid_non_existing_file = "";
+ const auto* invalid_non_existing_file = "";
REQUIRE_THROWS_MATCHES(
scl::NetworkConfig::Load(0, invalid_non_existing_file),
std::invalid_argument, Catch::Matchers::Message("could not open file"));
diff --git a/test/scl/net/test_discover.cc b/test/scl/net/test_discover.cc
index f1ef4d4..7d7d09c 100644
--- a/test/scl/net/test_discover.cc
+++ b/test/scl/net/test_discover.cc
@@ -28,16 +28,20 @@
#include "scl/net/mem_channel.h"
#include "scl/net/network.h"
-static inline bool VerifyParty(scl::Party& party, unsigned id,
- std::string hostname, int port) {
+namespace {
+
+bool VerifyParty(scl::Party& party, int id, const std::string& hostname,
+ int port) {
return party.id == id && party.hostname == hostname && party.port == port;
}
-static inline bool PartyEquals(scl::Party& first, scl::Party& second) {
+bool PartyEquals(scl::Party& first, scl::Party& second) {
return first.id == second.id && first.hostname == second.hostname &&
first.port == second.port;
}
+} // namespace
+
TEST_CASE("Discovery Server", "[network]") {
SECTION("CollectIdsAndPorts") {
std::vector hostnames = {"1.2.3.4", "4.4.4.4", "127.0.0.1"};
@@ -127,15 +131,16 @@ TEST_CASE("Discovery Server", "[network]") {
}
}
-static inline void SendHostname(scl::Channel* channel, std::string hostname) {
+namespace {
+
+void SendHostname(scl::Channel* channel, const std::string& hostname) {
std::size_t length = hostname.size();
channel->Send(length);
channel->Send(reinterpret_cast(hostname.c_str()),
length);
}
-static inline void SendConfig(scl::Channel* channel,
- const scl::NetworkConfig& config) {
+void SendConfig(scl::Channel* channel, const scl::NetworkConfig& config) {
channel->Send(config.NetworkSize());
for (const auto& party : config.Parties()) {
channel->Send(party.id);
@@ -144,6 +149,8 @@ static inline void SendConfig(scl::Channel* channel,
}
}
+} // namespace
+
TEST_CASE("Discovery Client", "[network]") {
SECTION("SendIdAndPort") {
auto channels = scl::InMemoryChannel::CreatePaired();
@@ -196,7 +203,9 @@ TEST_CASE("Discovery", "[network]") {
Catch::Matchers::Message("number_of_parties exceeds max"));
}
- scl::NetworkConfig c0, c1, c2;
+ scl::NetworkConfig c0;
+ scl::NetworkConfig c1;
+ scl::NetworkConfig c2;
std::thread server([&c0]() {
Server srv(9999, 3);
diff --git a/test/scl/net/test_mem_channel.cc b/test/scl/net/test_mem_channel.cc
index 79859a7..2c01b89 100644
--- a/test/scl/net/test_mem_channel.cc
+++ b/test/scl/net/test_mem_channel.cc
@@ -26,14 +26,9 @@
#include "scl/math.h"
#include "scl/net/mem_channel.h"
#include "scl/prg.h"
+#include "util.h"
-static inline bool Eq(const unsigned char* p, const unsigned char* q, int n) {
- while (n-- > 0 && *p++ == *q++)
- ;
- return n < 0;
-}
-
-static inline void PrintBuf(const unsigned char* b, std::size_t n) {
+void PrintBuf(const unsigned char* b, std::size_t n) {
for (std::size_t i = 0; i < n; ++i) {
std::cout << (int)b[i] << " ";
}
@@ -52,9 +47,12 @@ TEST_CASE("InMemoryChannel", "[network]") {
SECTION("Send and receive") {
unsigned char data_out[200] = {0};
+ REQUIRE(!chl1->HasData());
chl0->Send(data_in, 200);
+ REQUIRE(!chl0->HasData());
+ REQUIRE(chl1->HasData());
chl1->Recv(data_out, 200);
- REQUIRE(Eq(data_in, data_out, 200));
+ REQUIRE(scl_tests::BufferEquals(data_in, data_out, 200));
}
chl0->Flush();
@@ -68,7 +66,7 @@ TEST_CASE("InMemoryChannel", "[network]") {
chl0->Send(data_in + 100, 100);
chl1->Recv(data_out, 200);
- REQUIRE(Eq(data_in, data_out, 200));
+ REQUIRE(scl_tests::BufferEquals(data_in, data_out, 200));
}
chl0->Flush();
@@ -81,7 +79,7 @@ TEST_CASE("InMemoryChannel", "[network]") {
chl1->Recv(data_out, 100);
chl1->Recv(data_out + 100, 100);
- REQUIRE(Eq(data_in, data_out, 200));
+ REQUIRE(scl_tests::BufferEquals(data_in, data_out, 200));
}
chl0->Flush();
@@ -105,9 +103,10 @@ TEST_CASE("InMemoryChannel", "[network]") {
scl::Channel* c1 = chl1.get();
std::vector data = {1, 2, 3, 4, 11111111};
c0->Send(data);
- std::vector recv(data.size());
+ std::vector recv;
c1->Recv(recv);
REQUIRE(data == recv);
+ REQUIRE(recv.size() == data.size());
}
using FF = scl::Fp<61>;
@@ -153,6 +152,6 @@ TEST_CASE("InMemoryChannel", "[network]") {
c->Recv(data_out + 10, 100);
c->Recv(data_out + 110, 90);
- REQUIRE(Eq(data_in, data_out, 200));
+ REQUIRE(scl_tests::BufferEquals(data_in, data_out, 200));
}
}
diff --git a/test/scl/net/test_network.cc b/test/scl/net/test_network.cc
index e5196bf..f23f6cb 100644
--- a/test/scl/net/test_network.cc
+++ b/test/scl/net/test_network.cc
@@ -71,7 +71,9 @@ TEST_CASE("Network", "[network]") {
}
SECTION("TCP") {
- scl::Network network0, network1, network2;
+ scl::Network network0;
+ scl::Network network1;
+ scl::Network network2;
std::thread t0([&]() {
network0 = scl::Network::Create(
diff --git a/test/scl/net/test_tcp_channel.cc b/test/scl/net/test_tcp_channel.cc
index 82cecce..0efa63a 100644
--- a/test/scl/net/test_tcp_channel.cc
+++ b/test/scl/net/test_tcp_channel.cc
@@ -30,11 +30,14 @@ TEST_CASE("TcpChannel", "[network]") {
SECTION("Connect and Close") {
auto port = scl_tests::GetPort();
- std::shared_ptr client, server;
+ std::shared_ptr client;
+ std::shared_ptr server;
+
std::thread clt([&]() {
int socket = scl::details::ConnectAsClient("0.0.0.0", port);
client = std::make_shared(socket);
});
+
std::thread srv([&]() {
int ssock = scl::details::CreateServerSocket(port, 1);
auto ac = scl::details::AcceptConnection(ssock);
@@ -58,7 +61,9 @@ TEST_CASE("TcpChannel", "[network]") {
SECTION("Send Receive") {
auto port = scl_tests::GetPort();
- std::shared_ptr client, server;
+ std::shared_ptr client;
+ std::shared_ptr server;
+
std::thread clt([&]() {
int socket = scl::details::ConnectAsClient("0.0.0.0", port);
client = std::make_shared(socket);
@@ -79,8 +84,12 @@ TEST_CASE("TcpChannel", "[network]") {
unsigned char recv[200] = {0};
prg.Next(send, 200);
+ REQUIRE(!server->HasData());
+
client->Send(send, 100);
client->Send(send + 100, 100);
+
+ REQUIRE(server->HasData());
server->Recv(recv, 20);
server->Recv(recv + 20, 180);
@@ -90,7 +99,9 @@ TEST_CASE("TcpChannel", "[network]") {
SECTION("Recv from closed") {
auto port = scl_tests::GetPort();
- std::shared_ptr client, server;
+ std::shared_ptr client;
+ std::shared_ptr server;
+
std::thread clt([&]() {
int socket = scl::details::ConnectAsClient("0.0.0.0", port);
client = std::make_shared(socket);
diff --git a/test/scl/net/test_threaded_sender.cc b/test/scl/net/test_threaded_sender.cc
index 8c241af..c23b432 100644
--- a/test/scl/net/test_threaded_sender.cc
+++ b/test/scl/net/test_threaded_sender.cc
@@ -32,7 +32,8 @@ TEST_CASE("ThreadedSender", "[network]") {
SECTION("Connect and send") {
auto port = scl_tests::GetPort();
- std::shared_ptr client, server;
+ std::shared_ptr client;
+ std::shared_ptr server;
std::thread clt([&]() {
int socket = scl::details::ConnectAsClient("0.0.0.0", port);
@@ -54,9 +55,23 @@ TEST_CASE("ThreadedSender", "[network]") {
unsigned char recv[200] = {0};
prg.Next(send, 200);
+ REQUIRE(!server->HasData());
+
client->Send(send, 100);
client->Send(send + 100, 100);
+ // because the sender returns immediately, there might not be data
+ // available, so we will try a couple of times before failing.
+ {
+ using namespace std::chrono_literals;
+ auto c = 0;
+ while (c < 10 && !server->HasData()) {
+ std::this_thread::sleep_for(100ms);
+ c++;
+ }
+ }
+ REQUIRE(server->HasData());
+
server->Recv(recv, 20);
server->Recv(recv + 20, 180);
diff --git a/test/scl/net/util.cc b/test/scl/net/util.cc
index baf186b..6b1bbb7 100644
--- a/test/scl/net/util.cc
+++ b/test/scl/net/util.cc
@@ -26,7 +26,8 @@ int scl_tests::GetPort() { return test_port++; }
bool scl_tests::BufferEquals(const unsigned char *a, const unsigned char *b,
int n) {
- while (n-- > 0 && *a++ == *b++)
+ while (n-- > 0 && *a++ == *b++) {
;
+ }
return n < 0;
}
diff --git a/test/scl/net/util.h b/test/scl/net/util.h
index 5ca9c9d..6fe5a34 100644
--- a/test/scl/net/util.h
+++ b/test/scl/net/util.h
@@ -18,10 +18,9 @@
* along with this program. If not, see .
*/
-#ifndef _TEST_SCL_NET_UTIL_H
-#define _TEST_SCL_NET_UTIL_H
+#ifndef TEST_SCL_NET_UTIL_H
+#define TEST_SCL_NET_UTIL_H
-#include
namespace scl_tests {
/**
@@ -48,4 +47,4 @@ bool BufferEquals(const unsigned char* a, const unsigned char* b, int n);
} // namespace scl_tests
-#endif /* _TEST_SCL_NET_UTIL_H */
+#endif // TEST_SCL_NET_UTIL_H
diff --git a/test/scl/p/test_simple.cc b/test/scl/p/test_simple.cc
index c0aaa67..663e3f4 100644
--- a/test/scl/p/test_simple.cc
+++ b/test/scl/p/test_simple.cc
@@ -46,7 +46,7 @@ class BeaverMulFinalize
public:
BeaverMulFinalize(Triple t) : mTriple(t){};
- FF Finalize(Context& ctx) {
+ FF Finalize(Context& ctx) const {
scl::Vec ed0(2);
scl::Vec ed1(2);
ctx.network.Party(0)->Recv(ed0);
@@ -56,10 +56,11 @@ class BeaverMulFinalize
auto d = ed0[1] + ed1[1];
if (ctx.id == 0) {
+ // constant addition
return e * d - e * mTriple.b - d * mTriple.a + mTriple.c;
- } else {
- return -e * mTriple.b - d * mTriple.a + mTriple.c;
}
+
+ return -e * mTriple.b - d * mTriple.a + mTriple.c;
};
private:
diff --git a/test/scl/ss/test_shamir.cc b/test/scl/ss/test_shamir.cc
index edee349..9a44801 100644
--- a/test/scl/ss/test_shamir.cc
+++ b/test/scl/ss/test_shamir.cc
@@ -19,6 +19,7 @@
*/
#include
+#include
#include "../gf7.h"
#include "scl/math.h"
@@ -29,97 +30,87 @@ TEST_CASE("Shamir", "[ss]") {
using FF = scl::Fp<61>;
using Vec = scl::Vec;
- scl::PRG prg;
+ const std::size_t t = 2;
- SECTION("Share") {
- auto secret = FF(123);
- Vec alphas = {FF(2), FF(5), FF(3)};
- auto share_poly = scl::details::CreateShamirSharePolynomial(secret, 2, prg);
- auto shares = scl::CreateShamirShares(share_poly, alphas);
-
- auto reconstructed =
- scl::ReconstructShamirPassive(shares, alphas, FF(0), 2);
- REQUIRE(reconstructed == secret);
+ SECTION("Reconstruct") {
+ scl::PRG prg;
+ scl::details::ShamirSSFactory factory(
+ t, prg, scl::details::SecurityLevel::PASSIVE);
+ auto intr = factory.GetInterpolator();
- auto some_share = scl::ReconstructShamirPassive(shares, alphas, FF(3), 2);
- REQUIRE(some_share == shares[2]);
- }
-
- SECTION("Passive") {
auto secret = FF(123);
- auto shares = scl::CreateShamirShares(secret, 4, 3, prg);
- auto reconstructed = scl::ReconstructShamirPassive(shares, 3);
- REQUIRE(reconstructed == secret);
+ auto shares = factory.Share(secret);
+ auto s = intr.Reconstruct(shares);
+ REQUIRE(s == secret);
REQUIRE_THROWS_MATCHES(
- scl::ReconstructShamirPassive(shares, 4), std::invalid_argument,
+ intr.Reconstruct(shares.SubVector(1)), std::invalid_argument,
Catch::Matchers::Message("not enough shares to reconstruct"));
-
- Vec alphas = {FF(1), FF(2), FF(3)};
- REQUIRE_THROWS_MATCHES(
- scl::ReconstructShamirPassive(shares, alphas, FF(0), 3),
- std::invalid_argument,
- Catch::Matchers::Message("not enough alphas to reconstruct"));
}
SECTION("Detection") {
- auto secret = FF(123);
- auto shares = scl::CreateShamirShares(secret, 7, 3, prg);
- auto reconstructed = scl::ReconstructShamir(shares, 3);
- REQUIRE(reconstructed == secret);
+ scl::PRG prg;
+ scl::details::ShamirSSFactory factory(
+ t, prg, scl::details::SecurityLevel::DETECT);
+ auto intr = scl::details::Reconstructor::Create(
+ t, scl::details::SecurityLevel::DETECT);
- auto shares0 = shares;
- shares0[2] = FF(4);
- REQUIRE_THROWS_MATCHES(
- scl::ReconstructShamir(shares0, 3), std::logic_error,
- Catch::Matchers::Message("error detected during reconstruction"));
+ auto secret = FF(555);
+ auto shares = factory.Share(secret);
+ REQUIRE(shares.Size() == 2 * t + 1);
+ REQUIRE(intr.Reconstruct(shares) == secret);
- auto shares1 = shares;
- shares1[6] = FF(3);
REQUIRE_THROWS_MATCHES(
- scl::ReconstructShamir(shares1, 3), std::logic_error,
- Catch::Matchers::Message("error detected during reconstruction"));
+ intr.Reconstruct(shares.SubVector(2)), std::invalid_argument,
+ Catch::Matchers::Message("not enough shares to reconstruct"));
- REQUIRE_THROWS_MATCHES(
- scl::ReconstructShamir(shares0, 4), std::invalid_argument,
- Catch::Matchers::Message(
- "not enough shares to reconstruct with error detection"));
- REQUIRE_THROWS_MATCHES(
- scl::ReconstructShamir(shares0, Vec{}, FF(0), 2), std::invalid_argument,
- Catch::Matchers::Message(
- "not enough alphas to reconstruct with error detection"));
+ auto ss = intr.ReconstructShare(shares, 2);
+ REQUIRE(ss == shares[2]);
+ REQUIRE(intr.Reconstruct(shares, 3) == intr.ReconstructShare(shares, 2));
}
- SECTION("Correction") {
+ SECTION("Robust") {
+ scl::PRG prg;
+ scl::details::ShamirSSFactory factory(
+ t, prg, scl::details::SecurityLevel::CORRECT);
+
// no errors
auto secret = FF(123);
- auto shares = scl::CreateShamirShares(secret, 7, 2, prg);
- auto reconstructed = scl::ReconstructShamirRobust(shares, 2);
- CHECK(reconstructed == secret);
+ auto shares = factory.Share(secret);
+ REQUIRE(shares.Size() == 3 * t + 1);
+ auto reconstructed = scl::details::ReconstructShamirRobust(shares, t);
+ REQUIRE(reconstructed == secret);
+
+ // can also reconstruct with an interpolator
+ auto intr = scl::details::Reconstructor::Create(
+ t, scl::details::SecurityLevel::CORRECT);
+ REQUIRE(intr.Reconstruct(shares) == secret);
// one error
shares[0] = FF(63212);
- auto reconstructed_1 = scl::ReconstructShamirRobust(shares, 2);
- CHECK(reconstructed_1 == secret);
+ auto reconstructed_1 = scl::details::ReconstructShamirRobust(shares, t);
+ REQUIRE(reconstructed_1 == secret);
// two errors
shares[2] = FF(63212211);
- auto reconstructed_2 = scl::ReconstructShamirRobust(shares, 2);
- CHECK(reconstructed_2 == secret);
+ auto reconstructed_2 = scl::details::ReconstructShamirRobust(shares, t);
+ REQUIRE(reconstructed_2 == secret);
// three errors -- that's one too many
shares[1] = FF(123);
REQUIRE_THROWS_MATCHES(
- scl::ReconstructShamirRobust(shares, 2), std::logic_error,
+ scl::details::ReconstructShamirRobust(shares, t), std::logic_error,
Catch::Matchers::Message("could not correct shares"));
REQUIRE_THROWS_MATCHES(
- scl::ReconstructShamirRobust(shares, 3), std::invalid_argument,
+ scl::details::ReconstructShamirRobust(shares, t + 1),
+ std::invalid_argument,
Catch::Matchers::Message(
"not enough shares to reconstruct with error correction"));
REQUIRE_THROWS_MATCHES(
- scl::ReconstructShamirRobust(shares, Vec{}, 2), std::invalid_argument,
+ scl::details::ReconstructShamirRobust(shares, Vec{}, t),
+ std::invalid_argument,
Catch::Matchers::Message(
"not enough alphas to reconstruct with error correction"));
}
@@ -135,9 +126,9 @@ TEST_CASE("BerlekampWelch", "[ss][math]") {
Vec as = {FF(0), FF(1), FF(2), FF(3), FF(4), FF(5), FF(6)};
Vec corrected = {FF(1), FF(6), FF(3), FF(6), FF(1), FF(2), FF(2)};
- auto pe = scl::ReconstructShamirRobust(bs, as, 2);
- auto p = pe[0];
- auto e = pe[1];
+ auto pe = scl::details::ReconstructShamirRobust(bs, as, 2);
+ auto p = std::get<0>(pe);
+ auto e = std::get<1>(pe);
// errors
REQUIRE(e.Evaluate(FF(1)) == FF{});
diff --git a/test/scl/test_hash.cc b/test/scl/test_hash.cc
index 31d1801..30eb57f 100644
--- a/test/scl/test_hash.cc
+++ b/test/scl/test_hash.cc
@@ -67,42 +67,54 @@ TEST_CASE("Hash", "[misc]") {
SECTION("SHA3-256 0xA3 x 200") {
unsigned char byte = 0xA3;
unsigned char buf[200];
- for (std::size_t i = 0; i < 200; ++i) buf[i] = byte;
+ for (std::size_t i = 0; i < 200; ++i) {
+ buf[i] = byte;
+ }
scl::Hash<256> hash0;
auto digest = hash0.Update(buf, 200).Finalize();
REQUIRE(digest == SHA3_256_0xa3_200_times);
scl::Hash<256> hash1;
- for (std::size_t i = 0; i < 200; ++i) hash1.Update(&byte, 1);
+ for (std::size_t i = 0; i < 200; ++i) {
+ hash1.Update(&byte, 1);
+ }
REQUIRE(hash1.Finalize() == SHA3_256_0xa3_200_times);
}
SECTION("SHA3-384 0xA3 x 200") {
unsigned char byte = 0xA3;
unsigned char buf[200];
- for (std::size_t i = 0; i < 200; ++i) buf[i] = byte;
+ for (std::size_t i = 0; i < 200; ++i) {
+ buf[i] = byte;
+ }
scl::Hash<384> hash0;
auto digest = hash0.Update(buf, 200).Finalize();
REQUIRE(digest == SHA3_384_0xa3_200_times);
scl::Hash<384> hash1;
- for (std::size_t i = 0; i < 200; ++i) hash1.Update(&byte, 1);
+ for (std::size_t i = 0; i < 200; ++i) {
+ hash1.Update(&byte, 1);
+ }
REQUIRE(hash1.Finalize() == SHA3_384_0xa3_200_times);
}
SECTION("SHA3-512 0xA3 x 200") {
unsigned char byte = 0xA3;
unsigned char buf[200];
- for (std::size_t i = 0; i < 200; ++i) buf[i] = byte;
+ for (std::size_t i = 0; i < 200; ++i) {
+ buf[i] = byte;
+ }
scl::Hash<512> hash0;
auto digest = hash0.Update(buf, 200).Finalize();
REQUIRE(digest == SHA3_512_0xa3_200_times);
scl::Hash<512> hash1;
- for (std::size_t i = 0; i < 200; ++i) hash1.Update(&byte, 1);
+ for (std::size_t i = 0; i < 200; ++i) {
+ hash1.Update(&byte, 1);
+ }
REQUIRE(hash1.Finalize() == SHA3_512_0xa3_200_times);
}
diff --git a/test/scl/test_prg.cc b/test/scl/test_prg.cc
index 5d8e843..e156824 100644
--- a/test/scl/test_prg.cc
+++ b/test/scl/test_prg.cc
@@ -24,10 +24,13 @@
inline bool BufferCmp(const unsigned char* b0, const unsigned char* b1,
unsigned len) {
- auto p0 = b0;
- auto p1 = b1;
- while (len-- > 0)
- if (*p0++ != *p1++) return false;
+ const auto* p0 = b0;
+ const auto* p1 = b1;
+ while (len-- > 0) {
+ if (*p0++ != *p1++) {
+ return false;
+ }
+ }
return true;
}
@@ -40,7 +43,7 @@ inline bool BufferLooksRandom(const unsigned char* p, unsigned len) {
bool all_in_interval = true;
for (std::size_t i = 0; i < 256; i++) {
- auto p = 100 * ((float)buckets[i] / len);
+ auto p = 100 * ((float)buckets[i] / (float)len);
all_in_interval &= p >= 0.2 || p <= 6.0;
}
return all_in_interval;
@@ -84,7 +87,9 @@ TEST_CASE("PRG", "[misc]") {
std::vector buffer(100);
prg.Next(buffer, 50);
bool last_is_zero = true;
- for (std::size_t i = 50; i < 100; i++) last_is_zero &= buffer[i] == 0;
+ for (std::size_t i = 50; i < 100; i++) {
+ last_is_zero &= buffer[i] == 0;
+ }
REQUIRE(last_is_zero);
REQUIRE_THROWS_MATCHES(