diff --git a/.clang-format b/.clang-format
new file mode 100644
index 0000000..7520c79
--- /dev/null
+++ b/.clang-format
@@ -0,0 +1,30 @@
+Language: Cpp
+BasedOnStyle: Google
+DerivePointerAlignment: false
+AllowShortFunctionsOnASingleLine: Empty
+BinPackArguments: false
+BinPackParameters: false
+
+# The include rules below structures includes as
+#
+# - STL headers (anything without an extension, tbp)
+# - Other headers (anything that ends with .h)
+# - External SCL headers (anything of the form )
+# - Internal SCL headers (anything of the form "scl/...")
+#
+# The only exception is when a .cc file includes a header file with the same
+# name at the same path.
+
+IncludeCategories:
+ - Regex: '^'
+ Priority: 2
+ SortPriority: 0
+ - Regex: '^<.*'
+ Priority: 1
+ SortPriority: 0
+ - Regex: '^scl/.*'
+ Priority: 5
+ SortPriority: 0
diff --git a/.clang-tidy b/.clang-tidy
new file mode 100644
index 0000000..7428941
--- /dev/null
+++ b/.clang-tidy
@@ -0,0 +1,27 @@
+Checks: '-*,bugprone-*,performance-*,readability-*,google-global-names-in-headers,cert-dcl59-cpp,-bugprone-easily-swappable-parameters,-readability-identifier-length,-readability-magic-numbers,-readability-function-cognitive-complexity,-readability-function-size'
+
+# Enabled checks:
+# - bugprone
+# - performance
+# - readability
+# - google-global-names-in-headers
+# - cert-dcl59-cpp
+#
+# Specific disabled checks
+#
+# bugprone-easily-swappable-parameters:
+# Doesn't make sense to exclude functions taking multiple ints in SCL.
+#
+# readability-identifier-length:
+# Short identifiers make sense.
+#
+# readability-magic-numbers:
+# Too strict.
+#
+# readability-function-cognitive-complexity
+# Catch2.
+#
+# readability-function-size
+# Catch2.
+
+AnalyzeTemporaryDtors: false
diff --git a/.github/workflows/Checks.yml b/.github/workflows/Checks.yml
index 05adccc..93089d7 100644
--- a/.github/workflows/Checks.yml
+++ b/.github/workflows/Checks.yml
@@ -9,7 +9,7 @@ on:
jobs:
documentation:
name: Documentation
- runs-on: ubuntu-latest
+ runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v2
@@ -39,11 +39,11 @@ jobs:
- uses: actions/checkout@v2
- name: Setup
- run: sudo apt-get install -y clang-format-12
+ run: sudo apt-get install -y clang-format
- name: Check
shell: bash
run: |
- find . -type f \( -iname "*.h" -o -iname "*.cc" \) -exec clang-format -n --style=Google {} \; &> checks.txt
+ find . -type f \( -iname "*.h" -o -iname "*.cc" \) -exec clang-format -n {} \; &> checks.txt
cat checks.txt
test ! -s checks.txt
diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml
index c52ab45..28c491c 100644
--- a/.github/workflows/Test.yml
+++ b/.github/workflows/Test.yml
@@ -12,7 +12,7 @@ env:
jobs:
build:
name: Coverage and Linting
- runs-on: ubuntu-latest
+ runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v2
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 25be133..5ac7c25 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -16,7 +16,7 @@
cmake_minimum_required( VERSION 3.14 )
-project( scl VERSION 3.0.0 DESCRIPTION "Secure Computation Library" )
+project( scl VERSION 4.0.0 DESCRIPTION "Secure Computation Library" )
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
@@ -35,10 +35,12 @@ set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -Wall -Wextra -pedantic -Werror -std=gnu++17")
set(SCL_SOURCE_FILES
- src/scl/prg.cc
- src/scl/hash.cc
+ src/scl/util/str.cc
+
+ src/scl/primitives/prg.cc
+ src/scl/primitives/sha3.cc
+ src/scl/primitives/sha256.cc
- src/scl/math/str.cc
src/scl/math/mersenne61.cc
src/scl/math/mersenne127.cc
@@ -87,8 +89,9 @@ if(CMAKE_BUILD_TYPE MATCHES "Debug")
set(SCL_TEST_SOURCE_FILES
test/scl/main.cc
- test/scl/test_hash.cc
- test/scl/test_prg.cc
+ test/scl/primitives/test_prg.cc
+ test/scl/primitives/test_sha3.cc
+ test/scl/primitives/test_sha256.cc
test/scl/gf7.cc
test/scl/math/test_mersenne61.cc
@@ -102,6 +105,7 @@ if(CMAKE_BUILD_TYPE MATCHES "Debug")
test/scl/ss/test_additive.cc
test/scl/ss/test_poly.cc
test/scl/ss/test_shamir.cc
+ test/scl/ss/test_feldman.cc
test/scl/net/util.cc
test/scl/net/test_config.cc
diff --git a/README.md b/README.md
index 820fd48..14121e7 100644
--- a/README.md
+++ b/README.md
@@ -62,3 +62,15 @@ inspiration.
SCL uses Doxygen for documentation. Run `./scripts/build_documentation.sh` to
generate the documentation. This is placed in the `doc/` folder. Documentation
uses `doxygen`, so make sure that's installed.
+
+# Citing
+
+I'd greatly appreciate any work that uses SCL include the below bibtex entry
+
+```
+@misc{secure-computation-library,
+ author = {Anders Dalskov},
+ title = {{SCL (Secure Computation Library)---utility library for prototyping MPC applications}},
+ howpublished = {\url{https://github.com/anderspkd/secure-computation-library}},
+}
+```
diff --git a/RELEASE.txt b/RELEASE.txt
index f56e41d..80120a4 100644
--- a/RELEASE.txt
+++ b/RELEASE.txt
@@ -1,3 +1,14 @@
+4.0: Shamir, Feldman, SHA-256
+- Refactor Shamir to allow caching of Lagrange coefficients
+- Add support for Feldman Secret Sharing
+- Add support for SHA-256
+- Add bibtex blob for citing SCL
+- Refactor interface for hash functions
+- Refactor interface for Shamir
+- bugs:
+ - Fix negation of 0 in Secp256k1::Field and Secp256k1::Order
+ - Make serialization and deserialization of curve points behave more sanely
+
3.0: More features, build changes
- Add method for returning a point as a pair of affine coordinates
- Add method to check if a channel has data available
diff --git a/examples/01_primitives.cc b/examples/01_primitives.cc
index 441e330..a99b3cc 100644
--- a/examples/01_primitives.cc
+++ b/examples/01_primitives.cc
@@ -42,5 +42,5 @@ int main() {
/* The DigestToString can be used to print a hex representation of a digest.
*/
- std::cout << scl::DigestToString(digest) << "\n";
+ std::cout << scl::details::DigestToString(digest) << "\n";
}
diff --git a/examples/02_finite_fields.cc b/examples/02_finite_fields.cc
index 891acc7..7951e45 100644
--- a/examples/02_finite_fields.cc
+++ b/examples/02_finite_fields.cc
@@ -18,10 +18,10 @@
* along with this program. If not, see .
*/
-#include
-
#include
+#include
+
int main() {
/* This defines a "Finite Field" with space for at least 32 bits of
* computation. At the moment, SCL supports two primes: One that is 61 bits
@@ -67,7 +67,7 @@ int main() {
std::cout << a << " ?= " << b << ": " << (a == b) << "\n";
std::cout << a << " ?= " << a << ": " << (a == a) << "\n";
- scl::PRG prg;
+ auto prg = scl::PRG::Create();
/* Using a PRG (see the PRG example), we can generate random field elements.
*/
diff --git a/examples/03_secret_sharing.cc b/examples/03_secret_sharing.cc
index c5f9dcc..d98b553 100644
--- a/examples/03_secret_sharing.cc
+++ b/examples/03_secret_sharing.cc
@@ -18,16 +18,17 @@
* along with this program. If not, see .
*/
-#include
-#include
-
#include
#include
+#include
+#include
+
int main() {
using Fp = scl::Fp<32>;
using Vec = scl::Vec;
- scl::PRG prg;
+
+ auto prg = scl::PRG::Create();
/* We can easily create an additive secret sharing of some secret value:
*/
@@ -46,8 +47,8 @@ int main() {
* correction. Lets see error detection at work first
*/
- scl::details::ShamirSSFactory factory(
- 1, prg, scl::details::SecurityLevel::CORRECT);
+ auto factory =
+ scl::ShamirSSFactory::Create(1, prg, scl::SecurityLevel::CORRECT);
/* We create 4 shamir shares with a threshold of 1.
*/
auto shamir_shares = factory.Share(secret);
@@ -56,17 +57,15 @@ int main() {
/* Of course, these can be reconstructed. The second parameter is the
* threshold. This performs reconstruction with error detection.
*/
- auto recon = factory.GetInterpolator();
auto shamir_reconstructed =
- recon.Reconstruct(shamir_shares, scl::details::SecurityLevel::DETECT);
+ factory.Recover(shamir_shares, scl::SecurityLevel::DETECT);
std::cout << shamir_reconstructed << "\n";
/* If we introduce an error, then reconstruction fails
*/
shamir_shares[2] = Fp(123);
try {
- std::cout << recon.Reconstruct(shamir_shares,
- scl::details::SecurityLevel::DETECT)
+ std::cout << factory.Recover(shamir_shares, scl::SecurityLevel::DETECT)
<< "\n";
} catch (std::logic_error& e) {
std::cout << e.what() << "\n";
@@ -75,7 +74,7 @@ int main() {
/* On the other hand, we can use the robust reconstruction since the threshold
* is low enough. I.e., because 4 >= 3*1 + 1.
*/
- auto r = recon.Reconstruct(shamir_shares);
+ auto r = factory.Recover(shamir_shares);
std::cout << r << "\n";
/* With a bit of extra work, we can even learn which share had the error.
@@ -104,7 +103,7 @@ int main() {
*/
shamir_shares[1] = Fp(22);
try {
- recon.Reconstruct(shamir_shares);
+ factory.Recover(shamir_shares);
} catch (std::logic_error& e) {
std::cout << e.what() << "\n";
}
diff --git a/examples/04_networking.cc b/examples/04_networking.cc
index c37acc6..0ba6a82 100644
--- a/examples/04_networking.cc
+++ b/examples/04_networking.cc
@@ -18,11 +18,10 @@
* along with this program. If not, see .
*/
-#include
-
#include
-#include "scl/net/tcp_channel.h"
+#include
+#include
scl::NetworkConfig RunServer(int n) {
scl::DiscoveryServer server(n);
@@ -67,7 +66,7 @@ int main(int argc, char** argv) {
auto network = scl::Network::Create(config);
- for (std::size_t i = 0; i < 3; ++i) {
+ for (std::size_t i = 0; i < (std::size_t)n; ++i) {
// similar to the TCP channel example, send our ID to everyone:
network.Party(i)->Send(config.Id());
unsigned received_id;
diff --git a/include/scl/math/ec.h b/include/scl/math/ec.h
index caed8d4..057888a 100644
--- a/include/scl/math/ec.h
+++ b/include/scl/math/ec.h
@@ -26,7 +26,7 @@
#include "scl/math/ec_ops.h"
#include "scl/math/ff.h"
#include "scl/math/number.h"
-#include "scl/prg.h"
+#include "scl/primitives/prg.h"
namespace scl {
@@ -44,6 +44,11 @@ class EC {
*/
using Field = FF;
+ /**
+ * @brief A large sub-group of this curve.
+ */
+ using Order = FF;
+
/**
* @brief The size of a curve point in bytes.
* @param compressed
@@ -62,7 +67,9 @@ class EC {
/**
* @brief A string indicating which curve this is.
*/
- constexpr static const char* Name() { return Curve::kName; };
+ constexpr static const char* Name() {
+ return Curve::kName;
+ };
/**
* @brief Get the generator of this curve.
@@ -96,7 +103,9 @@ class EC {
/**
* @brief Create a new point equal to the point at infinity.
*/
- explicit constexpr EC() { details::CurveSetPointAtInfinity(mValue); };
+ explicit constexpr EC() {
+ details::CurveSetPointAtInfinity(mValue);
+ };
/**
* @brief Add another EC point to this.
@@ -173,7 +182,7 @@ class EC {
* @param scalar the scalar
* @return this.
*/
- EC& operator*=(const FF& scalar) {
+ EC& operator*=(const Order& scalar) {
details::CurveScalarMultiply(mValue, scalar);
return *this;
};
@@ -195,8 +204,7 @@ class EC {
* @param scalar the scalar
* @return the point multiplied with the scalar.
*/
- friend EC operator*(const EC& point,
- const FF& scalar) {
+ friend EC operator*(const EC& point, const Order& scalar) {
EC copy(point);
return copy *= scalar;
};
diff --git a/include/scl/math/ec_ops.h b/include/scl/math/ec_ops.h
index 8a601f4..8849f75 100644
--- a/include/scl/math/ec_ops.h
+++ b/include/scl/math/ec_ops.h
@@ -56,7 +56,8 @@ void CurveSetGenerator(typename C::ValueType& out);
* @param y the y coordinate
*/
template
-void CurveSetAffine(typename C::ValueType& out, const FF& x,
+void CurveSetAffine(typename C::ValueType& out,
+ const FF& x,
const FF& y);
/**
@@ -139,7 +140,8 @@ void CurveFromBytes(typename C::ValueType& out, const unsigned char* src);
* @param compress whether to compress the point
*/
template
-void CurveToBytes(unsigned char* dest, const typename C::ValueType& in,
+void CurveToBytes(unsigned char* dest,
+ const typename C::ValueType& in,
bool compress);
/**
diff --git a/include/scl/math/ff.h b/include/scl/math/ff.h
index 17739a7..a1e6d53 100644
--- a/include/scl/math/ff.h
+++ b/include/scl/math/ff.h
@@ -27,7 +27,7 @@
#include "scl/math/ff_ops.h"
#include "scl/math/ring.h"
-#include "scl/prg.h"
+#include "scl/primitives/prg.h"
namespace scl {
@@ -47,17 +47,23 @@ class FF final : details::RingBase> {
/**
* @brief Size in bytes of a field element.
*/
- constexpr static std::size_t ByteSize() { return Field::kByteSize; };
+ constexpr static std::size_t ByteSize() {
+ return Field::kByteSize;
+ };
/**
* @brief Actual bit size of an element.
*/
- constexpr static std::size_t BitSize() { return Field::kBitSize; };
+ constexpr static std::size_t BitSize() {
+ return Field::kBitSize;
+ };
/**
* @brief A short string representation of this field.
*/
- constexpr static const char* Name() { return Field::kName; };
+ constexpr static const char* Name() {
+ return Field::kName;
+ };
/**
* @brief Read a field element from a buffer.
@@ -166,7 +172,9 @@ class FF final : details::RingBase> {
* @param other the other element
* @return this set to this * other.Inverse().
*/
- FF& operator/=(const FF& other) { return operator*=(other.Inverse()); };
+ FF& operator/=(const FF& other) {
+ return operator*=(other.Inverse());
+ };
/**
* @brief Negates this element.
diff --git a/include/scl/math/ff_ops.h b/include/scl/math/ff_ops.h
index 8c293fd..0234bc4 100644
--- a/include/scl/math/ff_ops.h
+++ b/include/scl/math/ff_ops.h
@@ -26,8 +26,6 @@
#include
#include
-#include "scl/math/str.h"
-
namespace scl {
namespace details {
@@ -89,19 +87,14 @@ bool FieldEqual(const typename F::ValueType& in1,
/**
* @brief Convert a field element to bytes.
+ * @param dest the field element to convert
+ * @param src where to store the converted element
*
* For types that are trivially copyable, this function has a default
* implementation based on std::memcpy.
- *
- * @param dest the field element to convert
- * @param src where to store the converted element
*/
-template ,
- int> = 0>
-void FieldToBytes(unsigned char* dest, const typename F::ValueType& src) {
- std::memcpy(dest, &src, sizeof(typename F::ValueType));
-}
+template
+void FieldToBytes(unsigned char* dest, const typename F::ValueType& src);
/**
* @brief Convert the content of a buffer to a field element.
diff --git a/include/scl/math/mat.h b/include/scl/math/mat.h
index c027f7d..4f76f4d 100644
--- a/include/scl/math/mat.h
+++ b/include/scl/math/mat.h
@@ -30,7 +30,7 @@
#include
#include
-#include "scl/prg.h"
+#include "scl/primitives/prg.h"
namespace scl {
@@ -73,7 +73,8 @@ class Mat {
* @param xs a \p n length vector containing the x values to use
* @return a Vandermonde matrix.
*/
- static Mat Vandermonde(std::size_t n, std::size_t m,
+ static Mat Vandermonde(std::size_t n,
+ std::size_t m,
const std::vector& xs);
/**
@@ -119,7 +120,8 @@ class Mat {
* @param vec the elements of the matrix
* @return a Matrix.
*/
- static Mat FromVector(std::size_t n, std::size_t m,
+ static Mat FromVector(std::size_t n,
+ std::size_t m,
const std::vector& vec) {
if (vec.size() != n * m) {
throw std::invalid_argument("invalid dimensions");
@@ -167,12 +169,16 @@ class Mat {
/**
* @brief The number of rows of this matrix.
*/
- std::size_t Rows() const { return mRows; };
+ std::size_t Rows() const {
+ return mRows;
+ };
/**
* @brief The number of columns of this matrix.
*/
- std::size_t Cols() const { return mCols; };
+ std::size_t Cols() const {
+ return mCols;
+ };
/**
* @brief Provides mutable access to a matrix element.
@@ -313,7 +319,9 @@ class Mat {
/**
* @brief Check if this matrix is square.
*/
- bool IsSquare() const { return Rows() == Cols(); };
+ bool IsSquare() const {
+ return Rows() == Cols();
+ };
/**
* @brief Transpose this matrix.
@@ -385,7 +393,9 @@ class Mat {
/**
* @brief The size of a matrix when serialized in bytes.
*/
- std::size_t ByteSize() const { return Rows() * Cols() * T::ByteSize(); }
+ std::size_t ByteSize() const {
+ return Rows() * Cols() * T::ByteSize();
+ }
private:
Mat(std::size_t r, std::size_t c, std::vector v)
@@ -445,7 +455,8 @@ Mat Mat::Random(std::size_t n, std::size_t m, PRG& prg) {
}
template
-Mat Mat::Vandermonde(std::size_t n, std::size_t m,
+Mat Mat::Vandermonde(std::size_t n,
+ std::size_t m,
const std::vector& xs) {
if (xs.size() != n) {
throw std::invalid_argument("|xs| != number of rows");
diff --git a/include/scl/math/number.h b/include/scl/math/number.h
index 4b536ed..17a4761 100644
--- a/include/scl/math/number.h
+++ b/include/scl/math/number.h
@@ -21,12 +21,12 @@
#ifndef SCL_MATH_NUMBER_H
#define SCL_MATH_NUMBER_H
-#include
-
#include
#include
-#include "scl/prg.h"
+#include
+
+#include "scl/primitives/prg.h"
namespace scl {
@@ -99,17 +99,15 @@ class Number {
return *this;
};
- // This is used for all op-assign operator overloads below.
-#define SCL_OP_IMPL(op, arg) \
- *this = *this op(arg); \
- return *this
-
/**
* @brief In-place addition of two numbers.
* @param number the other number
* @return this
*/
- Number& operator+=(const Number& number) { SCL_OP_IMPL(+, number); };
+ Number& operator+=(const Number& number) {
+ *this = *this + number;
+ return *this;
+ };
/**
* @brief Add two numbers.
@@ -123,7 +121,10 @@ class Number {
* @param number the other number
* @return this.
*/
- Number& operator-=(const Number& number) { SCL_OP_IMPL(-, number); };
+ Number& operator-=(const Number& number) {
+ *this = *this - number;
+ return *this;
+ };
/**
* @brief Subtract two Numbers.
@@ -143,7 +144,10 @@ class Number {
* @param number the other Number
* @return this.
*/
- Number& operator*=(const Number& number) { SCL_OP_IMPL(*, number); };
+ Number& operator*=(const Number& number) {
+ *this = *this * number;
+ return *this;
+ };
/**
* @brief Multiply two Numbers.
@@ -157,7 +161,10 @@ class Number {
* @param number the other number
* @return this.
*/
- Number& operator/=(const Number& number) { SCL_OP_IMPL(/, number); };
+ Number& operator/=(const Number& number) {
+ *this = *this / number;
+ return *this;
+ };
/**
* @brief Divide two Numbers.
@@ -171,7 +178,10 @@ class Number {
* @param shift the amount to left shift
* @return this.
*/
- Number& operator<<=(int shift) { SCL_OP_IMPL(<<, shift); };
+ Number& operator<<=(int shift) {
+ *this = *this << shift;
+ return *this;
+ };
/**
* @brief Perform a left shift of a Number.
@@ -185,7 +195,10 @@ class Number {
* @param shift the amount to right shift
* @return this.
*/
- Number& operator>>=(int shift) { SCL_OP_IMPL(>>, shift); };
+ Number& operator>>=(int shift) {
+ *this = *this >> shift;
+ return *this;
+ };
/**
* @brief Perform a right shift of a Number.
@@ -199,7 +212,10 @@ class Number {
* @param number the number to xor this with
* @return \p this
*/
- Number& operator^=(const Number& number) { SCL_OP_IMPL(^, number); };
+ Number& operator^=(const Number& number) {
+ *this = *this ^ number;
+ return *this;
+ };
/**
* @brief Exclusive or of two numbers.
@@ -213,7 +229,10 @@ class Number {
* @param number
* @return
*/
- Number& operator|=(const Number& number) { SCL_OP_IMPL(|, number); };
+ Number& operator|=(const Number& number) {
+ *this = *this | number;
+ return *this;
+ };
/**
* @brief operator |
@@ -227,9 +246,10 @@ class Number {
* @param number
* @return
*/
- Number& operator&=(const Number& number) { SCL_OP_IMPL(&, number); };
-
-#undef SCL_OP_IMPL
+ Number& operator&=(const Number& number) {
+ *this = *this & number;
+ return *this;
+ };
/**
* @brief operator &
@@ -320,13 +340,17 @@ class Number {
* @brief Test if this Number is odd.
* @return true if this Number is odd.
*/
- bool Odd() const { return TestBit(0); };
+ bool Odd() const {
+ return TestBit(0);
+ };
/**
* @brief Test if this Number is even.
* @return true if this Number is even.
*/
- bool Even() const { return !Odd(); };
+ bool Even() const {
+ return !Odd();
+ };
/**
* @brief Return a string representation of this Number.
diff --git a/include/scl/math/ring.h b/include/scl/math/ring.h
index a14deef..4224588 100644
--- a/include/scl/math/ring.h
+++ b/include/scl/math/ring.h
@@ -35,7 +35,7 @@ struct RingBase {
/**
* @brief Add two elements and return their sum.
*/
- friend T operator+(const T &lhs, const T &rhs) {
+ friend T operator+(const T& lhs, const T& rhs) {
T temp(lhs);
return temp += rhs;
};
@@ -43,7 +43,7 @@ struct RingBase {
/**
* @brief Subtract two elements and return their difference.
*/
- friend T operator-(const T &lhs, const T &rhs) {
+ friend T operator-(const T& lhs, const T& rhs) {
T temp(lhs);
return temp -= rhs;
};
@@ -51,7 +51,7 @@ struct RingBase {
/**
* @brief Return the negation of an element.
*/
- friend T operator-(const T &elem) {
+ friend T operator-(const T& elem) {
T temp(elem);
return temp.Negate();
};
@@ -59,7 +59,7 @@ struct RingBase {
/**
* @brief Multiply two elements and return their product.
*/
- friend T operator*(const T &lhs, const T &rhs) {
+ friend T operator*(const T& lhs, const T& rhs) {
T temp(lhs);
return temp *= rhs;
};
@@ -67,7 +67,7 @@ struct RingBase {
/**
* @brief Divide two elements and return their quotient.
*/
- friend T operator/(const T &lhs, const T &rhs) {
+ friend T operator/(const T& lhs, const T& rhs) {
T temp(lhs);
return temp /= rhs;
};
@@ -75,17 +75,21 @@ struct RingBase {
/**
* @brief Compare two elements for equality.
*/
- friend bool operator==(const T &lhs, const T &rhs) { return lhs.Equal(rhs); };
+ friend bool operator==(const T& lhs, const T& rhs) {
+ return lhs.Equal(rhs);
+ };
/**
* @brief Compare two elements for inequality.
*/
- friend bool operator!=(const T &lhs, const T &rhs) { return !(lhs == rhs); };
+ friend bool operator!=(const T& lhs, const T& rhs) {
+ return !(lhs == rhs);
+ };
/**
* @brief Write a string representation of an element to a stream.
*/
- friend std::ostream &operator<<(std::ostream &os, const T &r) {
+ friend std::ostream& operator<<(std::ostream& os, const T& r) {
return os << r.ToString();
};
};
diff --git a/include/scl/math/vec.h b/include/scl/math/vec.h
index ab5e89a..f527b86 100644
--- a/include/scl/math/vec.h
+++ b/include/scl/math/vec.h
@@ -29,7 +29,7 @@
#include
#include "scl/math/mat.h"
-#include "scl/prg.h"
+#include "scl/primitives/prg.h"
namespace scl {
@@ -41,8 +41,8 @@ namespace details {
* @param xe end of the first iterator
* @param yb start of the second iterator
*/
-template
-T UncheckedInnerProd(It xb, It xe, It yb) {
+template
+T UncheckedInnerProd(I0 xb, I0 xe, I1 yb) {
T v;
while (xb != xe) {
v += *xb++ * *yb++;
@@ -161,17 +161,23 @@ class Vec {
/**
* @brief The size of the Vec.
*/
- std::size_t Size() const { return mValues.size(); };
+ std::size_t Size() const {
+ return mValues.size();
+ };
/**
* @brief Mutable access to vector elements.
*/
- T& operator[](std::size_t idx) { return mValues[idx]; };
+ T& operator[](std::size_t idx) {
+ return mValues[idx];
+ };
/**
* @brief Read only access to vector elements.
*/
- T operator[](std::size_t idx) const { return mValues[idx]; };
+ T operator[](std::size_t idx) const {
+ return mValues[idx];
+ };
/**
* @brief Add two Vec objects entry-wise.
@@ -305,22 +311,30 @@ class Vec {
/**
* @brief Convert this vector into a 1-by-N row matrix.
*/
- Mat ToRowMatrix() const { return Mat{1, Size(), mValues}; };
+ Mat ToRowMatrix() const {
+ return Mat{1, Size(), mValues};
+ };
/**
* @brief Convert this vector into a N-by-1 column matrix.
*/
- Mat ToColumnMatrix() const { return Mat{Size(), 1, mValues}; };
+ Mat ToColumnMatrix() const {
+ return Mat{Size(), 1, mValues};
+ };
/**
* @brief Convert this Vec object to an std::vector.
*/
- std::vector& ToStlVector() { return mValues; };
+ std::vector& ToStlVector() {
+ return mValues;
+ };
/**
* @brief Convert this Vec object to a const std::vector.
*/
- const std::vector& ToStlVector() const { return mValues; };
+ const std::vector& ToStlVector() const {
+ return mValues;
+ };
/**
* @brief Extract a sub-vector
@@ -343,7 +357,9 @@ class Vec {
* @param end the end index, exclusive
* @return a sub-vector.
*/
- Vec SubVector(std::size_t end) { return SubVector(0, end); };
+ Vec SubVector(std::size_t end) {
+ return SubVector(0, end);
+ };
/**
* @brief Return a string representation of this vector.
@@ -366,67 +382,93 @@ class Vec {
/**
* @brief Returns the number of bytes that Write writes.
*/
- std::size_t ByteSize() const { return Size() * T::ByteSize(); };
+ std::size_t ByteSize() const {
+ return Size() * T::ByteSize();
+ };
/**
* @brief Return an iterator pointing to the start of this Vec.
*/
- iterator begin() { return mValues.begin(); };
+ iterator begin() {
+ return mValues.begin();
+ };
/**
* @brief Provides a const iterator to the start of this Vec.
*/
- const_iterator begin() const { return mValues.begin(); };
+ const_iterator begin() const {
+ return mValues.begin();
+ };
/**
* @brief Provides a const iterator to the start of this Vec.
*/
- const_iterator cbegin() const { return mValues.cbegin(); };
+ const_iterator cbegin() const {
+ return mValues.cbegin();
+ };
/**
* @brief Provides an iterator pointing to the end of this Vec.
*/
- iterator end() { return mValues.end(); };
+ iterator end() {
+ return mValues.end();
+ };
/**
* @brief Provides a const iterator pointing to the end of this Vec.
*/
- const_iterator end() const { return mValues.end(); };
+ const_iterator end() const {
+ return mValues.end();
+ };
/**
* @brief Provides a const iterator pointing to the end of this Vec.
*/
- const_iterator cend() const { return mValues.cend(); };
+ const_iterator cend() const {
+ return mValues.cend();
+ };
/**
* @brief Provides a reverse iterator pointing to the end of this Vec.
*/
- reverse_iterator rbegin() { return mValues.rbegin(); };
+ reverse_iterator rbegin() {
+ return mValues.rbegin();
+ };
/**
* @brief Provides a reverse const iterator pointing to the end of this Vec.
*/
- const_reverse_iterator rbegin() const { return mValues.rbegin(); };
+ const_reverse_iterator rbegin() const {
+ return mValues.rbegin();
+ };
/**
* @brief Provides a reverse const iterator pointing to the end of this Vec.
*/
- const_reverse_iterator crbegin() const { return mValues.crbegin(); };
+ const_reverse_iterator crbegin() const {
+ return mValues.crbegin();
+ };
/**
* @brief Provides a reverse iterator pointing to the start of this Vec.
*/
- reverse_iterator rend() { return mValues.rend(); };
+ reverse_iterator rend() {
+ return mValues.rend();
+ };
/**
* @brief Provides a reverse const iterator pointing to the start of this Vec.
*/
- const_reverse_iterator rend() const { return mValues.rend(); };
+ const_reverse_iterator rend() const {
+ return mValues.rend();
+ };
/**
* @brief Provides a reverse const iterator pointing to the start of this Vec.
*/
- const_reverse_iterator crend() const { return mValues.crend(); };
+ const_reverse_iterator crend() const {
+ return mValues.crend();
+ };
private:
void EnsureCompatible(const Vec& other) const {
diff --git a/include/scl/math/z2k.h b/include/scl/math/z2k.h
index 6610075..5fa5ad2 100644
--- a/include/scl/math/z2k.h
+++ b/include/scl/math/z2k.h
@@ -25,7 +25,7 @@
#include "scl/math/ring.h"
#include "scl/math/z2k_ops.h"
-#include "scl/prg.h"
+#include "scl/primitives/prg.h"
namespace scl {
@@ -50,22 +50,30 @@ class Z2k final : public details::RingBase> {
/**
* @brief The bit size of the ring. Identical to BitSize().
*/
- constexpr static std::size_t SpecifiedBitSize() { return K; };
+ constexpr static std::size_t SpecifiedBitSize() {
+ return K;
+ };
/**
* @brief The number of bytes needed to store a ring element.
*/
- constexpr static std::size_t ByteSize() { return (K - 1) / 8 + 1; };
+ constexpr static std::size_t ByteSize() {
+ return (K - 1) / 8 + 1;
+ };
/**
* @brief The bit size of the ring. Identical to SpecifiedBitSize().
*/
- constexpr static std::size_t BitSize() { return SpecifiedBitSize(); };
+ constexpr static std::size_t BitSize() {
+ return SpecifiedBitSize();
+ };
/**
* @brief A short string representation of this ring.
*/
- constexpr static const char* Name() { return "Z2k"; };
+ constexpr static const char* Name() {
+ return "Z2k";
+ };
/**
* @brief Read a ring from a buffer.
@@ -193,7 +201,9 @@ class Z2k final : public details::RingBase> {
* particular, an element x is invertible if x.Lsb() ==
* 1. That is, if it is odd.
*/
- unsigned Lsb() const { return details::LsbZ2k(mValue); };
+ unsigned Lsb() const {
+ return details::LsbZ2k(mValue);
+ };
/**
* @brief Check if this element is equal to another element.
diff --git a/include/scl/math/z2k_ops.h b/include/scl/math/z2k_ops.h
index c767bad..16b1c57 100644
--- a/include/scl/math/z2k_ops.h
+++ b/include/scl/math/z2k_ops.h
@@ -26,7 +26,7 @@
#include
#include
-#include "scl/math/str.h"
+#include "scl/util/str.h"
namespace scl {
namespace details {
diff --git a/include/scl/net/channel.h b/include/scl/net/channel.h
index 1a4becb..b3b0732 100644
--- a/include/scl/net/channel.h
+++ b/include/scl/net/channel.h
@@ -92,14 +92,17 @@ class Channel {
* @return true if this channel has data and false otherwise.
* @note the default implementation always returns true.
*/
- virtual bool HasData() { return true; };
+ virtual bool HasData() {
+ return true;
+ };
/**
* @brief Send a trivially copyable item.
* @param src the thing to send
*/
- template , bool> = true>
+ template <
+ typename T,
+ typename std::enable_if_t, bool> = true>
void Send(const T& src) {
Send(SCL_CC(&src), sizeof(T));
}
@@ -112,8 +115,9 @@ class Channel {
*
* @param src an STL vector of things to send
*/
- template , bool> = true>
+ template <
+ typename T,
+ typename std::enable_if_t, bool> = true>
void Send(const std::vector& src) {
Send(src.size());
Send(SCL_CC(src.data()), sizeof(T) * src.size());
@@ -168,8 +172,9 @@ class Channel {
* @brief Receive a trivially copyable item.
* @param dst where to store the received item
*/
- template , bool> = true>
+ template <
+ typename T,
+ typename std::enable_if_t, bool> = true>
void Recv(T& dst) {
Recv(SCL_C(&dst), sizeof(T));
}
@@ -179,8 +184,9 @@ class Channel {
* @param dst where to store the received items
* @note any existing content in \p dst is overwritten.
*/
- template , bool> = true>
+ template <
+ typename T,
+ typename std::enable_if_t, bool> = true>
void Recv(std::vector& dst) {
std::size_t size;
Recv(size);
diff --git a/include/scl/net/config.h b/include/scl/net/config.h
index 8831779..529b675 100644
--- a/include/scl/net/config.h
+++ b/include/scl/net/config.h
@@ -112,22 +112,30 @@ class NetworkConfig {
/**
* @brief Gets the identity of this party.
*/
- int Id() const { return mId; };
+ int Id() const {
+ return mId;
+ };
/**
* @brief Gets the size of the network.
*/
- std::size_t NetworkSize() const { return mParties.size(); };
+ std::size_t NetworkSize() const {
+ return mParties.size();
+ };
/**
* @brief Get a list of connection information for parties in this network.
*/
- std::vector Parties() const { return mParties; };
+ std::vector Parties() const {
+ return mParties;
+ };
/**
* @brief Get information about a party.
*/
- Party GetParty(unsigned id) const { return mParties[id]; };
+ Party GetParty(unsigned id) const {
+ return mParties[id];
+ };
/**
* @brief Return a string representation of this network config.
diff --git a/include/scl/net/mem_channel.h b/include/scl/net/mem_channel.h
index 5e552b0..66faa82 100644
--- a/include/scl/net/mem_channel.h
+++ b/include/scl/net/mem_channel.h
@@ -80,7 +80,11 @@ class InMemoryChannel final : public Channel {
void Send(const unsigned char* src, std::size_t n) override;
std::size_t Recv(unsigned char* dst, std::size_t n) override;
- bool HasData() override { return mIn->Size() > 0 || !mOverflow.empty(); };
+
+ bool HasData() override {
+ return mIn->Size() > 0 || !mOverflow.empty();
+ };
+
void Close() override{};
private:
diff --git a/include/scl/net/network.h b/include/scl/net/network.h
index b1ed6f3..e15feea 100644
--- a/include/scl/net/network.h
+++ b/include/scl/net/network.h
@@ -65,12 +65,16 @@ class Network {
* @brief Get a channel to a particular party.
* @param id the id of the party
*/
- Channel* Party(unsigned id) { return mChannels[id].get(); }
+ Channel* Party(unsigned id) {
+ return mChannels[id].get();
+ }
/**
* @brief The size of the network.
*/
- std::size_t Size() const { return mChannels.size(); };
+ std::size_t Size() const {
+ return mChannels.size();
+ };
/**
* @brief Closes all channels in the network.
@@ -153,8 +157,8 @@ Network Network::Create(const scl::NetworkConfig& config) {
channels[config.Id()] = scl::details::CreateChannelConnectingToSelf();
- std::thread server(scl::details::SCL_AcceptConnections, std::ref(channels),
- config);
+ std::thread server(
+ scl::details::SCL_AcceptConnections, std::ref(channels), config);
for (std::size_t i = 0; i < static_cast(config.Id()); ++i) {
const auto party = config.GetParty(i);
diff --git a/include/scl/net/shared_deque.h b/include/scl/net/shared_deque.h
index 862c371..55c62a6 100644
--- a/include/scl/net/shared_deque.h
+++ b/include/scl/net/shared_deque.h
@@ -44,7 +44,7 @@ class SharedDeque {
/**
* @brief Read the top element from the queue.
*/
- T &Peek();
+ T& Peek();
/**
* @brief Remove and return the top element from the queue.
@@ -54,12 +54,12 @@ class SharedDeque {
/**
* @brief Insert an item to the back of the queue.
*/
- void PushBack(const T &item);
+ void PushBack(const T& item);
/**
* @brief Move an item to the back of the queue.
*/
- void PushBack(T &&item);
+ void PushBack(T&& item);
/**
* @brief Number of elements currently in the queue.
@@ -82,7 +82,7 @@ void SharedDeque::PopFront() {
}
template
-T &SharedDeque::Peek() {
+T& SharedDeque::Peek() {
std::unique_lock lock(mMutex);
while (mDeck.empty()) {
mCond.wait(lock);
@@ -102,7 +102,7 @@ T SharedDeque::Pop() {
}
template
-void SharedDeque::PushBack(const T &item) {
+void SharedDeque::PushBack(const T& item) {
std::unique_lock lock(mMutex);
mDeck.push_back(item);
lock.unlock();
@@ -110,7 +110,7 @@ void SharedDeque::PushBack(const T &item) {
}
template
-void SharedDeque::PushBack(T &&item) {
+void SharedDeque::PushBack(T&& item) {
std::unique_lock lock(mMutex);
mDeck.push_back(std::move(item));
lock.unlock();
diff --git a/include/scl/net/tcp_channel.h b/include/scl/net/tcp_channel.h
index f5af2fe..05e2d24 100644
--- a/include/scl/net/tcp_channel.h
+++ b/include/scl/net/tcp_channel.h
@@ -44,12 +44,16 @@ class TcpChannel final : public Channel {
/**
* @brief Destroying a TCP channel closes the connection.
*/
- ~TcpChannel() { Close(); };
+ ~TcpChannel() {
+ Close();
+ };
/**
* @brief Tells whether this channel is alive or not.
*/
- bool Alive() const { return mAlive; };
+ bool Alive() const {
+ return mAlive;
+ };
void Send(const unsigned char* src, std::size_t n) override;
std::size_t Recv(unsigned char* dst, std::size_t n) override;
diff --git a/include/scl/net/tcp_utils.h b/include/scl/net/tcp_utils.h
index fd75e70..3bf92e3 100644
--- a/include/scl/net/tcp_utils.h
+++ b/include/scl/net/tcp_utils.h
@@ -21,11 +21,11 @@
#ifndef SCL_NET_TCP_UTILS_H
#define SCL_NET_TCP_UTILS_H
-#include
-
#include
#include
+#include
+
namespace scl {
namespace details {
diff --git a/include/scl/net/threaded_sender.h b/include/scl/net/threaded_sender.h
index f06e1b5..e9e6710 100644
--- a/include/scl/net/threaded_sender.h
+++ b/include/scl/net/threaded_sender.h
@@ -56,7 +56,9 @@ class ThreadedSenderChannel final : public Channel {
return mChannel.Recv(dst, n);
};
- bool HasData() override { return mChannel.HasData(); };
+ bool HasData() override {
+ return mChannel.HasData();
+ };
private:
TcpChannel mChannel;
diff --git a/include/scl/p/simple.h b/include/scl/p/simple.h
index 04188de..9f372f2 100644
--- a/include/scl/p/simple.h
+++ b/include/scl/p/simple.h
@@ -33,12 +33,14 @@ struct ProtocolStep {
/**
* @brief Evaluate a step of the protocol.
*/
- auto Run(Ctx& context) { return static_cast(this)->Run(context); };
+ auto Run(Ctx& context) {
+ return static_cast(this)->Run(context);
+ };
};
/**
* @brief The final step of a protocol.
- */
+o */
template
struct LastProtocolStep {
/**
@@ -53,7 +55,8 @@ struct LastProtocolStep {
* @brief Recursively evaluate all internal steps of a protocol.
*/
template <
- typename S, typename Ctx,
+ typename S,
+ typename Ctx,
std::enable_if_t, S>, bool> = true>
auto Evaluate(S& step, Ctx& context) {
auto next = step.Run(context);
@@ -63,7 +66,8 @@ auto Evaluate(S& step, Ctx& context) {
/**
* @brief Evaluate the last step of a protocol.
*/
-template , S>,
bool> = true>
auto Evaluate(S& last, Ctx& context) {
diff --git a/include/scl/primitives.h b/include/scl/primitives.h
index 0235577..15bb05d 100644
--- a/include/scl/primitives.h
+++ b/include/scl/primitives.h
@@ -21,7 +21,7 @@
#ifndef SCL_PRIMITIVES_H
#define SCL_PRIMITIVES_H
-#include "scl/hash.h"
-#include "scl/prg.h"
+#include "scl/primitives/hash.h"
+#include "scl/primitives/prg.h"
#endif // SCL_PRIMITIVES_H
diff --git a/include/scl/primitives/digest.h b/include/scl/primitives/digest.h
new file mode 100644
index 0000000..eae103a
--- /dev/null
+++ b/include/scl/primitives/digest.h
@@ -0,0 +1,62 @@
+/**
+ * @file digest.h
+ *
+ * SCL --- Secure Computation Library
+ * Copyright (C) 2022 Anders Dalskov
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+#ifndef SCL_PRIMITIVES_DIGEST_H
+#define SCL_PRIMITIVES_DIGEST_H
+
+#include
+#include
+#include
+#include
+
+#include "scl/util/str.h"
+
+namespace scl {
+namespace details {
+
+/**
+ * @brief A digest of some bitsize.
+ *
+ * This type is effectively std::array.
+ *
+ * @tparam Bits the bitsize of the digest
+ */
+template
+struct Digest {
+ /**
+ * @brief The actual type of a digest.
+ */
+ using Type = std::array;
+};
+
+/**
+ * @brief Convert a digest to a string.
+ * @param digest the digest
+ * @return a hex representation of the digest.
+ */
+template
+std::string DigestToString(const D& digest) {
+ return ToHexString(digest.begin(), digest.end());
+}
+
+} // namespace details
+} // namespace scl
+
+#endif // SCL_PRIMITIVES_DIGEST_H
diff --git a/include/scl/primitives/hash.h b/include/scl/primitives/hash.h
new file mode 100644
index 0000000..86278ee
--- /dev/null
+++ b/include/scl/primitives/hash.h
@@ -0,0 +1,41 @@
+/**
+ * @file hash.h
+ *
+ * SCL --- Secure Computation Library
+ * Copyright (C) 2022 Anders Dalskov
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+#ifndef SCL_PRIMITIVES_HASH_H
+#define SCL_PRIMITIVES_HASH_H
+
+#include
+
+#include "scl/primitives/sha3.h"
+
+namespace scl {
+
+/**
+ * @brief A default hash function given a digest size.
+ *
+ * This type defults to one of the three instantiations of SHA3 that SCL
+ * provides.
+ */
+template
+using Hash = details::Sha3;
+
+} // namespace scl
+
+#endif // SCL_PRIMITIVES_HASH_H
diff --git a/include/scl/primitives/iuf_hash.h b/include/scl/primitives/iuf_hash.h
new file mode 100644
index 0000000..ee57821
--- /dev/null
+++ b/include/scl/primitives/iuf_hash.h
@@ -0,0 +1,88 @@
+/**
+ * @file iuf_hash.h
+ *
+ * SCL --- Secure Computation Library
+ * Copyright (C) 2022 Anders Dalskov
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+#ifndef SCL_PRIMITIVES_IUF_HASH_H
+#define SCL_PRIMITIVES_IUF_HASH_H
+
+#include
+#include
+#include
+#include
+#include
+
+namespace scl {
+namespace details {
+
+/**
+ * @brief IUF interface for hash functions.
+ *
+ * HashInterface defines an IUF (Initialize-Update-Finalize) style interface
+ * for a hash function.
+ *
+ * @tparam H interfaceementation.
+ */
+template
+struct IUFHash {
+ /**
+ * @brief Update the hash function with a set of bytes.
+ *
+ * @param bytes a pointer to a number of bytes.
+ * @param n the number of bytes.
+ * @return the updated Hash object.
+ */
+ IUFHash& Update(const unsigned char* bytes, std::size_t n) {
+ static_cast(this)->Hash(bytes, n);
+ return *this;
+ };
+
+ /**
+ * @brief Update the hash function with the content from a byte vector.
+ *
+ * @param data a vector of bytes.
+ * @return the updated Hash object.
+ */
+ IUFHash& Update(const std::vector& data) {
+ return Update(data.data(), data.size());
+ };
+
+ /**
+ * @brief Update the hash function with the content from a byte STL array.
+ *
+ * @param data the array
+ * @return the updated Hash object.
+ */
+ template
+ IUFHash& Update(const std::array& data) {
+ return Update(data.data(), N);
+ }
+
+ /**
+ * @brief Finalize and return the digest.
+ */
+ auto Finalize() {
+ auto digest = static_cast(this)->Write();
+ return digest;
+ };
+};
+
+} // namespace details
+} // namespace scl
+
+#endif // SCL_PRIMITIVES_IUF_HASH_H
diff --git a/include/scl/prg.h b/include/scl/primitives/prg.h
similarity index 51%
rename from include/scl/prg.h
rename to include/scl/primitives/prg.h
index 5699aa8..6fc94b5 100644
--- a/include/scl/prg.h
+++ b/include/scl/primitives/prg.h
@@ -18,15 +18,17 @@
* along with this program. If not, see .
*/
-#ifndef SCL_PRG_H
-#define SCL_PRG_H
-
-#include
+#ifndef SCL_PRIMITIVES_PRG_H
+#define SCL_PRIMITIVES_PRG_H
+#include
#include
#include
#include
+#include
+#include
+
/**
* @brief 64 bit nonce which is prepended to the counter in the PRG.
*/
@@ -61,103 +63,108 @@ namespace scl {
* as a macro. It defaults to 0x0123456789ABCDEF.
*/
class PRG {
- public:
- /**
- * @brief The size of an output block.
- */
- static constexpr std::size_t BlockSize() { return sizeof(BlockType); };
+ private:
+ using BlockType = __m128i;
+ static constexpr std::size_t kBlockSize = sizeof(BlockType);
+ public:
/**
- * @brief The size of the seed.
+ * @brief Size of the seed.
*/
- static constexpr std::size_t SeedSize() { return BlockSize(); };
+ static constexpr std::size_t SeedSize() {
+ return kBlockSize;
+ };
/**
- * @brief Construct a new PRG object with seed of 0.
+ * @brief Create a new PRG with seed 0.
*/
- PRG();
+ static PRG Create() {
+ PRG prg(nullptr);
+ prg.Init();
+ return prg;
+ };
/**
- * @brief Construct a new PRG with a given seed.
- *
+ * @brief Create a new PRG with a provided seed.
* @param seed the seed.
- *
- * @pre This constructor reads seed_size() bytes from
- * seed so the latter must point to that much allocated space.
*/
- PRG(const unsigned char *seed);
+ static PRG Create(const unsigned char* seed) {
+ PRG prg(seed);
+ prg.Init();
+ return prg;
+ };
/**
- * @brief Reset the PRG to its initial state.
+ * @brief Reset the PRG.
+ *
+ * This method allows resetting a PRG object to its initial state.
*/
void Reset();
/**
- * @brief Generate random data and store it in a supplied location.
- *
- * @param dest the destination of the generated random bytes.
- * @param nbytes how many bytes of random data to generate.
- *
- * @pre dest must point to nbytes of allocated
- * space.
- *
- * @throws std::runtime_error if allocation of roughly nbytes
- * bytes of memory fails.
+ * @brief Generate random data and store it in a supplied buffer.
+ * @param buffer the buffer
+ * @param n how many bytes of random data to generate
*/
- void Next(unsigned char *dest, std::size_t nbytes);
+ void Next(unsigned char* buffer, std::size_t n);
/**
- * @brief Generate random data.
- * @param dest the destination of the random data
+ * @brief Generate random data and store it in a supplied buffer.
+ * @param buffer the buffer with space pre-allocated
+ *
+ * How many bytes of random data to generate is decided based on the output of
+ * buffer.size().
*/
- void Next(std::vector &dest) {
- Next(dest.data(), dest.size());
+ void Next(std::vector& buffer) {
+ Next(buffer.data(), buffer.size());
};
/**
- * @brief Generate random data.
- * @param dest the destination of the random data
- * @param nbytes how many bytes to generate.
+ * @brief Generate random data and store in in a supplied buffer.
+ * @param buffer the buffer
+ * @param n how many random bytes to generate
+ * @throws std::invalid_argument if \p n is greater than
+ * buffer.size().
+ *
+ * The capacity of \p buffer is not affected in any way by this method and it
+ * requires that it has room for at least \p n elements.
*/
- void Next(std::vector &dest, std::size_t nbytes) {
- if (dest.size() < nbytes) {
+ void Next(std::vector& buffer, std::size_t n) {
+ if (buffer.size() < n) {
throw std::invalid_argument("requested more randomness than dest.size()");
}
- Next(dest.data(), nbytes);
+ Next(buffer.data(), n);
};
/**
- * @brief Generate and return some random data.
- * @param nbytes the number of random bytes to generate
+ * @brief Generate and return random data.
+ * @param n the number of random bytes to generate
* @return the random bytes.
*/
- std::vector Next(std::size_t nbytes) {
- auto buffer = std::make_unique(nbytes);
- Next(buffer.get(), nbytes);
- return std::vector(buffer.get(), buffer.get() + nbytes);
+ std::vector Next(std::size_t n) {
+ auto buffer = std::make_unique(n);
+ Next(buffer.get(), n);
+ return std::vector(buffer.get(), buffer.get() + n);
};
/**
- * @brief The seed of the PRG.
+ * @brief The seed.
*/
- const unsigned char *Seed() const { return mSeed; };
-
- /**
- * @brief The current counter of the PRG.
- */
- long Counter() const { return mCounter; };
+ std::array Seed() const {
+ return mSeed;
+ };
private:
- using BlockType = __m128i;
-
- void Update(void);
- void Init(void);
+ PRG(const unsigned char* seed);
- unsigned char mSeed[sizeof(BlockType)] = {0};
+ std::array mSeed = {0};
long mCounter = PRG_INITIAL_COUNTER;
BlockType mState[11];
+
+ void Update();
+ void Init();
};
} // namespace scl
-#endif // SCL_PRG_H
+#endif // SCL_PRIMITIVES_PRG_H
diff --git a/include/scl/primitives/sha256.h b/include/scl/primitives/sha256.h
new file mode 100644
index 0000000..04d87eb
--- /dev/null
+++ b/include/scl/primitives/sha256.h
@@ -0,0 +1,80 @@
+/**
+ * @file sha256.h
+ *
+ * SCL --- Secure Computation Library
+ * Copyright (C) 2022 Anders Dalskov
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+#ifndef SCL_PRIMITIVES_SHA256_H
+#define SCL_PRIMITIVES_SHA256_H
+
+#include
+#include
+#include
+#include
+
+#include "scl/primitives/digest.h"
+#include "scl/primitives/iuf_hash.h"
+
+namespace scl {
+namespace details {
+
+/**
+ * @brief SHA256 hash function.
+ */
+class Sha256 final : public details::IUFHash {
+ public:
+ /**
+ * @brief The type of a SHA256 digest.
+ */
+ using DigestType = typename details::Digest<256>::Type;
+
+ /**
+ * @brief Update the hash function with a set of bytes.
+ *
+ * @param bytes a pointer to a number of bytes.
+ * @param nbytes the number of bytes.
+ * @return the updated Hash object.
+ */
+ void Hash(const unsigned char* bytes, std::size_t nbytes);
+
+ /**
+ * @brief Finalize and return the digest.
+ */
+ DigestType Write();
+
+ private:
+ std::array mChunk;
+ std::uint32_t mChunkPos = 0;
+ std::size_t mTotalLen = 0;
+ std::array mState = {0x6a09e667,
+ 0xbb67ae85,
+ 0x3c6ef372,
+ 0xa54ff53a,
+ 0x510e527f,
+ 0x9b05688c,
+ 0x1f83d9ab,
+ 0x5be0cd19};
+
+ void Transform();
+ void Pad();
+ DigestType WriteDigest();
+};
+
+} // namespace details
+} // namespace scl
+
+#endif // SCL_PRIMITIVES_SHA256_H
diff --git a/include/scl/hash.h b/include/scl/primitives/sha3.h
similarity index 56%
rename from include/scl/hash.h
rename to include/scl/primitives/sha3.h
index 9b25dee..5d7d7f9 100644
--- a/include/scl/hash.h
+++ b/include/scl/primitives/sha3.h
@@ -1,5 +1,5 @@
/**
- * @file hash.h
+ * @file sha3.h
*
* SCL --- Secure Computation Library
* Copyright (C) 2022 Anders Dalskov
@@ -18,75 +18,47 @@
* along with this program. If not, see .
*/
-#ifndef SCL_HASH_H
-#define SCL_HASH_H
+#ifndef SCL_PRIMITIVES_SHA3_H
+#define SCL_PRIMITIVES_SHA3_H
#include
+#include
#include
-#include
-#include
#include
+#include "scl/primitives/digest.h"
+#include "scl/primitives/iuf_hash.h"
+
namespace scl {
+namespace details {
/**
- * @brief A hash function.
- *
- * Hash defines an IUF (Initialize-Update-Finalize) style interface for a
- * hash function. The current implementation is based on SHA3 and supports
- * digest sizes of either 256, 384 or 512 bits.
- *
- * @code
- * // define a hash function object with 256-bit output.
- * using Hash = scl::Hash<256>;
- *
- * unsigned char data[] = {'d', 'a', 't', 'a'};
- * Hash hash;
- * hash.Update(data, 4);
- * auto digest = hash.Finalize();
- * @endcode
- *
+ * @brief SHA3 hash function.
* @tparam DigestSize the output size in bits. Must be either 256, 384 or 512
*/
template
-class Hash {
+class Sha3 final : public details::IUFHash> {
static_assert(DigestSize == 256 || DigestSize == 384 || DigestSize == 512,
- "B must be one of 256, 384 or 512");
+ "Invalid SHA3 digest size. Must be 256, 384 or 512");
public:
/**
- * @brief The type of the final digest.
- */
- using DigestType = std::array;
-
- /**
- * @brief Initialize the hash function.
+ * @brief The type of a SHA3 digest.
*/
- Hash(){};
+ using DigestType = typename details::Digest::Type;
/**
* @brief Update the hash function with a set of bytes.
- *
- * @param[in] bytes a pointer to a number of bytes.
- * @param[in] nbytes the number of bytes.
- * @return the updated Hash object.
- */
- Hash &Update(const unsigned char *bytes, std::size_t nbytes);
-
- /**
- * @brief Update the hash function with the content from a byte vector.
- *
- * @param bytes a vector of bytes.
+ * @param bytes a pointer to a number of bytes.
+ * @param nbytes the number of bytes.
* @return the updated Hash object.
*/
- Hash &Update(const std::vector &bytes) {
- return Update(bytes.data(), bytes.size());
- };
+ void Hash(const unsigned char* bytes, std::size_t nbytes);
/**
* @brief Finalize and return the digest.
*/
- DigestType Finalize();
+ DigestType Write();
private:
static const std::size_t kStateSize = 25;
@@ -100,40 +72,22 @@ class Hash {
unsigned int mWordIndex = 0;
};
-static const uint64_t keccakf_rndc[24] = {
- 0x0000000000000001ULL, 0x0000000000008082ULL, 0x800000000000808aULL,
- 0x8000000080008000ULL, 0x000000000000808bULL, 0x0000000080000001ULL,
- 0x8000000080008081ULL, 0x8000000000008009ULL, 0x000000000000008aULL,
- 0x0000000000000088ULL, 0x0000000080008009ULL, 0x000000008000000aULL,
- 0x000000008000808bULL, 0x800000000000008bULL, 0x8000000000008089ULL,
- 0x8000000000008003ULL, 0x8000000000008002ULL, 0x8000000000000080ULL,
- 0x000000000000800aULL, 0x800000008000000aULL, 0x8000000080008081ULL,
- 0x8000000000008080ULL, 0x0000000080000001ULL, 0x8000000080008008ULL};
-
-static const unsigned int keccakf_rotc[24] = {1, 3, 6, 10, 15, 21, 28, 36,
- 45, 55, 2, 14, 27, 41, 56, 8,
- 25, 43, 62, 18, 39, 61, 20, 44};
-
-static const unsigned int keccakf_piln[24] = {10, 7, 11, 17, 18, 3, 5, 16,
- 8, 21, 24, 4, 15, 23, 19, 13,
- 12, 2, 20, 14, 22, 9, 6, 1};
-
/**
* @brief Keccak function.
* @param state the current state
*/
void Keccakf(uint64_t state[25]);
-template
-Hash &Hash::Update(const unsigned char *bytes, std::size_t nbytes) {
+template
+void Sha3::Hash(const unsigned char* bytes, std::size_t nbytes) {
unsigned int old_tail = (8 - mByteIndex) & 7;
- const unsigned char *p = bytes;
+ const unsigned char* p = bytes;
if (nbytes < old_tail) {
while (nbytes-- > 0) {
mSaved |= (uint64_t)(*(p++)) << ((mByteIndex++) * 8);
}
- return *this;
+ return;
}
if (old_tail != 0) {
@@ -174,12 +128,10 @@ Hash &Hash::Update(const unsigned char *bytes, std::size_t nbytes) {
while (tail-- > 0) {
mSaved |= (uint64_t)(*(p++)) << ((mByteIndex++) * 8);
}
-
- return *this;
}
-template
-auto Hash::Finalize() -> DigestType {
+template
+auto Sha3::Write() -> Sha3::DigestType {
uint64_t t = (uint64_t)(((uint64_t)(0x02 | (1 << 2))) << ((mByteIndex)*8));
mState[mWordIndex] ^= mSaved ^ t;
mState[kCuttoff - 1] ^= 0x8000000000000000ULL;
@@ -207,21 +159,7 @@ auto Hash::Finalize() -> DigestType {
return digest;
}
-/**
- * @brief Convert a digest to a string.
- * @param digest the digest
- * @return a hex representation of the digest.
- */
-template
-std::string DigestToString(const D &digest) {
- std::stringstream ss;
- ss << std::setw(2) << std::setfill('0') << std::hex;
- for (const auto &c : digest) {
- ss << (int)c;
- }
- return ss.str();
-}
-
+} // namespace details
} // namespace scl
-#endif // SCL_HASH_H
+#endif // SCL_PRIMITIVES_SHA3_H
diff --git a/include/scl/ss/additive.h b/include/scl/ss/additive.h
index 9d935af..0c598fb 100644
--- a/include/scl/ss/additive.h
+++ b/include/scl/ss/additive.h
@@ -24,7 +24,7 @@
#include
#include "scl/math/vec.h"
-#include "scl/prg.h"
+#include "scl/primitives/prg.h"
namespace scl {
diff --git a/include/scl/ss/feldman.h b/include/scl/ss/feldman.h
new file mode 100644
index 0000000..a3e731f
--- /dev/null
+++ b/include/scl/ss/feldman.h
@@ -0,0 +1,159 @@
+/**
+ * @file feldman.h
+ *
+ * SCL --- Secure Computation Library
+ * Copyright (C) 2022 Anders Dalskov
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+#ifndef SCL_SS_FELDMAN_H
+#define SCL_SS_FELDMAN_H
+
+#include
+#include
+#include
+#include
+
+#include "scl/math/vec.h"
+#include "scl/primitives/prg.h"
+#include "scl/ss/shamir.h"
+
+namespace scl {
+
+/**
+ * @brief Class for working with Feldman secret-shares.
+ *
+ * This class allows creation, verification and reconstruction of
+ * secret-shares from Feldman's verifiable secret-sharing scheme over a suitable
+ * elliptic curve.
+ *
+ * The factory is instantiated with a type \p V which should be \ref EC
+ * for some curve definition. In particular, V::Order is assumed to
+ * be an \ref FF type.
+ */
+template
+class FeldmanSSFactory {
+ /**
+ * @brief The type of secret-shares.
+ */
+ using T = typename V::Order;
+
+ public:
+ /**
+ * @brief Create a new FeldmanSSFactory object.
+ * @param threshold the privacy threshold
+ * @param prg a prg to use for generating randomness
+ */
+ static FeldmanSSFactory Create(std::size_t threshold, PRG& prg) {
+ return FeldmanSSFactory(
+ threshold,
+ ShamirSSFactory::Create(threshold, prg, SecurityLevel::PASSIVE));
+ };
+
+ /**
+ * @brief The result of calling Share
+ */
+ struct ShareBundle {
+ /**
+ * @brief The secret shares.
+ */
+ Vec shares;
+ /**
+ * @brief Commitments of the shares plus the secret.
+ */
+ Vec commitments;
+ };
+
+ /**
+ * @brief Secret share a value.
+ * @param secret the secret to share
+ * @param number_of_shares the number of shares to generate
+ * @return shares and commitments of shares.
+ */
+ ShareBundle Share(const T& secret, std::size_t number_of_shares) const {
+ auto shares = mBase.Share(secret, number_of_shares);
+ return {shares, ComputeCommitments(shares)};
+ };
+
+ /**
+ * @brief Check that a share is consistent with a set of commitments.
+ */
+ bool Verify(const T& share, const Vec& commitments, int party_index) const;
+
+ /**
+ * @brief Check that a secret is consistent with a set of commitments.
+ */
+ bool Verify(const T& secret, const Vec& commitments) const {
+ return Verify(secret, commitments, -1);
+ };
+
+ /**
+ * @brief Recover the value corresponding to a particular index
+ * @param shares the shares to recover
+ * @param index the index. Defaults to 0
+ */
+ T Recover(const Vec& shares, int index = 0) const {
+ return mBase.Recover(shares, index);
+ };
+
+ /**
+ * @brief Recover a share of a particular party
+ * @param shares the shares
+ * @param party_index the index of a the party whose share to recover
+ */
+ T RecoverShare(const Vec& shares, int party_index = 0) const {
+ return mBase.RecoverShare(shares, party_index);
+ };
+
+ private:
+ FeldmanSSFactory(std::size_t threshold, ShamirSSFactory base)
+ : mThreshold(threshold), mBase(base){};
+
+ std::size_t mThreshold;
+ // mutable because calling Recover on this might trigger computation of new
+ // lagrange coefficients.
+ mutable ShamirSSFactory mBase;
+
+ Vec ComputeCommitments(const Vec& shares) const;
+};
+
+template
+bool FeldmanSSFactory::Verify(const T& share,
+ const Vec& commitments,
+ int party_index) const {
+ if (commitments.Size() < mThreshold + 1) {
+ throw std::invalid_argument("insufficient commitments for verification");
+ }
+ // coefficients are indexed one-off.
+ auto& coeff = mBase.GetLagrangeCoefficients(party_index + 1);
+ auto v = details::UncheckedInnerProd(
+ coeff.begin(), coeff.end(), commitments.begin());
+ return v == V::Generator() * share;
+}
+
+template
+Vec FeldmanSSFactory::ComputeCommitments(const Vec& shares) const {
+ std::vector c;
+ c.reserve(mThreshold + 1);
+ auto gen = V::Generator();
+ for (std::size_t i = 0; i < mThreshold + 1; ++i) {
+ c.emplace_back(shares[i] * gen);
+ }
+ return Vec{c};
+}
+
+} // namespace scl
+
+#endif // SCL_SS_FELDMAN_H
diff --git a/include/scl/ss/poly.h b/include/scl/ss/poly.h
index 04340a9..679eb60 100644
--- a/include/scl/ss/poly.h
+++ b/include/scl/ss/poly.h
@@ -70,12 +70,16 @@ class Polynomial {
/**
* @brief Access coefficients, with the constant term at position 0.
*/
- T& operator[](std::size_t idx) { return mCoefficients[idx]; };
+ T& operator[](std::size_t idx) {
+ return mCoefficients[idx];
+ };
/**
* @brief Access coefficients, with the constant term at position 0.
*/
- T operator[](std::size_t idx) const { return mCoefficients[idx]; };
+ T operator[](std::size_t idx) const {
+ return mCoefficients[idx];
+ };
/**
* @brief Add two polynomials.
@@ -101,22 +105,30 @@ class Polynomial {
/**
* @brief Returns true if this is the 0 polynomial.
*/
- bool IsZero() const { return Degree() == 0 && ConstantTerm() == T(0); };
+ bool IsZero() const {
+ return Degree() == 0 && ConstantTerm() == T(0);
+ };
/**
* @brief Get the constant term of this polynomial.
*/
- T ConstantTerm() const { return operator[](0); };
+ T ConstantTerm() const {
+ return operator[](0);
+ };
/**
* @brief Get the leading term of this polynomial.
*/
- T LeadingTerm() const { return operator[](Degree()); };
+ T LeadingTerm() const {
+ return operator[](Degree());
+ };
/**
* @brief Degree of this polynomial.
*/
- std::size_t Degree() const { return mCoefficients.Size() - 1; };
+ std::size_t Degree() const {
+ return mCoefficients.Size() - 1;
+ };
/**
* @brief Get a string representation of this polynomial.
@@ -135,7 +147,9 @@ class Polynomial {
* @brief Get a string representation of this polynomial.
* @note Equivalent to ToString("f", "x").
*/
- std::string ToString() const { return ToString("f", "x"); };
+ std::string ToString() const {
+ return ToString("f", "x");
+ };
/**
* @brief Write a string representation of this polynomial to a stream.
diff --git a/include/scl/ss/shamir.h b/include/scl/ss/shamir.h
index b38a4c2..dc77378 100644
--- a/include/scl/ss/shamir.h
+++ b/include/scl/ss/shamir.h
@@ -30,7 +30,7 @@
#include "scl/math/la.h"
#include "scl/math/vec.h"
-#include "scl/prg.h"
+#include "scl/primitives/prg.h"
#include "scl/ss/poly.h"
namespace scl {
@@ -55,7 +55,8 @@ namespace details {
* @return a pair of polynomials.
*/
template
-auto ReconstructShamirRobust(const Vec& shares, const Vec& alphas,
+auto ReconstructShamirRobust(const Vec& shares,
+ const Vec& alphas,
std::size_t t) {
std::size_t n = 3 * t + 1;
if (n > shares.Size()) {
@@ -123,115 +124,148 @@ T ReconstructShamirRobust(const Vec& shares, std::size_t t) {
ReconstructShamirRobust(shares, Vec::Range(1, shares.Size() + 1), t);
return std::get<0>(p).Evaluate(T(0));
}
+} // namespace details
/**
- * @brief Defines some common security levels.
+ * @brief Defines some commonly used security levels for Shamir secret-sharing.
*
- * These security levels are used to derive some common defaults. For example,
- * SecurityLevel::DETECT is used to indicate that ShamirSSFactory, when
- * instantiated with a threshold of \f$t\f$, should generate \f$2t + 1\f$ shares
- * if no other value is specified.
+ * The SecurityLevel enum describes different threat models that Shamir
+ * secret-sharing is typically used with. Each value should be interpreted
+ * relative to an implicit threshold \f$t\f$.
*/
enum class SecurityLevel {
/**
- * @brief \f$t + 1\f$.
+ * @brief Passive security.
+ *
+ * This level describes a passive adversary threat model. When this level is
+ * used, only \f$t+1\f$ shares are created, and secrets are recovered with
+ * standard Lagrange interpolation without any form of checks.
*/
PASSIVE,
/**
- * @brief \f$2t + 1\f$. Enough shares to detect errors.
+ * @brief Active security with error detection.
+ *
+ * This level describes an active adversary threat model. When this level is
+ * used, \f$n = 2t + 1\f$ shares are created, and secrets are recovered by
+ * using the first \f$t + 1\f$ shares to recover the remaining \f$n - t\f$
+ * shares, which are checked for equality. If this check fails, then
+ * reconstruction could not proceed. Otherwise the secret is recovered as in
+ * the passive case (in which case the secret is guaranteed to be the right
+ * one).
*/
DETECT,
/**
- * @brief \f$3t + 1\f$. Enough shares to correct errors.
+ * @brief Active security with error correction.
+ *
+ * This level describes an active adversary threat model. When this level is
+ * used, \f$3t+1\f$ shares are generated. Recovery proceeds via. the
+ * Berelkamp-Welch error correction algorithm and guarantees that the right
+ * secret is recovered, provided the input shares contain at most \f$t\f$
+ * errors.
*/
CORRECT
};
/**
- * @brief Class for reconstructing secrets from Shamir secret-shares.
+ * @brief Class for working with Shamir secret-shares.
*
- * The main purpose of this class is to cache the lagrange coefficients used
- * when performing interpolation. These coefficients are computed when they are
- * first needed. The only exception is the coefficients needed to reconstruct
- * the point at 0, which is usually where the secret is placed.
+ * This class can be used to create Shamir secret-shares of provided secrets,
+ * and to recover secrets from Shamir secret-shares. On creation, this class
+ * takes a threshold \f$t\f$ and a security level. The security level is used to
+ * determine (1) how many shares to create when secret-sharing a secret, and (2)
+ * how to recover a secret from a set of secret-shares. For example, if
+ * SecurityLevel::DETECT is passed as the security level and \f$t = 3\f$, then
+ * the factory object will create \f$2t + 1 = 7\f$ shares when secret-sharing
+ * something, and reconstruction uses an algorithm which detects errors in the
+ * provided shares (i.e., reconstruction in this case, if it succeeds,
+ * guarantees that the recovered secret is the right one).
*
- * A Reconstructor contains essentially two methods for reconstructing:
- *
- * - Reconstruct reconstructs a particular evaluation of the implied
- * polynomial.
- * - ReconstructShare reconstructs a particular share of a party
- *
- * Both methods do effectively the same, and the only difference is how the
- * index passed to them is interpreted. For the first, the index is interpreted
- * as-is. I.e., if we pass an index of 3, then it will recover the point
- * \f$f(3)\f$, where \f$f\f$ is the polynomial implied by the shares we provide.
- * For the latter method, passing an index of \f$i\f$ implies that we wish to
- * interpolate the point \f$f(i + 1)\f$. This is because we assume parties are
- * indexed from 0, but their evaluation indices are indexed from 1 (because the
- * 0'th index is reserved for the secret).
+ * @tparam T the share type.
+ * @see SecurityLevel
*/
template
-class Reconstructor {
+class ShamirSSFactory {
public:
/**
- * @brief Create a new Reconstructor instance.
- *
- * This creator method creates a new instance and pre-computes the
- * coefficients needed to recover a secret from a set of shamir shares. The
- * SecurityLevel given is used to infer the default reconstruction method:
- *
- * - SecurityLevel::PASSIVE: use passive reconstruction. I.e., reconstruct
- * the value we ask for, using the minimal amount of \f$t + 1\f$ shares
- * needed.
- *
- * - SecurityLevel::DETECT: use the first \f$t+1\f$ shares to reconstruct
- * the remaning \f$t\f$ shares. If all of these check out, then reconstruct
- * the value we seek.
- *
- * - SecurityLevel::CORRECT: use an error correcting
- * algorithm to recover the polynomial \f$f\f$ from \f$3t+1\f$ shares,
- * assuming that at most \f$t\f$ shares are faulty.
- *
- *
+ * @brief Create a new object for working with Shamir secret-shares.
* @param threshold the threshold \f$t\f$
+ * @param prg the prg
* @param default_security_level the default security level
*/
- static Reconstructor Create(std::size_t threshold,
- SecurityLevel default_security_level);
+ static ShamirSSFactory Create(std::size_t threshold,
+ PRG& prg,
+ SecurityLevel default_security_level);
/**
- * @brief Interpolate a set of shares according to a given security level
+ * @brief Create a Shamir secret-sharing of a secret.
+ * @param secret the secret
+ * @param number_of_shares the number of shares to generate. If empty, then
+ * the number of shares is determined by GetDefaultNumberOfShares
+ * @return a vector of shares.
+ */
+ Vec Share(const T& secret,
+ std::optional number_of_shares = {});
+
+ /**
+ * @brief Recover a secret from a vector of secret-shares.
+ *
+ * This method technically recovers a particular coefficient in the original
+ * polynomial used to generate the input shares. By default, the coefficient
+ * to recover is the constant term, which is where the secret lies. However,
+ * this method can also be used to recover the share of another party
+ * (although, \ref RecoverShare has a clearer syntax).
+ *
* @param shares the shares
- * @param security_level the security level
+ * @param security_level the security level to use
* @param index the index to interpolate. Defaults to 0
+ * @return the recovered secret.
+ * @throws std::invalid_argument if not enough shares is given for the
+ * provided security level.
+ * @throws std::logic_error if reconstruction fails (depends on the security
+ * level).
*/
- T Reconstruct(const Vec& shares, SecurityLevel security_level,
- int index = 0) const;
+ T Recover(const Vec& shares,
+ SecurityLevel security_level,
+ int index = 0) const;
/**
- * @brief Interpolate a set of shares with the default security level
+ * @brief Recover a secret from a vector of secret-shares.
+ *
+ * This works like \ref Recover with the security level set to the one
+ * passed during creation of the factory.
+ *
* @param shares the shares
* @param index the index to interpolate. Defaults to 0
*/
- T Reconstruct(const Vec& shares, int index = 0) const {
- return Reconstruct(shares, mSecurityLevel, index);
+ T Recover(const Vec& shares, int index = 0) const {
+ return Recover(shares, mSecurityLevel, index);
};
/**
- * @brief Reconstruct the share of a party.
+ * @brief Recover the share of a particular party.
+ *
+ * Because parties' indices are counted from 1 when creating the
+ * secret-shares, this method is simply a shorthand for
+ * Reconstruct(shares, level, party_index + 1).
+ *
+ * @param shares the shares
+ * @param level the security level to use
+ * @param party_index the index of a party
+ * @return the recovered secret.
*/
- T ReconstructShare(const Vec& shares, SecurityLevel level,
- int party_index = 0) const {
- return Reconstruct(shares, level, party_index + 1);
+ T RecoverShare(const Vec& shares,
+ SecurityLevel level,
+ int party_index = 0) const {
+ return Recover(shares, level, party_index + 1);
};
/**
- * @brief Reconstruct the share of a party.
+ * @brief Recover the share of a party.
*/
- T ReconstructShare(const Vec& shares, int party_index = 0) const {
- return Reconstruct(shares, party_index + 1);
+ T RecoverShare(const Vec& shares, int party_index = 0) const {
+ return RecoverShare(shares, mSecurityLevel, party_index);
};
/**
@@ -247,28 +281,48 @@ class Reconstructor {
return mLagrangeCoeff[index];
};
+ /**
+ * @brief Get the number of shares to create based on set security level
+ *
+ * The returned number is determined based on the \ref scl::SecurityLevel
+ * provided during construction. Given a threshold \f$t\f$, then the number of
+ * shares is computed based on the following rules:
+ */
+ std::size_t GetDefaultNumberOfShares() const {
+ switch (mSecurityLevel) {
+ case SecurityLevel::PASSIVE:
+ return mThreshold + 1;
+ case SecurityLevel::DETECT:
+ return 2 * mThreshold + 1;
+ default: // SecurityLevel::ROBUST:
+ return 3 * mThreshold + 1;
+ }
+ };
+
private:
- Reconstructor(std::size_t threshold, SecurityLevel security_level)
- : mThreshold(threshold), mSecurityLevel(security_level){};
+ ShamirSSFactory(std::size_t threshold, PRG& prg, SecurityLevel security_level)
+ : mThreshold(threshold), mPrg(prg), mSecurityLevel(security_level){};
void ComputeLagrangeCoefficients(int index) const;
std::size_t mThreshold;
+ PRG mPrg;
SecurityLevel mSecurityLevel;
mutable std::unordered_map> mLagrangeCoeff;
};
template
-Reconstructor Reconstructor::Create(
- std::size_t threshold, SecurityLevel default_security_level) {
- Reconstructor intr(threshold, default_security_level);
- intr.ComputeLagrangeCoefficients(0);
- return intr;
+ShamirSSFactory ShamirSSFactory::Create(
+ std::size_t threshold, PRG& prg, SecurityLevel default_security_level) {
+ ShamirSSFactory ssf(threshold, prg, default_security_level);
+ ssf.ComputeLagrangeCoefficients(0);
+ return ssf;
}
template
-T Reconstructor::Reconstruct(const Vec& shares,
- SecurityLevel security_level, int index) const {
+T ShamirSSFactory::Recover(const Vec& shares,
+ SecurityLevel security_level,
+ int index) const {
if (security_level == SecurityLevel::CORRECT) {
// TODO(anders): Currently only supports index = 0
return details::ReconstructShamirRobust(shares, mThreshold);
@@ -302,7 +356,7 @@ T Reconstructor::Reconstruct(const Vec& shares,
}
template
-void Reconstructor::ComputeLagrangeCoefficients(int index) const {
+void ShamirSSFactory::ComputeLagrangeCoefficients(int index) const {
Vec coeff(mThreshold + 1);
const auto x = T(index);
for (std::size_t j = 0; j <= mThreshold; ++j) {
@@ -319,73 +373,31 @@ void Reconstructor::ComputeLagrangeCoefficients(int index) const {
mLagrangeCoeff[index] = coeff;
}
+namespace details {
+
/**
- * @brief A factory object for creating Shamir secret shares.
- * @tparam T the finite field to use.
+ * @brief Create a polynomial for Shamir SS.
+ * @param secret the secret
+ * @param prg a PRG to use for generating the coefficients
+ * @param degree the degree of the polynomial
+ * @return a degree \p degree polynomial.
*/
template
-class ShamirSSFactory {
- public:
- /**
- * @brief Create a new Shamir secret share factory
- * @param t the privacy threshold to use
- * @param prg the PRG to use when generating new shares
- * @param security_level the SecurityLevel to use. Default to
- * SecurityLevel::DETECT
- */
- ShamirSSFactory(std::size_t t, PRG& prg,
- SecurityLevel security_level = SecurityLevel::DETECT)
- : mThreshold(t), mPrg(prg), mDefaultSecurityLevel(security_level){};
-
- /**
- * @brief Create a new set of Shamir secret shares of a secret
- * @param secret the secret
- * @param number_of_shares the number of shares to generate. If empty, then
- * the number of shares is determined by GetDefaultNumberOfShares
- * @return a vector of shares.
- */
- Vec Share(const T& secret,
- std::optional number_of_shares = {});
-
- /**
- * @brief Get the number of shares to create based on set security level
- *
- * The returned number is determined based on the \ref SecurityLevel
- * provided during construction. Given a threshold \f$t\f$, then the number of
- * shares is computed based on the following rules:
- */
- std::size_t GetDefaultNumberOfShares() const {
- switch (mDefaultSecurityLevel) {
- case SecurityLevel::PASSIVE:
- return mThreshold + 1;
- case SecurityLevel::DETECT:
- return 2 * mThreshold + 1;
- default: // SecurityLevel::ROBUST:
- return 3 * mThreshold + 1;
- }
- };
-
- /**
- * @brief Get an scl::Interpolator suitable for this factory.
- */
- Reconstructor GetInterpolator() const {
- return Reconstructor::Create(mThreshold, mDefaultSecurityLevel);
- };
+Polynomial CreateSharePolynomial(const T& secret,
+ PRG& prg,
+ std::size_t degree) {
+ auto coeff = Vec::PartialRandom(
+ degree + 1, [](std::size_t i) { return i > 0; }, prg);
+ coeff[0] = secret;
+ return Polynomial::Create(coeff);
+}
- private:
- std::size_t mThreshold;
- PRG mPrg;
- SecurityLevel mDefaultSecurityLevel;
-};
+} // namespace details
template
Vec ShamirSSFactory::Share(const T& secret,
std::optional number_of_shares) {
- auto coeff = Vec::PartialRandom(
- mThreshold + 1, [](std::size_t i) { return i > 0; }, mPrg);
- coeff[0] = secret;
- auto p = details::Polynomial::Create(coeff);
-
+ auto p = details::CreateSharePolynomial(secret, mPrg, mThreshold);
auto n = number_of_shares.value_or(GetDefaultNumberOfShares());
std::vector shares;
shares.reserve(n);
@@ -395,7 +407,6 @@ Vec ShamirSSFactory::Share(const T& secret,
return Vec(shares);
}
-} // namespace details
} // namespace scl
#endif // SCL_SS_SHAMIR_H
diff --git a/include/scl/math/str.h b/include/scl/util/str.h
similarity index 84%
rename from include/scl/math/str.h
rename to include/scl/util/str.h
index f337729..2e9a73e 100644
--- a/include/scl/math/str.h
+++ b/include/scl/util/str.h
@@ -18,8 +18,8 @@
* along with this program. If not, see .
*/
-#ifndef SCL_MATH_STR_H
-#define SCL_MATH_STR_H
+#ifndef SCL_UTIL_STR_H
+#define SCL_UTIL_STR_H
#include
#include
@@ -77,6 +77,22 @@ std::string ToHexString(const T& v) {
return ss.str();
}
+/**
+ * @brief Convert a list of bytes to a string.
+ * @param begin the start of an iterator
+ * @param end the end of an iterator
+ * @return a hex representation of the digest.
+ */
+template
+std::string ToHexString(It begin, It end) {
+ std::stringstream ss;
+ ss << std::setfill('0') << std::hex;
+ while (begin != end) {
+ ss << std::setw(2) << static_cast(*begin++);
+ }
+ return ss.str();
+}
+
/**
* @brief ToHexString specialization for __uint128_t.
*/
@@ -86,4 +102,4 @@ std::string ToHexString(const __uint128_t& v);
} // namespace details
} // namespace scl
-#endif // SCL_MATH_STR_H
+#endif // SCL_UTIL_STR_H
diff --git a/scripts/check_formatting.sh b/scripts/check_formatting.sh
index 0a9aed7..0647985 100755
--- a/scripts/check_formatting.sh
+++ b/scripts/check_formatting.sh
@@ -1,3 +1,3 @@
#!/usr/bin/bash
-find . -type f \( -iname "*.h" -o -iname "*.cc" \) -exec clang-format -n --style=Google {} \;
+find . -type f \( -iname "*.h" -o -iname "*.cc" \) -exec clang-format -n {} \;
diff --git a/src/scl/math/mersenne127.cc b/src/scl/math/mersenne127.cc
index d024a5a..5cacb14 100644
--- a/src/scl/math/mersenne127.cc
+++ b/src/scl/math/mersenne127.cc
@@ -25,7 +25,7 @@
#include "./ops_small_fp.h"
#include "scl/math/ff_ops.h"
-#include "scl/math/str.h"
+#include "scl/util/str.h"
using u64 = std::uint64_t;
using u128 = __uint128_t;
@@ -111,6 +111,12 @@ void scl::details::FieldFromBytes(u128& dest,
dest = dest % p;
}
+template <>
+void scl::details::FieldToBytes(unsigned char* dest,
+ const u128& src) {
+ std::memcpy(dest, &src, sizeof(u128));
+}
+
template <>
std::string scl::details::FieldToString(const u128& in) {
return ToHexString(in);
diff --git a/src/scl/math/mersenne61.cc b/src/scl/math/mersenne61.cc
index 2866028..f916e0f 100644
--- a/src/scl/math/mersenne61.cc
+++ b/src/scl/math/mersenne61.cc
@@ -26,7 +26,7 @@
#include "./ops_small_fp.h"
#include "scl/math/ff_ops.h"
-#include "scl/math/str.h"
+#include "scl/util/str.h"
using u64 = std::uint64_t;
using u128 = __uint128_t;
@@ -84,6 +84,12 @@ void scl::details::FieldFromBytes(u64& dest,
dest = dest % p;
}
+template <>
+void scl::details::FieldToBytes(unsigned char* dest,
+ const u64& src) {
+ std::memcpy(dest, &src, sizeof(u64));
+}
+
template <>
std::string scl::details::FieldToString(const u64& in) {
return ToHexString(in);
diff --git a/src/scl/math/number.cc b/src/scl/math/number.cc
index 0134768..148b110 100644
--- a/src/scl/math/number.cc
+++ b/src/scl/math/number.cc
@@ -25,7 +25,9 @@
#include
#include
-scl::Number::Number() { mpz_init(mValue); }
+scl::Number::Number() {
+ mpz_init(mValue);
+}
scl::Number::Number(const Number& number) : Number() {
mpz_set(mValue, number.mValue);
@@ -35,7 +37,9 @@ scl::Number::Number(Number&& number) noexcept : Number() {
mpz_set(mValue, number.mValue);
}
-scl::Number::~Number() { mpz_clear(mValue); }
+scl::Number::~Number() {
+ mpz_clear(mValue);
+}
scl::Number scl::Number::Random(std::size_t bits, PRG& prg) {
auto len = (bits - 1) / 8 + 2;
@@ -59,7 +63,9 @@ scl::Number scl::Number::FromString(const std::string& str) {
return num;
} // LCOV_EXCL_LINE
-scl::Number::Number(int value) : Number() { mpz_set_si(mValue, value); }
+scl::Number::Number(int value) : Number() {
+ mpz_set_si(mValue, value);
+}
scl::Number scl::Number::operator+(const Number& number) const {
scl::Number sum;
@@ -139,7 +145,9 @@ int scl::Number::Compare(const Number& number) const {
return mpz_cmp(mValue, number.mValue);
}
-std::size_t scl::Number::BitSize() const { return mpz_sizeinbase(mValue, 2); }
+std::size_t scl::Number::BitSize() const {
+ return mpz_sizeinbase(mValue, 2);
+}
bool scl::Number::TestBit(std::size_t index) const {
return mpz_tstbit(mValue, index);
diff --git a/src/scl/math/ops_gmp_ff.cc b/src/scl/math/ops_gmp_ff.cc
index 8cc3d18..7ef0aba 100644
--- a/src/scl/math/ops_gmp_ff.cc
+++ b/src/scl/math/ops_gmp_ff.cc
@@ -20,16 +20,7 @@
#include "./ops_gmp_ff.h"
-void scl::details::ReadLimb(mp_limb_t &lmb, const unsigned char *bytes,
- std::size_t bits_per_limbs) {
- std::size_t c = 0;
- lmb = 0;
- for (std::size_t i = 0; i < bits_per_limbs; i += 8) {
- lmb |= static_cast(bytes[c++]) << i;
- }
-}
-
-std::size_t scl::details::FindFirstNonZero(const std::string &s) {
+std::size_t scl::details::FindFirstNonZero(const std::string& s) {
int n = 0;
for (const auto c : s) {
if (c != '0') {
diff --git a/src/scl/math/ops_gmp_ff.h b/src/scl/math/ops_gmp_ff.h
index 704d14a..d712041 100644
--- a/src/scl/math/ops_gmp_ff.h
+++ b/src/scl/math/ops_gmp_ff.h
@@ -21,14 +21,15 @@
#ifndef SCL_MATH_OPS_GMP_FF_H
#define SCL_MATH_OPS_GMP_FF_H
-#include
-
#include
#include
+#include
#include
#include
-#include "scl/math/str.h"
+#include
+
+#include "scl/util/str.h"
namespace scl {
namespace details {
@@ -138,7 +139,9 @@ void ModSub(mp_limb_t* out, const mp_limb_t* op, const mp_limb_t* mod) {
*/
template
void ModNeg(mp_limb_t* out, const mp_limb_t* mod) {
- mpn_sub_n(out, mod, out, N);
+ mp_limb_t t[N] = {0};
+ ModSub(t, out, mod);
+ SCL_COPY(out, t, N);
}
/**
@@ -150,7 +153,9 @@ void ModNeg(mp_limb_t* out, const mp_limb_t* mod) {
* @see Redc
*/
template
-void ModMul(mp_limb_t* out, const mp_limb_t* op, const mp_limb_t* mod,
+void ModMul(mp_limb_t* out,
+ const mp_limb_t* op,
+ const mp_limb_t* mod,
const mp_limb_t* np) {
mp_limb_t res[2 * N];
mpn_mul_n(res, out, op, N);
@@ -166,7 +171,9 @@ void ModMul(mp_limb_t* out, const mp_limb_t* op, const mp_limb_t* mod,
* @param np a constant used for montgomery reduction
*/
template
-void ModSqr(mp_limb_t* out, const mp_limb_t* op, const mp_limb_t* mod,
+void ModSqr(mp_limb_t* out,
+ const mp_limb_t* op,
+ const mp_limb_t* mod,
const mp_limb_t* np) {
mp_limb_t res[2 * N];
mpn_sqr(res, op, N);
@@ -196,8 +203,11 @@ inline bool TestBit(const mp_limb_t* v, std::size_t pos) {
* @param np a constant used for montgomery reduction
*/
template
-void ModExp(mp_limb_t* out, const mp_limb_t* x, const mp_limb_t* e,
- const mp_limb_t* mod, const mp_limb_t* np) {
+void ModExp(mp_limb_t* out,
+ const mp_limb_t* x,
+ const mp_limb_t* e,
+ const mp_limb_t* mod,
+ const mp_limb_t* np) {
auto n = mpn_sizeinbase(e, N, 2);
for (std::size_t i = n; i-- > 0;) {
ModSqr(out, out, mod, np);
@@ -208,8 +218,11 @@ void ModExp(mp_limb_t* out, const mp_limb_t* x, const mp_limb_t* e,
}
template
-void ModInvFermat(mp_limb_t* out, const mp_limb_t* op, const mp_limb_t* mod,
- const mp_limb_t* mod_minus_2, const mp_limb_t* np) {
+void ModInvFermat(mp_limb_t* out,
+ const mp_limb_t* op,
+ const mp_limb_t* mod,
+ const mp_limb_t* mod_minus_2,
+ const mp_limb_t* np) {
if (mpn_zero_p(op, N)) {
throw std::invalid_argument("0 not invertible modulo prime");
}
@@ -226,16 +239,34 @@ int CompareValues(const mp_limb_t* lhs, const mp_limb_t* rhs) {
return mpn_cmp(lhs, rhs, N);
}
-void ReadLimb(mp_limb_t& lmb, const unsigned char* bytes,
- std::size_t bits_per_limb);
+template
+void ValueFromBytes(mp_limb_t* out,
+ const unsigned char* src,
+ const mp_limb_t* mod) {
+ for (int i = N - 1; i >= 0; --i) {
+ for (int j = BYTES_PER_LIMB - 1; j >= 0; --j) {
+ out[i] |= static_cast(*src++) << (j * 8);
+ }
+ }
+
+ ToMonty(out, mod);
+}
-/**
- * @brief Read a value from a byte array.
- */
template
-void ValueFromBytes(mp_limb_t* out, const unsigned char* src) {
- for (std::size_t i = 0; i < N; ++i) {
- ReadLimb(out[i], src + i * BYTES_PER_LIMB, BITS_PER_LIMB);
+void ValueToBytes(unsigned char* dest,
+ const mp_limb_t* src,
+ const mp_limb_t* mod,
+ const mp_limb_t* np) {
+ mp_limb_t padded[2 * N] = {0};
+ SCL_COPY(padded, src, N);
+ Redc(padded, mod, np);
+
+ std::size_t c = 0;
+ for (int i = N - 1; i >= 0; --i) {
+ const auto v = padded[i];
+ for (int j = BYTES_PER_LIMB - 1; j >= 0; --j) {
+ dest[c++] = v >> (j * 8);
+ }
}
}
@@ -245,7 +276,8 @@ std::size_t FindFirstNonZero(const std::string& s);
* @brief Print a value.
*/
template
-std::string ToString(const mp_limb_t* val, const mp_limb_t* mod,
+std::string ToString(const mp_limb_t* val,
+ const mp_limb_t* mod,
const mp_limb_t* np) {
mp_limb_t padded[2 * N] = {0};
SCL_COPY(padded, val, N);
diff --git a/src/scl/math/secp256k1_curve.cc b/src/scl/math/secp256k1_curve.cc
index afa5b9f..522ed28 100644
--- a/src/scl/math/secp256k1_curve.cc
+++ b/src/scl/math/secp256k1_curve.cc
@@ -19,6 +19,7 @@
*/
#include
+#include
#include
#include "./secp256k1_extras.h"
@@ -57,7 +58,8 @@ bool Valid(const Field& x, const Field& y) {
} // namespace
template <>
-void scl::details::CurveSetAffine(Point& out, const Field& x,
+void scl::details::CurveSetAffine(Point& out,
+ const Field& x,
const Field& y) {
if (Valid(x, y)) {
out = {x, y, Field(1)};
@@ -226,11 +228,16 @@ void scl::details::CurveScalarMultiply(Point& out,
}
}
-#define COMPRESSED_FLAG 0x04
+// Flag indicating that the point was serialized as an (X, Y) pair
+#define FULL_POINT_FLAG 0x04
+// Flag indicating that the serialized point was the point at infinity. If the
+// point was also serialized as a FULL_POINT, then we write the pair (0, 0).
#define POINT_AT_INFINITY_FLAG 0x02
+// Flag indicating which of Y or -Y to select in case we serialize the point in
+// compressed form.
#define SELECT_SMALLER_FLAG 0x01
-#define IS_COMPRESSED(flags) ((flags)&COMPRESSED_FLAG)
+#define IS_FULL_POINT(flags) ((flags)&FULL_POINT_FLAG)
#define IS_POINT_AT_INFINITY(flags) ((flags)&POINT_AT_INFINITY_FLAG)
#define SELECT_SMALLER(flags) ((flags)&SELECT_SMALLER_FLAG)
@@ -258,7 +265,11 @@ void scl::details::CurveFromBytes(Point& out, const unsigned char* src) {
// send the point-at-infinity.
CurveSetPointAtInfinity(out);
} else {
- if (IS_COMPRESSED(flags)) {
+ if (IS_FULL_POINT(flags)) {
+ out[0] = Field::Read(src + 1);
+ out[1] = Field::Read(src + 1 + Field::ByteSize());
+ out[2] = Field::One();
+ } else {
Field x = Field::Read(src + 1);
out[0] = x;
@@ -274,27 +285,24 @@ void scl::details::CurveFromBytes(Point& out, const unsigned char* src) {
} else {
out[1] = select_smaller == 0 ? y : yn;
}
- } else {
- out[0] = Field::Read(src + 1);
- out[1] = Field::Read(src + 1 + Field::ByteSize());
- out[2] = Field::One();
}
}
}
-#define MARK_COMPRESSED(buf) (*(buf) |= COMPRESSED_FLAG)
+#define MARK_FULL_POINT(buf) (*(buf) |= FULL_POINT_FLAG)
#define MARK_POINT_AT_INFINITY(buf) (*(buf) |= POINT_AT_INFINITY_FLAG)
#define MARK_SELECT_SMALLER(buf) (*(buf) |= SELECT_SMALLER_FLAG)
template <>
-void scl::details::CurveToBytes(unsigned char* dest, const Point& in,
+void scl::details::CurveToBytes(unsigned char* dest,
+ const Point& in,
bool compress) {
// Make sure flag byte is zeroed.
*dest = 0;
- // if point is compressed, mark it as such.
- if (compress) {
- MARK_COMPRESSED(dest);
+ // if point is un-compressed, mark it as such.
+ if (!compress) {
+ MARK_FULL_POINT(dest);
}
if (CurveIsPointAtInfinity(in)) {
diff --git a/src/scl/math/secp256k1_field.cc b/src/scl/math/secp256k1_field.cc
index c78ddcd..5d3323c 100644
--- a/src/scl/math/secp256k1_field.cc
+++ b/src/scl/math/secp256k1_field.cc
@@ -101,12 +101,12 @@ bool scl::details::FieldEqual(const Elem& in1, const Elem& in2) {
template <>
void scl::details::FieldFromBytes(Elem& dest, const unsigned char* src) {
- ValueFromBytes(PTR(dest), src);
+ ValueFromBytes(PTR(dest), src, kPrime);
}
template <>
-std::string scl::details::FieldToString(const Elem& in) {
- return ToString(PTR(in), kPrime, kMontyN);
+void scl::details::FieldToBytes(unsigned char* dest, const Elem& src) {
+ ValueToBytes(dest, PTR(src), kPrime, kMontyN);
}
template <>
@@ -115,6 +115,11 @@ void scl::details::FieldFromString(Elem& out, const std::string& src) {
FromString(PTR(out), kPrime, src);
}
+template <>
+std::string scl::details::FieldToString(const Elem& in) {
+ return ToString(PTR(in), kPrime, kMontyN);
+}
+
bool scl::SCL_FF_Extras::IsSmaller(
const scl::FF& lhs,
const scl::FF& rhs) {
diff --git a/src/scl/math/secp256k1_order.cc b/src/scl/math/secp256k1_order.cc
index 8a78a93..c5abbac 100644
--- a/src/scl/math/secp256k1_order.cc
+++ b/src/scl/math/secp256k1_order.cc
@@ -18,11 +18,11 @@
* along with this program. If not, see .
*/
-#include
-
#include
#include
+#include
+
#include "./ops_gmp_ff.h"
#include "./secp256k1_extras.h"
#include "scl/math/curves/secp256k1.h"
@@ -101,7 +101,12 @@ bool scl::details::FieldEqual(const Elem& in1, const Elem& in2) {
template <>
void scl::details::FieldFromBytes(Elem& dest, const unsigned char* src) {
- ValueFromBytes(PTR(dest), src);
+ ValueFromBytes(PTR(dest), src, kPrime);
+}
+
+template <>
+void scl::details::FieldToBytes(unsigned char* dest, const Elem& src) {
+ ValueToBytes(dest, PTR(src), kPrime, kMontyN);
}
template <>
diff --git a/src/scl/net/config.cc b/src/scl/net/config.cc
index 356752e..36e8364 100644
--- a/src/scl/net/config.cc
+++ b/src/scl/net/config.cc
@@ -73,7 +73,8 @@ scl::NetworkConfig scl::NetworkConfig::Load(int id,
return NetworkConfig(id, info);
}
-scl::NetworkConfig scl::NetworkConfig::Localhost(int id, int size,
+scl::NetworkConfig scl::NetworkConfig::Localhost(int id,
+ int size,
int port_base) {
ValidateIdAndSize(id, size);
diff --git a/src/scl/net/mem_channel.cc b/src/scl/net/mem_channel.cc
index 72f186c..e2caebf 100644
--- a/src/scl/net/mem_channel.cc
+++ b/src/scl/net/mem_channel.cc
@@ -57,7 +57,8 @@ std::size_t scl::InMemoryChannel::Recv(unsigned char* dst, std::size_t n) {
const auto old_size = mOverflow.size();
mOverflow.reserve(old_size + leftovers);
mOverflow.insert(mOverflow.begin() + DIFF_T(old_size),
- data.begin() + DIFF_T(to_copy), data.end());
+ data.begin() + DIFF_T(to_copy),
+ data.end());
}
}
diff --git a/src/scl/net/tcp_utils.cc b/src/scl/net/tcp_utils.cc
index bec9b62..50a1f55 100644
--- a/src/scl/net/tcp_utils.cc
+++ b/src/scl/net/tcp_utils.cc
@@ -20,10 +20,6 @@
#include "scl/net/tcp_utils.h"
-#include
-#include
-#include
-
#include
#include
#include
@@ -32,6 +28,10 @@
#include
#include
+#include
+#include
+#include
+
int scl::details::CreateServerSocket(int port, int backlog) {
int err;
int ssock = ::socket(AF_INET, SOCK_STREAM, 0);
@@ -117,14 +117,18 @@ int scl::details::ConnectAsClient(const std::string& hostname, int port) {
return sock;
}
-int scl::details::CloseSocket(int socket) { return ::close(socket); }
+int scl::details::CloseSocket(int socket) {
+ return ::close(socket);
+}
-ssize_t scl::details::ReadFromSocket(int socket, unsigned char* dst,
+ssize_t scl::details::ReadFromSocket(int socket,
+ unsigned char* dst,
std::size_t n) {
return ::read(socket, dst, n);
}
-ssize_t scl::details::WriteToSocket(int socket, const unsigned char* src,
+ssize_t scl::details::WriteToSocket(int socket,
+ const unsigned char* src,
std::size_t n) {
return ::write(socket, src, n);
}
diff --git a/src/scl/prg.cc b/src/scl/primitives/prg.cc
similarity index 51%
rename from src/scl/prg.cc
rename to src/scl/primitives/prg.cc
index d08150f..2e4bb19 100644
--- a/src/scl/prg.cc
+++ b/src/scl/primitives/prg.cc
@@ -18,17 +18,19 @@
* along with this program. If not, see .
*/
-#include "scl/prg.h"
+#include "scl/primitives/prg.h"
#include
#include
-/* https://github.com/sebastien-riou/aes-brute-force */
+#include
-using byte_t = unsigned char;
-using block_t = __m128i;
+/**
+ * PRG implementation based on AES-CTR with code from
+ * https://github.com/sebastien-riou/aes-brute-force
+ */
-using std::size_t;
+#define BLOCK_SIZE sizeof(__m128i)
#define DO_ENC_BLOCK(m, k) \
do { \
@@ -45,10 +47,12 @@ using std::size_t;
(m) = _mm_aesenclast_si128(m, (k)[10]); \
} while (0)
-#define AES_128_key_exp(k, rcon) \
- aes_128_key_expansion(k, _mm_aeskeygenassist_si128(k, rcon))
+#define AES_128_KEY_EXP(k, rcon) \
+ Aes128KeyExpansion(k, _mm_aeskeygenassist_si128(k, rcon))
-inline static block_t aes_128_key_expansion(block_t key, block_t keygened) {
+namespace {
+
+auto Aes128KeyExpansion(__m128i key, __m128i keygened) {
keygened = _mm_shuffle_epi32(keygened, _MM_SHUFFLE(3, 3, 3, 3));
key = _mm_xor_si128(key, _mm_slli_si128(key, 4));
key = _mm_xor_si128(key, _mm_slli_si128(key, 4));
@@ -56,74 +60,72 @@ inline static block_t aes_128_key_expansion(block_t key, block_t keygened) {
return _mm_xor_si128(key, keygened);
}
-inline static void aes128_load_key(byte_t* enc_key, block_t* key_schedule) {
- key_schedule[0] = _mm_loadu_si128((const block_t*)enc_key);
- key_schedule[1] = AES_128_key_exp(key_schedule[0], 0x01);
- key_schedule[2] = AES_128_key_exp(key_schedule[1], 0x02);
- key_schedule[3] = AES_128_key_exp(key_schedule[2], 0x04);
- key_schedule[4] = AES_128_key_exp(key_schedule[3], 0x08);
- key_schedule[5] = AES_128_key_exp(key_schedule[4], 0x10);
- key_schedule[6] = AES_128_key_exp(key_schedule[5], 0x20);
- key_schedule[7] = AES_128_key_exp(key_schedule[6], 0x40);
- key_schedule[8] = AES_128_key_exp(key_schedule[7], 0x80);
- key_schedule[9] = AES_128_key_exp(key_schedule[8], 0x1B);
- key_schedule[10] = AES_128_key_exp(key_schedule[9], 0x36);
+void Aes128LoadKey(const unsigned char* enc_key, __m128i* key_schedule) {
+ const auto* k = reinterpret_cast(enc_key);
+ key_schedule[0] = _mm_loadu_si128(k);
+ key_schedule[1] = AES_128_KEY_EXP(key_schedule[0], 0x01);
+ key_schedule[2] = AES_128_KEY_EXP(key_schedule[1], 0x02);
+ key_schedule[3] = AES_128_KEY_EXP(key_schedule[2], 0x04);
+ key_schedule[4] = AES_128_KEY_EXP(key_schedule[3], 0x08);
+ key_schedule[5] = AES_128_KEY_EXP(key_schedule[4], 0x10);
+ key_schedule[6] = AES_128_KEY_EXP(key_schedule[5], 0x20);
+ key_schedule[7] = AES_128_KEY_EXP(key_schedule[6], 0x40);
+ key_schedule[8] = AES_128_KEY_EXP(key_schedule[7], 0x80);
+ key_schedule[9] = AES_128_KEY_EXP(key_schedule[8], 0x1B);
+ key_schedule[10] = AES_128_KEY_EXP(key_schedule[9], 0x36);
}
-inline static void aes128_enc(block_t* key_schedule, byte_t* pt, byte_t* ct) {
- block_t m = _mm_loadu_si128((block_t*)pt);
+void Aes128Enc(const __m128i* key_schedule, __m128i m, unsigned char* ct) {
DO_ENC_BLOCK(m, key_schedule);
- _mm_storeu_si128((block_t*)ct, m);
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(ct), m);
+}
+
+auto create_mask(long counter) {
+ return _mm_set_epi64x(PRG_NONCE, counter);
}
-scl::PRG::PRG() { Init(); }
+} // namespace
scl::PRG::PRG(const unsigned char* seed) {
- memcpy(mSeed, seed, SeedSize());
+ if (seed != nullptr) {
+ std::copy(seed, seed + SeedSize(), mSeed.begin());
+ }
Init();
}
-void scl::PRG::Update() { mCounter += 1; }
+void scl::PRG::Update() {
+ mCounter += 1;
+}
-void scl::PRG::Init() { aes128_load_key(mSeed, mState); }
+void scl::PRG::Init() {
+ Aes128LoadKey(mSeed.data(), mState);
+}
void scl::PRG::Reset() {
Init();
mCounter = PRG_INITIAL_COUNTER;
}
-static inline auto create_mask(const long counter) {
- return _mm_set_epi64x(PRG_NONCE, counter);
-}
-
-void scl::PRG::Next(byte_t* dest, size_t nbytes) {
- if (nbytes == 0) {
+void scl::PRG::Next(unsigned char* buffer, size_t n) {
+ if (n == 0) {
return;
}
- size_t nblocks = nbytes / BlockSize();
+ auto nblocks = n / BLOCK_SIZE;
- if ((nbytes % BlockSize()) != 0) {
+ if ((n % BLOCK_SIZE) != 0) {
nblocks++;
}
- block_t mask = create_mask(mCounter);
- byte_t* out = (byte_t*)malloc(nblocks * BlockSize());
- byte_t* p = out;
-
- // LCOV_EXCL_START
- if (out == nullptr) {
- throw std::runtime_error("Could not allocate memory for PRG.");
- }
- // LCOV_EXCL_STOP
-
+ auto mask = create_mask(mCounter);
+ auto out = std::make_unique(nblocks * BLOCK_SIZE);
+ auto* p = out.get();
for (size_t i = 0; i < nblocks; i++) {
- aes128_enc(mState, (byte_t*)(&mask), p);
+ Aes128Enc(mState, mask, p);
Update();
mask = create_mask(mCounter);
- p += BlockSize();
+ p += BLOCK_SIZE;
}
- memcpy(dest, out, nbytes);
- free(out);
+ std::copy(out.get(), out.get() + n, buffer);
}
diff --git a/src/scl/primitives/sha256.cc b/src/scl/primitives/sha256.cc
new file mode 100644
index 0000000..3f3c57a
--- /dev/null
+++ b/src/scl/primitives/sha256.cc
@@ -0,0 +1,171 @@
+/**
+ * @file sha256.cc
+ *
+ * SCL --- Secure Computation Library
+ * Copyright (C) 2022 Anders Dalskov
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+#include "scl/primitives/sha256.h"
+
+#include
+#include
+
+/**
+ * SHA-256 implementation based on https://github.com/System-Glitch/SHA256.
+ */
+
+namespace {
+
+auto RotR(uint32_t x, unsigned n) {
+ return (x >> n) | (x << (32 - n));
+}
+
+auto Sig0(uint32_t x) {
+ return RotR(x, 7) ^ RotR(x, 18) ^ (x >> 3);
+}
+
+auto Sig1(uint32_t x) {
+ return RotR(x, 17) ^ RotR(x, 19) ^ (x >> 10);
+}
+
+auto Split(std::array& chunk) {
+ std::array split;
+ for (std::size_t i = 0, j = 0; i < 16; ++i, j += 4) {
+ split[i] = (chunk[j] << 24) //
+ | (chunk[j + 1] << 16) //
+ | (chunk[j + 2] << 8) //
+ | (chunk[j + 3]);
+ }
+
+ for (std::size_t i = 16; i < 64; ++i) {
+ split[i] = Sig1(split[i - 2]) + Sig0(split[i - 15]);
+ split[i] += split[i - 7] + split[i - 16];
+ }
+
+ return split;
+}
+
+auto Majority(uint32_t x, uint32_t y, uint32_t z) {
+ return (x & (y | z)) | (y & z);
+}
+
+auto Choose(uint32_t x, uint32_t y, uint32_t z) {
+ return (x & y) ^ (~x & z);
+}
+
+} // namespace
+
+void scl::details::Sha256::Transform() {
+ // round constants.
+ static constexpr std::array k = {
+ 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1,
+ 0x923f82a4, 0xab1c5ed5, 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3,
+ 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786,
+ 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
+ 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147,
+ 0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13,
+ 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b,
+ 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
+ 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a,
+ 0x5b9cca4f, 0x682e6ff3, 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208,
+ 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2};
+
+ const auto m = Split(mChunk);
+ auto s = mState;
+
+ for (std::size_t i = 0; i < 64; ++i) {
+ const auto maj = Majority(s[0], s[1], s[2]);
+ const auto chs = Choose(s[4], s[5], s[6]);
+
+ const auto xor_a = RotR(s[0], 2) ^ RotR(s[0], 13) ^ RotR(s[0], 22);
+ const auto xor_e = RotR(s[4], 6) ^ RotR(s[4], 11) ^ RotR(s[4], 25);
+
+ const auto sum = m[i] + k[i] + s[7] + chs + xor_e;
+
+ const auto new_a = xor_a + maj + sum;
+ const auto new_e = s[3] + sum;
+
+ s[7] = s[6];
+ s[6] = s[5];
+ s[5] = s[4];
+ s[4] = new_e;
+ s[3] = s[2];
+ s[2] = s[1];
+ s[1] = s[0];
+ s[0] = new_a;
+ }
+
+ for (std::size_t i = 0; i < 8; ++i) {
+ mState[i] += s[i];
+ }
+}
+
+void scl::details::Sha256::Pad() {
+ auto i = mChunkPos;
+ const auto end = mChunkPos < 56U ? 56U : 64U;
+
+ mChunk[i++] = 0x80;
+ while (i < end) {
+ mChunk[i++] = 0;
+ }
+
+ if (mChunkPos >= 56) {
+ Transform();
+ std::fill(mChunk.begin(), mChunk.begin() + 56, 0);
+ }
+
+ mTotalLen += static_cast(mChunkPos) * 8;
+
+ mChunk[63] = mTotalLen;
+ mChunk[62] = mTotalLen >> 8;
+ mChunk[61] = mTotalLen >> 16;
+ mChunk[60] = mTotalLen >> 24;
+ mChunk[59] = mTotalLen >> 32;
+ mChunk[58] = mTotalLen >> 40;
+ mChunk[57] = mTotalLen >> 48;
+ mChunk[56] = mTotalLen >> 56;
+
+ Transform();
+}
+
+scl::details::Sha256::DigestType scl::details::Sha256::WriteDigest() {
+ scl::details::Sha256::DigestType digest;
+
+ for (std::size_t i = 0; i < 4; ++i) {
+ for (std::size_t j = 0; j < 8; ++j) {
+ digest[i + (j * 4)] = (mState[j] >> (24 - i * 8)) & 0xFF;
+ }
+ }
+
+ return digest;
+}
+
+void scl::details::Sha256::Hash(const unsigned char* bytes,
+ std::size_t nbytes) {
+ for (std::size_t i = 0; i < nbytes; ++i) {
+ mChunk[mChunkPos++] = bytes[i];
+ if (mChunkPos == 64) {
+ Transform();
+ mTotalLen += 512;
+ mChunkPos = 0;
+ }
+ }
+}
+
+scl::details::Sha256::DigestType scl::details::Sha256::Write() {
+ Pad();
+ return WriteDigest();
+}
diff --git a/src/scl/hash.cc b/src/scl/primitives/sha3.cc
similarity index 55%
rename from src/scl/hash.cc
rename to src/scl/primitives/sha3.cc
index 6c34d9b..d6bbd4b 100644
--- a/src/scl/hash.cc
+++ b/src/scl/primitives/sha3.cc
@@ -1,5 +1,5 @@
/**
- * @file hash.cc
+ * @file sha3.cc
*
* SCL --- Secure Computation Library
* Copyright (C) 2022 Anders Dalskov
@@ -18,13 +18,35 @@
* along with this program. If not, see .
*/
-#include "scl/hash.h"
+#include "scl/primitives/sha3.h"
-static inline uint64_t rotl64(uint64_t x, uint64_t y) {
+namespace {
+
+const uint64_t keccakf_rndc[24] = {
+ 0x0000000000000001ULL, 0x0000000000008082ULL, 0x800000000000808aULL,
+ 0x8000000080008000ULL, 0x000000000000808bULL, 0x0000000080000001ULL,
+ 0x8000000080008081ULL, 0x8000000000008009ULL, 0x000000000000008aULL,
+ 0x0000000000000088ULL, 0x0000000080008009ULL, 0x000000008000000aULL,
+ 0x000000008000808bULL, 0x800000000000008bULL, 0x8000000000008089ULL,
+ 0x8000000000008003ULL, 0x8000000000008002ULL, 0x8000000000000080ULL,
+ 0x000000000000800aULL, 0x800000008000000aULL, 0x8000000080008081ULL,
+ 0x8000000000008080ULL, 0x0000000080000001ULL, 0x8000000080008008ULL};
+
+const unsigned int keccakf_rotc[24] = {1, 3, 6, 10, 15, 21, 28, 36,
+ 45, 55, 2, 14, 27, 41, 56, 8,
+ 25, 43, 62, 18, 39, 61, 20, 44};
+
+const unsigned int keccakf_piln[24] = {10, 7, 11, 17, 18, 3, 5, 16,
+ 8, 21, 24, 4, 15, 23, 19, 13,
+ 12, 2, 20, 14, 22, 9, 6, 1};
+
+uint64_t RotLeft64(uint64_t x, uint64_t y) {
return (x << y) | (x >> ((sizeof(uint64_t) * 8) - y));
}
-void scl::Keccakf(uint64_t state[25]) {
+} // namespace
+
+void scl::details::Keccakf(uint64_t state[25]) {
uint64_t t;
uint64_t bc[5];
@@ -35,7 +57,7 @@ void scl::Keccakf(uint64_t state[25]) {
}
for (std::size_t i = 0; i < 5; ++i) {
- t = bc[(i + 4) % 5] ^ rotl64(bc[(i + 1) % 5], 1);
+ t = bc[(i + 4) % 5] ^ RotLeft64(bc[(i + 1) % 5], 1);
for (std::size_t j = 0; j < 25; j += 5) {
state[j + i] ^= t;
}
@@ -45,7 +67,7 @@ void scl::Keccakf(uint64_t state[25]) {
for (std::size_t i = 0; i < 24; ++i) {
const uint64_t v = keccakf_piln[i];
bc[0] = state[v];
- state[v] = rotl64(t, keccakf_rotc[i]);
+ state[v] = RotLeft64(t, keccakf_rotc[i]);
t = bc[0];
}
diff --git a/src/scl/math/str.cc b/src/scl/util/str.cc
similarity index 97%
rename from src/scl/math/str.cc
rename to src/scl/util/str.cc
index 32bdb74..cf7a5e3 100644
--- a/src/scl/math/str.cc
+++ b/src/scl/util/str.cc
@@ -18,7 +18,7 @@
* along with this program. If not, see .
*/
-#include "scl/math/str.h"
+#include "scl/util/str.h"
#include
diff --git a/test/scl/gf7.cc b/test/scl/gf7.cc
index 7148f3d..9cd1d7d 100644
--- a/test/scl/gf7.cc
+++ b/test/scl/gf7.cc
@@ -22,7 +22,7 @@
#include "scl/math/ff_ops.h"
-using GF7 = scl::details::GF7;
+using GF7 = scl_tests::GaloisField7;
template <>
void scl::details::FieldConvertIn(unsigned char& out, int v) {
@@ -95,6 +95,12 @@ void scl::details::FieldFromBytes(unsigned char& dest,
dest = dest % 7;
}
+template <>
+void scl::details::FieldToBytes(unsigned char* dest,
+ const unsigned char& src) {
+ *dest = src;
+}
+
template <>
std::string scl::details::FieldToString(const unsigned char& in) {
std::stringstream ss;
diff --git a/test/scl/gf7.h b/test/scl/gf7.h
index ec9e9c9..a4e2a1b 100644
--- a/test/scl/gf7.h
+++ b/test/scl/gf7.h
@@ -24,17 +24,15 @@
#include
#include
-namespace scl {
-namespace details {
+namespace scl_tests {
-struct GF7 {
+struct GaloisField7 {
using ValueType = unsigned char;
constexpr static const char* kName = "GF(7)";
constexpr static const std::size_t kByteSize = 1;
constexpr static const std::size_t kBitSize = 8;
};
-} // namespace details
-} // namespace scl
+} // namespace scl_tests
#endif /* TEST_SCL_GF7_H */
diff --git a/test/scl/math/fields.h b/test/scl/math/fields.h
new file mode 100644
index 0000000..e0277e9
--- /dev/null
+++ b/test/scl/math/fields.h
@@ -0,0 +1,43 @@
+/**
+ * @file fields.h
+ *
+ * SCL --- Secure Computation Library
+ * Copyright (C) 2022 Anders Dalskov
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+#include "../gf7.h"
+#include "scl/math/curves/secp256k1.h"
+#include "scl/math/fp.h"
+
+namespace scl_tests {
+
+using Mersenne61 = scl::Fp<61>;
+using Mersenne127 = scl::Fp<127>;
+using GF7 = scl::FF;
+
+#ifdef SCL_ENABLE_EC_TESTS
+using Secp256k1_Field = scl::FF;
+using Secp256k1_Order = scl::FF;
+#endif
+
+#ifdef SCL_ENABLE_EC_TESTS
+#define FIELD_DEFS \
+ Mersenne61, Mersenne127, GF7, Secp256k1_Field, Secp256k1_Order
+#else
+#define FIELD_DEFS Mersenne61, Mersenne127, GF7
+#endif
+
+} // namespace scl_tests
diff --git a/test/scl/math/test_ff.cc b/test/scl/math/test_ff.cc
index 2f22384..f684cc0 100644
--- a/test/scl/math/test_ff.cc
+++ b/test/scl/math/test_ff.cc
@@ -20,19 +20,12 @@
#include
-#include "../gf7.h"
-#include "scl/math/curves/secp256k1.h"
-#include "scl/math/fp.h"
-#include "scl/prg.h"
+#include "fields.h"
+#include "scl/primitives/prg.h"
-using Mersenne61 = scl::Fp<61>;
-using Mersenne127 = scl::Fp<127>;
-using GF7 = scl::FF;
+using namespace scl_tests;
-#ifdef SCL_ENABLE_EC_TESTS
-using Secp256k1_Field = scl::FF;
-using Secp256k1_Order = scl::FF;
-#endif
+namespace {
template
T RandomNonZero(scl::PRG& prg) {
@@ -60,16 +53,12 @@ GF7 RandomNonZero(scl::PRG& prg) {
return a;
}
-#define REPEAT for (std::size_t i = 0; i < 50; ++i)
+} // namespace
-#ifdef SCL_ENABLE_EC_TESTS
-#define ARG_LIST Mersenne61, Mersenne127, GF7, Secp256k1_Field, Secp256k1_Order
-#else
-#define ARG_LIST Mersenne61, Mersenne127, GF7
-#endif
+#define REPEAT for (std::size_t i = 0; i < 50; ++i)
-TEMPLATE_TEST_CASE("FF", "[math]", ARG_LIST) {
- scl::PRG prg;
+TEMPLATE_TEST_CASE("FF", "[math]", FIELD_DEFS) {
+ auto prg = scl::PRG::Create();
auto zero = TestType();
SECTION("random") {
@@ -106,6 +95,8 @@ TEMPLATE_TEST_CASE("FF", "[math]", ARG_LIST) {
REQUIRE(a == a_negated);
REQUIRE(a - zero == a);
}
+
+ REQUIRE(TestType(0) == -TestType(0));
}
SECTION("subtraction") {
@@ -136,7 +127,8 @@ TEMPLATE_TEST_CASE("FF", "[math]", ARG_LIST) {
SECTION("inverses") {
REQUIRE_THROWS_MATCHES(
- zero.Inverse(), std::logic_error,
+ zero.Inverse(),
+ std::logic_error,
Catch::Matchers::Message("0 not invertible modulo prime"));
REPEAT {
diff --git a/test/scl/math/test_la.cc b/test/scl/math/test_la.cc
index 8234d09..f485c5b 100644
--- a/test/scl/math/test_la.cc
+++ b/test/scl/math/test_la.cc
@@ -25,7 +25,7 @@
#include "scl/math/la.h"
#include "scl/math/mat.h"
-using F = scl::FF;
+using F = scl::FF;
using Mat = scl::Mat;
using Vec = scl::Vec;
@@ -37,10 +37,8 @@ TEST_CASE("LinearAlgebra", "[math]") {
// [1 0 1]
// [0 1 0]
// [0 0 0]
- Mat A = Mat::FromVector(3, 3,
- {one, zero, one, //
- zero, one, zero, //
- zero, zero, zero});
+ Mat A = Mat::FromVector(
+ 3, 3, {one, zero, one, zero, one, zero, zero, zero, zero});
REQUIRE(scl::details::GetPivotInColumn(A, 2) == -1);
REQUIRE(scl::details::GetPivotInColumn(A, 1) == 1);
REQUIRE(scl::details::GetPivotInColumn(A, 0) == 0);
@@ -52,27 +50,27 @@ TEST_CASE("LinearAlgebra", "[math]") {
}
SECTION("FindFirstNonZeroRow") {
- Mat A = Mat::FromVector(3, 3,
- {one, zero, one, //
- zero, one, zero, //
- zero, zero, zero});
+ Mat A = Mat::FromVector(
+ 3, 3, {one, zero, one, zero, one, zero, zero, zero, zero});
REQUIRE(scl::details::FindFirstNonZeroRow(A) == 1);
A(2, 1) = one;
REQUIRE(scl::details::FindFirstNonZeroRow(A) == 2);
}
SECTION("ExtractSolution") {
- Mat A = Mat::FromVector(3, 4,
- {one, zero, zero, F(3), //
- zero, one, zero, F(5), //
- zero, zero, one, F(2)});
+ Mat A = Mat::FromVector(
+ 3,
+ 4,
+ {one, zero, zero, F(3), zero, one, zero, F(5), zero, zero, one, F(2)});
auto x = scl::details::ExtractSolution(A);
REQUIRE(x.Equals(Vec{F(3), F(5), F(2)}));
+ // clang-format off
Mat B = Mat::FromVector(3, 4,
- {F(1), F(3), F(1), F(2), //
- F(0), F(0), F(1), F(4), //
+ {F(1), F(3), F(1), F(2),
+ F(0), F(0), F(1), F(4),
F(0), F(0), F(0), F(0)});
+ // clang-format on
auto y = scl::details::ExtractSolution(B);
REQUIRE(y.Equals(Vec{F(4), F(4), F(0)}));
@@ -84,7 +82,7 @@ TEST_CASE("LinearAlgebra", "[math]") {
SECTION("RandomSolve") {
auto n = 10;
- scl::PRG prg;
+ auto prg = scl::PRG::Create();
Mat A = Mat::Random(n, n, prg);
Vec b = Vec::Random(n, prg);
@@ -99,7 +97,8 @@ TEST_CASE("LinearAlgebra", "[math]") {
Mat A(2, 2);
Vec b(3);
REQUIRE_THROWS_MATCHES(
- scl::details::SolveLinearSystem(x, A, b), std::invalid_argument,
+ scl::details::SolveLinearSystem(x, A, b),
+ std::invalid_argument,
Catch::Matchers::Message("malformed system of equations"));
}
@@ -113,7 +112,7 @@ TEST_CASE("LinearAlgebra", "[math]") {
SECTION("Inverse") {
std::size_t n = 10;
- scl::PRG prg;
+ auto prg = scl::PRG::Create();
Mat A = Mat::Random(n, n, prg);
Mat I = Mat::Identity(n);
diff --git a/test/scl/math/test_mat.cc b/test/scl/math/test_mat.cc
index 4410098..c2cb475 100644
--- a/test/scl/math/test_mat.cc
+++ b/test/scl/math/test_mat.cc
@@ -26,7 +26,9 @@
using F = scl::Fp<61>;
using Mat = scl::Mat;
-inline void Populate(Mat& m, const int* values) {
+namespace {
+
+void Populate(Mat& m, const int* values) {
for (std::size_t i = 0; i < m.Rows(); i++) {
for (std::size_t j = 0; j < m.Cols(); j++) {
m(i, j) = F(values[i * m.Cols() + j]);
@@ -34,6 +36,8 @@ inline void Populate(Mat& m, const int* values) {
}
}
+} // namespace
+
TEST_CASE("Matrix", "[math]") {
Mat m0(2, 2);
int v0[] = {1, 2, 5, 6};
@@ -79,9 +83,11 @@ TEST_CASE("Matrix", "[math]") {
// matrices are 0 initialized, so the above matrices are equal
REQUIRE(a.Equals(b));
- REQUIRE_THROWS_MATCHES(Mat(0, 1), std::invalid_argument,
+ REQUIRE_THROWS_MATCHES(Mat(0, 1),
+ std::invalid_argument,
Catch::Matchers::Message("n or m cannot be 0"));
- REQUIRE_THROWS_MATCHES(Mat(1, 0), std::invalid_argument,
+ REQUIRE_THROWS_MATCHES(Mat(1, 0),
+ std::invalid_argument,
Catch::Matchers::Message("n or m cannot be 0"));
}
@@ -119,7 +125,8 @@ TEST_CASE("Matrix", "[math]") {
SECTION("Incompatible") {
Mat m2(3, 2);
- REQUIRE_THROWS_MATCHES(m2.Add(m0), std::invalid_argument,
+ REQUIRE_THROWS_MATCHES(m2.Add(m0),
+ std::invalid_argument,
Catch::Matchers::Message("incompatible matrices"));
}
@@ -147,7 +154,8 @@ TEST_CASE("Matrix", "[math]") {
REQUIRE(m5.Equals(m4));
REQUIRE_THROWS_MATCHES(
- m3.Multiply(m0), std::invalid_argument,
+ m3.Multiply(m0),
+ std::invalid_argument,
Catch::Matchers::Message("invalid matrix dimensions for multiply"));
}
@@ -181,7 +189,7 @@ TEST_CASE("Matrix", "[math]") {
}
SECTION("Random") {
- scl::PRG prg;
+ auto prg = scl::PRG::Create();
Mat mr = Mat::Random(4, 5, prg);
REQUIRE(mr.Rows() == 4);
REQUIRE(mr.Cols() == 5);
@@ -194,7 +202,7 @@ TEST_CASE("Matrix", "[math]") {
REQUIRE(not_zero);
// check stability
- scl::PRG prg1;
+ auto prg1 = scl::PRG::Create();
Mat mr1 = Mat::Random(4, 5, prg1);
REQUIRE(mr1.Equals(mr));
}
@@ -212,7 +220,7 @@ TEST_CASE("Matrix", "[math]") {
}
SECTION("resize") {
- scl::PRG prg;
+ auto prg = scl::PRG::Create();
Mat m = Mat::Random(2, 4, prg);
auto copy = m;
copy.Resize(1, 8);
@@ -225,12 +233,13 @@ TEST_CASE("Matrix", "[math]") {
}
}
- REQUIRE_THROWS_MATCHES(m.Resize(42, 4), std::invalid_argument,
+ REQUIRE_THROWS_MATCHES(m.Resize(42, 4),
+ std::invalid_argument,
Catch::Matchers::Message("cannot resize matrix"));
}
SECTION("IsSquare") {
- scl::PRG prg;
+ auto prg = scl::PRG::Create();
Mat sq = Mat::Random(2, 2, prg);
REQUIRE(sq.IsSquare());
Mat nsq = Mat::Random(4, 2, prg);
@@ -279,7 +288,8 @@ TEST_CASE("Matrix", "[math]") {
REQUIRE(m1(2, 2) == F(64));
xs.emplace_back(F(55));
- REQUIRE_THROWS_MATCHES(Mat::Vandermonde(3, 3, xs), std::invalid_argument,
+ REQUIRE_THROWS_MATCHES(Mat::Vandermonde(3, 3, xs),
+ std::invalid_argument,
Catch::Matchers::Message("|xs| != number of rows"));
}
}
diff --git a/test/scl/math/test_mersenne127.cc b/test/scl/math/test_mersenne127.cc
index 5adc8d3..dd47263 100644
--- a/test/scl/math/test_mersenne127.cc
+++ b/test/scl/math/test_mersenne127.cc
@@ -19,6 +19,7 @@
*/
#include
+#include
#include "scl/math/fp.h"
@@ -31,7 +32,9 @@ TEST_CASE("Mersenne127", "[math]") {
Field x(0x7b);
Field big = Field::FromString("58797a14d0653d22a05c11c60e1aacf4");
- SECTION("Name") { REQUIRE(std::string(Field::Name()) == "Mersenne127"); }
+ SECTION("Name") {
+ REQUIRE(std::string(Field::Name()) == "Mersenne127");
+ }
SECTION("ToString") {
REQUIRE(zero.ToString() == "0");
diff --git a/test/scl/math/test_mersenne61.cc b/test/scl/math/test_mersenne61.cc
index 6f84d52..c827e09 100644
--- a/test/scl/math/test_mersenne61.cc
+++ b/test/scl/math/test_mersenne61.cc
@@ -19,6 +19,7 @@
*/
#include
+#include
#include "scl/math/fp.h"
@@ -30,7 +31,9 @@ TEST_CASE("Mersenne61", "[math]") {
Field x(0x7b);
Field big(0x41621e);
- SECTION("Name") { REQUIRE(std::string(Field::Name()) == "Mersenne61"); }
+ SECTION("Name") {
+ REQUIRE(std::string(Field::Name()) == "Mersenne61");
+ }
SECTION("ToString") {
REQUIRE(zero.ToString() == "0");
@@ -53,10 +56,12 @@ TEST_CASE("Mersenne61", "[math]") {
}
SECTION("FromString") {
- REQUIRE_THROWS_MATCHES(Field::FromString("012"), std::invalid_argument,
+ REQUIRE_THROWS_MATCHES(Field::FromString("012"),
+ std::invalid_argument,
Catch::Matchers::Message("odd-length hex string"));
REQUIRE_THROWS_MATCHES(
- Field::FromString("1g"), std::invalid_argument,
+ Field::FromString("1g"),
+ std::invalid_argument,
Catch::Matchers::Message("encountered invalid hex character"));
auto y = Field::FromString("7b");
REQUIRE(x == y);
diff --git a/test/scl/math/test_number.cc b/test/scl/math/test_number.cc
index 700ec2d..085f507 100644
--- a/test/scl/math/test_number.cc
+++ b/test/scl/math/test_number.cc
@@ -22,13 +22,13 @@
#include
#include "scl/math/number.h"
-#include "scl/prg.h"
+#include "scl/primitives/prg.h"
#define REPEAT_I(I) for (std::size_t i = 0; i < (I); ++i)
#define REPEAT REPEAT_I(50)
TEST_CASE("Number", "[math]") {
- scl::PRG prg;
+ auto prg = scl::PRG::Create();
SECTION("Construction") {
scl::Number n0(27);
REQUIRE(n0.ToString() == "Number{1b}");
@@ -129,7 +129,9 @@ TEST_CASE("Number", "[math]") {
}
}
- SECTION("Negation") { REQUIRE(-scl::Number(1234) == scl::Number(-1234)); }
+ SECTION("Negation") {
+ REQUIRE(-scl::Number(1234) == scl::Number(-1234));
+ }
SECTION("Multiplication") {
scl::Number a(444);
diff --git a/test/scl/math/test_secp256k1.cc b/test/scl/math/test_secp256k1.cc
index d9053b5..9432085 100644
--- a/test/scl/math/test_secp256k1.cc
+++ b/test/scl/math/test_secp256k1.cc
@@ -26,13 +26,15 @@
#include "scl/math/ec_ops.h"
#include "scl/math/fp.h"
#include "scl/math/number.h"
-#include "scl/prg.h"
+#include "scl/primitives/prg.h"
using Curve = scl::EC;
using Field = Curve::Field;
TEST_CASE("secp256k1_field", "[math]") {
- SECTION("name") { REQUIRE(std::string(Field::Name()) == "secp256k1_field"); }
+ SECTION("name") {
+ REQUIRE(std::string(Field::Name()) == "secp256k1_field");
+ }
SECTION("Strings") {
REQUIRE(Field(0).ToString() == "0");
@@ -76,7 +78,8 @@ TEST_CASE("secp256k1_field", "[math]") {
REQUIRE(as_affine[1] == y);
REQUIRE_THROWS_MATCHES(
- Curve::FromAffine(Field(0), Field(0)), std::invalid_argument,
+ Curve::FromAffine(Field(0), Field(0)),
+ std::invalid_argument,
Catch::Matchers::Message("provided (x, y) not on curve"));
}
@@ -126,7 +129,8 @@ TEST_CASE("secp256k1_field", "[math]") {
SECTION("inversion") {
REQUIRE_THROWS_MATCHES(
- Field(0).Inverse(), std::invalid_argument,
+ Field(0).Inverse(),
+ std::invalid_argument,
Catch::Matchers::Message("0 not invertible modulo prime"));
Field one(1);
REQUIRE(one * one.Inverse() == one);
@@ -148,13 +152,18 @@ TEST_CASE("secp256k1_field", "[math]") {
using Scalar = scl::FF;
-static Curve RandomPoint(scl::PRG& prg) {
+namespace {
+
+Curve RandomPoint(scl::PRG& prg) {
auto r = scl::Number::Random(100, prg);
return Curve::Generator() * r;
}
+} // namespace
TEST_CASE("secp256k1", "[math]") {
- SECTION("name") { REQUIRE(std::string(Curve::Name()) == "secp256k1"); }
+ SECTION("name") {
+ REQUIRE(std::string(Curve::Name()) == "secp256k1");
+ }
auto ord = scl::Number::FromString(
"fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141");
@@ -186,7 +195,7 @@ TEST_CASE("secp256k1", "[math]") {
REQUIRE(poi.ToString() == "EC{POINT_AT_INFINITY}");
}
- scl::PRG prg;
+ auto prg = scl::PRG::Create();
SECTION("addition") {
auto a = RandomPoint(prg);
@@ -263,23 +272,23 @@ TEST_CASE("secp256k1", "[math]") {
auto a = RandomPoint(prg);
auto buffer = std::make_unique(Curve::ByteSize(false));
a.Write(buffer.get(), false);
- REQUIRE(buffer[0] == 0);
+ REQUIRE(buffer[0] == 0x04);
auto c = Curve::Read(buffer.get());
REQUIRE(a == c);
a.Write(buffer.get(), true);
auto d = Curve::Read(buffer.get());
- REQUIRE(buffer[0] == 0x05);
+ REQUIRE(buffer[0] == 0x01);
REQUIRE(a == d);
Curve poi;
poi.Write(buffer.get(), false);
- REQUIRE(buffer[0] == 0x02);
+ REQUIRE(buffer[0] == 0x06);
auto e = Curve::Read(buffer.get());
REQUIRE(e.PointAtInfinity());
poi.Write(buffer.get(), true);
- REQUIRE(buffer[0] == (0x02 | 0x04));
+ REQUIRE(buffer[0] == 0x02);
auto f = Curve::Read(buffer.get());
REQUIRE(f.PointAtInfinity());
@@ -290,7 +299,7 @@ TEST_CASE("secp256k1", "[math]") {
"a2eee86009ff") //
);
g.Write(buffer.get());
- REQUIRE(buffer[0] == (0x04 | 0x01));
+ REQUIRE(buffer[0] == 0x01);
auto h = Curve::Read(buffer.get());
REQUIRE(h == g);
diff --git a/test/scl/math/test_vec.cc b/test/scl/math/test_vec.cc
index a126e39..d3d1049 100644
--- a/test/scl/math/test_vec.cc
+++ b/test/scl/math/test_vec.cc
@@ -112,7 +112,8 @@ TEST_CASE("Vector", "[math]") {
SECTION("Incompatible") {
auto v2 = Vec{F(2), F(3)};
REQUIRE(!v2.Equals(v1));
- REQUIRE_THROWS_MATCHES(v2.Add(v1), std::invalid_argument,
+ REQUIRE_THROWS_MATCHES(v2.Add(v1),
+ std::invalid_argument,
Catch::Matchers::Message("Vec sizes mismatch"));
}
@@ -122,7 +123,7 @@ TEST_CASE("Vector", "[math]") {
}
SECTION("Random") {
- scl::PRG prg;
+ auto prg = scl::PRG::Create();
auto r = Vec::Random(3, prg);
auto zero = F();
REQUIRE(r.Size() == 3);
diff --git a/test/scl/math/test_z2k.cc b/test/scl/math/test_z2k.cc
index 954d810..07bf345 100644
--- a/test/scl/math/test_z2k.cc
+++ b/test/scl/math/test_z2k.cc
@@ -22,14 +22,14 @@
#include
#include "scl/math/z2k.h"
-#include "scl/prg.h"
+#include "scl/primitives/prg.h"
// use sizes that ensure masking is needed
using Z2k1 = scl::Z2k<62>;
using Z2k2 = scl::Z2k<123>;
TEMPLATE_TEST_CASE("Z2k", "[math]", Z2k1, Z2k2) {
- scl::PRG prg;
+ auto prg = scl::PRG::Create();
auto zero = TestType();
REQUIRE(std::string("Z2k") == TestType::Name());
@@ -89,7 +89,8 @@ TEMPLATE_TEST_CASE("Z2k", "[math]", Z2k1, Z2k2) {
auto a_inverse = a.Inverse();
REQUIRE(a * a_inverse == TestType(1));
REQUIRE_THROWS_MATCHES(
- zero.Inverse(), std::logic_error,
+ zero.Inverse(),
+ std::logic_error,
Catch::Matchers::Message("value not invertible modulo 2^K"));
}
diff --git a/test/scl/net/test_config.cc b/test/scl/net/test_config.cc
index c79a45c..5070d2a 100644
--- a/test/scl/net/test_config.cc
+++ b/test/scl/net/test_config.cc
@@ -53,13 +53,15 @@ TEST_CASE("Config", "[network]") {
const auto* invalid_entry = SCL_TEST_DATA_DIR "invalid_entry.txt";
REQUIRE_THROWS_MATCHES(
- scl::NetworkConfig::Load(0, invalid_entry), std::invalid_argument,
+ scl::NetworkConfig::Load(0, invalid_entry),
+ std::invalid_argument,
Catch::Matchers::Message("invalid entry in config file"));
const auto* invalid_non_existing_file = "";
REQUIRE_THROWS_MATCHES(
scl::NetworkConfig::Load(0, invalid_non_existing_file),
- std::invalid_argument, Catch::Matchers::Message("could not open file"));
+ std::invalid_argument,
+ Catch::Matchers::Message("could not open file"));
}
SECTION("All local") {
diff --git a/test/scl/net/test_discover.cc b/test/scl/net/test_discover.cc
index 7d7d09c..4e039e6 100644
--- a/test/scl/net/test_discover.cc
+++ b/test/scl/net/test_discover.cc
@@ -30,7 +30,9 @@
namespace {
-bool VerifyParty(scl::Party& party, int id, const std::string& hostname,
+bool VerifyParty(scl::Party& party,
+ int id,
+ const std::string& hostname,
int port) {
return party.id == id && party.hostname == hostname && party.port == port;
}
@@ -105,7 +107,8 @@ TEST_CASE("Discovery Server", "[network]") {
scl::DiscoveryServer::Ctx ctx{me, fake.my_network};
REQUIRE_THROWS_MATCHES(
- prot.Run(ctx), std::logic_error,
+ prot.Run(ctx),
+ std::logic_error,
Catch::Matchers::Message("received invalid party ID"));
}
@@ -195,11 +198,13 @@ TEST_CASE("Discovery", "[network]") {
SECTION("Too many parties") {
REQUIRE_THROWS_MATCHES(
- Server(9999, 256), std::invalid_argument,
+ Server(9999, 256),
+ std::invalid_argument,
Catch::Matchers::Message("number_of_parties exceeds max"));
REQUIRE_THROWS_MATCHES(
- Server(256), std::invalid_argument,
+ Server(256),
+ std::invalid_argument,
Catch::Matchers::Message("number_of_parties exceeds max"));
}
diff --git a/test/scl/net/test_mem_channel.cc b/test/scl/net/test_mem_channel.cc
index 2c01b89..7dbf19f 100644
--- a/test/scl/net/test_mem_channel.cc
+++ b/test/scl/net/test_mem_channel.cc
@@ -25,7 +25,7 @@
#include "scl/math.h"
#include "scl/net/mem_channel.h"
-#include "scl/prg.h"
+#include "scl/primitives/prg.h"
#include "util.h"
void PrintBuf(const unsigned char* b, std::size_t n) {
@@ -41,7 +41,7 @@ TEST_CASE("InMemoryChannel", "[network]") {
auto chl0 = channels[0];
auto chl1 = channels[1];
- scl::PRG prg;
+ auto prg = scl::PRG::Create();
unsigned char data_in[200] = {0};
prg.Next(data_in, 200);
diff --git a/test/scl/net/test_tcp_channel.cc b/test/scl/net/test_tcp_channel.cc
index 0efa63a..9e73011 100644
--- a/test/scl/net/test_tcp_channel.cc
+++ b/test/scl/net/test_tcp_channel.cc
@@ -23,7 +23,7 @@
#include "scl/net/tcp_channel.h"
#include "scl/net/tcp_utils.h"
-#include "scl/prg.h"
+#include "scl/primitives/prg.h"
#include "util.h"
TEST_CASE("TcpChannel", "[network]") {
@@ -79,7 +79,7 @@ TEST_CASE("TcpChannel", "[network]") {
clt.join();
srv.join();
- scl::PRG prg;
+ auto prg = scl::PRG::Create();
unsigned char send[200] = {0};
unsigned char recv[200] = {0};
prg.Next(send, 200);
diff --git a/test/scl/net/test_threaded_sender.cc b/test/scl/net/test_threaded_sender.cc
index c23b432..125cc2b 100644
--- a/test/scl/net/test_threaded_sender.cc
+++ b/test/scl/net/test_threaded_sender.cc
@@ -25,7 +25,7 @@
#include "scl/net/tcp_utils.h"
#include "scl/net/threaded_sender.h"
-#include "scl/prg.h"
+#include "scl/primitives/prg.h"
#include "util.h"
TEST_CASE("ThreadedSender", "[network]") {
@@ -50,7 +50,7 @@ TEST_CASE("ThreadedSender", "[network]") {
clt.join();
srv.join();
- scl::PRG prg;
+ auto prg = scl::PRG::Create();
unsigned char send[200] = {0};
unsigned char recv[200] = {0};
prg.Next(send, 200);
diff --git a/test/scl/net/util.cc b/test/scl/net/util.cc
index 6b1bbb7..0bb4941 100644
--- a/test/scl/net/util.cc
+++ b/test/scl/net/util.cc
@@ -22,9 +22,12 @@
int test_port = SCL_DEFAULT_TEST_PORT;
-int scl_tests::GetPort() { return test_port++; }
+int scl_tests::GetPort() {
+ return test_port++;
+}
-bool scl_tests::BufferEquals(const unsigned char *a, const unsigned char *b,
+bool scl_tests::BufferEquals(const unsigned char* a,
+ const unsigned char* b,
int n) {
while (n-- > 0 && *a++ == *b++) {
;
diff --git a/test/scl/p/test_simple.cc b/test/scl/p/test_simple.cc
index 663e3f4..1a0e705 100644
--- a/test/scl/p/test_simple.cc
+++ b/test/scl/p/test_simple.cc
@@ -87,8 +87,10 @@ class BeaverMul : public scl::ProtocolStep {
FF y;
};
-static inline std::vector RandomTriple() {
- scl::PRG prg;
+namespace {
+
+std::vector RandomTriple() {
+ auto prg = scl::PRG::Create();
auto a = FF::Random(prg);
auto b = FF::Random(prg);
auto c = a * b;
@@ -100,8 +102,10 @@ static inline std::vector RandomTriple() {
return std::vector{{as[0], bs[0], cs[0]}, {as[1], bs[1], cs[1]}};
}
+} // namespace
+
TEST_CASE("protocol") {
- scl::PRG prg;
+ auto prg = scl::PRG::Create();
auto xs = scl::CreateAdditiveShares(FF(42), 2, prg);
auto ys = scl::CreateAdditiveShares(FF(11), 2, prg);
auto ts = RandomTriple();
diff --git a/test/scl/test_prg.cc b/test/scl/primitives/test_prg.cc
similarity index 73%
rename from test/scl/test_prg.cc
rename to test/scl/primitives/test_prg.cc
index e156824..11fc117 100644
--- a/test/scl/test_prg.cc
+++ b/test/scl/primitives/test_prg.cc
@@ -18,23 +18,14 @@
* along with this program. If not, see .
*/
+#include
#include
-#include "scl/prg.h"
+#include "scl/primitives/prg.h"
-inline bool BufferCmp(const unsigned char* b0, const unsigned char* b1,
- unsigned len) {
- const auto* p0 = b0;
- const auto* p1 = b1;
- while (len-- > 0) {
- if (*p0++ != *p1++) {
- return false;
- }
- }
- return true;
-}
+namespace {
-inline bool BufferLooksRandom(const unsigned char* p, unsigned len) {
+bool BufferLooksRandom(const unsigned char* p, unsigned len) {
std::vector buckets(256);
for (std::size_t i = 0; i < len; i++) {
@@ -43,16 +34,17 @@ inline bool BufferLooksRandom(const unsigned char* p, unsigned len) {
bool all_in_interval = true;
for (std::size_t i = 0; i < 256; i++) {
- auto p = 100 * ((float)buckets[i] / (float)len);
+ auto p = 100 * (static_cast(buckets[i]) / static_cast(len));
all_in_interval &= p >= 0.2 || p <= 6.0;
}
return all_in_interval;
}
+} // namespace
+
TEST_CASE("PRG", "[misc]") {
- scl::PRG prg;
+ auto prg = scl::PRG::Create();
- REQUIRE(scl::PRG::BlockSize() == 16);
REQUIRE(scl::PRG::SeedSize() == 16);
SECTION("SanityCheck") {
@@ -65,22 +57,19 @@ TEST_CASE("PRG", "[misc]") {
SECTION("Stable") {
unsigned char seed[scl::PRG::SeedSize()] = "1234567890abcde";
- scl::PRG prg0(seed);
- scl::PRG prg1(seed);
-
- REQUIRE(prg0.Counter() == prg1.Counter());
- auto counter_before = prg1.Counter();
- REQUIRE(BufferCmp(prg0.Seed(), seed, scl::PRG::SeedSize()));
+ auto prg0 = scl::PRG::Create(seed);
+ auto prg1 = scl::PRG::Create(seed);
auto rand0 = prg0.Next(100);
auto rand1 = prg1.Next(100);
REQUIRE(rand0 == rand1);
- REQUIRE(counter_before != prg1.Counter());
prg0.Reset();
auto rand00 = prg0.Next(100);
REQUIRE(rand00 == rand0);
+ auto rand10 = prg1.Next(100);
+ REQUIRE(rand00 != rand10);
}
SECTION("Fill") {
@@ -93,7 +82,8 @@ TEST_CASE("PRG", "[misc]") {
REQUIRE(last_is_zero);
REQUIRE_THROWS_MATCHES(
- prg.Next(buffer, 101), std::invalid_argument,
+ prg.Next(buffer, 101),
+ std::invalid_argument,
Catch::Matchers::Message("requested more randomness than dest.size()"));
prg.Next(buffer);
diff --git a/test/scl/primitives/test_sha256.cc b/test/scl/primitives/test_sha256.cc
new file mode 100644
index 0000000..fad4069
--- /dev/null
+++ b/test/scl/primitives/test_sha256.cc
@@ -0,0 +1,89 @@
+/**
+ * @file test_sha256.cc
+ *
+ * SCL --- Secure Computation Library
+ * Copyright (C) 2022 Anders Dalskov
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+#include
+#include
+
+#include "scl/math/curves/secp256k1.h"
+#include "scl/math/ec.h"
+#include "scl/primitives/digest.h"
+#include "scl/primitives/sha256.h"
+
+const static std::array SHA256_empty = {
+ 0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, 0x9a, 0xfb, 0xf4,
+ 0xc8, 0x99, 0x6f, 0xb9, 0x24, 0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b,
+ 0x93, 0x4c, 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55};
+
+const static std::array SHA256_abc = {
+ 0xba, 0x78, 0x16, 0xbf, 0x8f, 0x01, 0xcf, 0xea, 0x41, 0x41, 0x40,
+ 0xde, 0x5d, 0xae, 0x22, 0x23, 0xb0, 0x03, 0x61, 0xa3, 0x96, 0x17,
+ 0x7a, 0x9c, 0xb4, 0x10, 0xff, 0x61, 0xf2, 0x00, 0x15, 0xad};
+
+using Sha256 = scl::details::Sha256;
+
+TEST_CASE("Sha256") {
+ SECTION("SHA256 abc") {
+ Sha256 hash;
+ hash.Update({'a', 'b', 'c'});
+ auto digest = hash.Finalize();
+ REQUIRE(digest.size() == 32);
+ REQUIRE(digest == SHA256_abc);
+ }
+
+ SECTION("SHA256 empty") {
+ Sha256 hash;
+ auto digest = hash.Finalize();
+ REQUIRE(digest.size() == 32);
+ REQUIRE(digest == SHA256_empty);
+ }
+
+ SECTION("SHA256 chunked") {
+ Sha256 hash;
+ hash.Update({'a', 'b'});
+ hash.Update({'c'});
+ auto digest = hash.Finalize();
+ REQUIRE(digest.size() == 32);
+ REQUIRE(digest == SHA256_abc);
+ }
+
+ SECTION("Hash of curve point") {
+ // Reference test showing that serialization + hashing is the same as
+ // bouncycastle in Java.
+
+ using Curve = scl::EC;
+ auto pk = Curve::Generator() * scl::Number::FromString("a");
+
+ const auto n = Curve::ByteSize(false);
+ unsigned char buf[n] = {0};
+ pk.Write(buf, false);
+
+ Sha256 hash;
+ hash.Update(buf, n);
+
+ auto d = hash.Finalize();
+
+ std::array target = {
+ 0xde, 0xc1, 0x6a, 0xc2, 0x78, 0x99, 0xeb, 0xdf, 0x76, 0x0e, 0xaf,
+ 0x0a, 0x9f, 0x30, 0x95, 0xd1, 0x6a, 0x55, 0xea, 0x59, 0xef, 0x2a,
+ 0xe1, 0x8e, 0x9d, 0x22, 0x33, 0xd6, 0xbe, 0x82, 0x58, 0x38};
+
+ REQUIRE(d == target);
+ }
+}
diff --git a/test/scl/test_hash.cc b/test/scl/primitives/test_sha3.cc
similarity index 80%
rename from test/scl/test_hash.cc
rename to test/scl/primitives/test_sha3.cc
index 30eb57f..0eb1710 100644
--- a/test/scl/test_hash.cc
+++ b/test/scl/primitives/test_sha3.cc
@@ -1,5 +1,5 @@
/**
- * @file test_hash.cc
+ * @file test_sha3.cc
*
* SCL --- Secure Computation Library
* Copyright (C) 2022 Anders Dalskov
@@ -20,30 +20,31 @@
#include
-#include "scl/hash.h"
+#include "scl/primitives/digest.h"
+#include "scl/primitives/hash.h"
-const static std::array SHA3_256_empty = {
+static const scl::details::Digest<256>::Type SHA3_256_empty = {
0xa7, 0xff, 0xc6, 0xf8, 0xbf, 0x1e, 0xd7, 0x66, 0x51, 0xc1, 0x47,
0x56, 0xa0, 0x61, 0xd6, 0x62, 0xf5, 0x80, 0xff, 0x4d, 0xe4, 0x3b,
0x49, 0xfa, 0x82, 0xd8, 0x0a, 0x4b, 0x80, 0xf8, 0x43, 0x4a};
-static const std::array SHA3_256_abc = {
+static const scl::details::Digest<256>::Type SHA3_256_abc = {
0x3a, 0x98, 0x5d, 0xa7, 0x4f, 0xe2, 0x25, 0xb2, 0x04, 0x5c, 0x17,
0x2d, 0x6b, 0xd3, 0x90, 0xbd, 0x85, 0x5f, 0x08, 0x6e, 0x3e, 0x9d,
0x52, 0x5b, 0x46, 0xbf, 0xe2, 0x45, 0x11, 0x43, 0x15, 0x32};
-static const std::array SHA3_256_0xa3_200_times = {
+static const scl::details::Digest<256>::Type SHA3_256_0xa3_200_times = {
0x79, 0xf3, 0x8a, 0xde, 0xc5, 0xc2, 0x03, 0x07, 0xa9, 0x8e, 0xf7,
0x6e, 0x83, 0x24, 0xaf, 0xbf, 0xd4, 0x6c, 0xfd, 0x81, 0xb2, 0x2e,
0x39, 0x73, 0xc6, 0x5f, 0xa1, 0xbd, 0x9d, 0xe3, 0x17, 0x87};
-static const std::array SHA3_384_0xa3_200_times = {
+static const scl::details::Digest<384>::Type SHA3_384_0xa3_200_times = {
0x18, 0x81, 0xde, 0x2c, 0xa7, 0xe4, 0x1e, 0xf9, 0x5d, 0xc4, 0x73, 0x2b,
0x8f, 0x5f, 0x00, 0x2b, 0x18, 0x9c, 0xc1, 0xe4, 0x2b, 0x74, 0x16, 0x8e,
0xd1, 0x73, 0x26, 0x49, 0xce, 0x1d, 0xbc, 0xdd, 0x76, 0x19, 0x7a, 0x31,
0xfd, 0x55, 0xee, 0x98, 0x9f, 0x2d, 0x70, 0x50, 0xdd, 0x47, 0x3e, 0x8f};
-std::array SHA3_512_0xa3_200_times = {
+static const std::array SHA3_512_0xa3_200_times = {
0xe7, 0x6d, 0xfa, 0xd2, 0x20, 0x84, 0xa8, 0xb1, 0x46, 0x7f, 0xcf,
0x2f, 0xfa, 0x58, 0x36, 0x1b, 0xec, 0x76, 0x28, 0xed, 0xf5, 0xf3,
0xfd, 0xc0, 0xe4, 0x80, 0x5d, 0xc4, 0x8c, 0xae, 0xec, 0xa8, 0x1b,
@@ -51,7 +52,7 @@ std::array SHA3_512_0xa3_200_times = {
0x9a, 0x2d, 0xf4, 0x6b, 0xe5, 0x89, 0xc5, 0x1c, 0xa1, 0xa4, 0xa8,
0x41, 0x6d, 0xf6, 0x54, 0x5a, 0x1c, 0xe8, 0xba, 0x00};
-TEST_CASE("Hash", "[misc]") {
+TEST_CASE("Sha3", "[misc]") {
SECTION("SHA3-256 empty") {
scl::Hash<256> hash;
auto digest = hash.Finalize();
@@ -60,7 +61,8 @@ TEST_CASE("Hash", "[misc]") {
SECTION("SHA3-256 abc") {
scl::Hash<256> hash;
- auto digest = hash.Update((const unsigned char *)"abc", 3).Finalize();
+ unsigned char abc[] = "abc";
+ auto digest = hash.Update(abc, 3).Finalize();
REQUIRE(digest == SHA3_256_abc);
}
@@ -91,6 +93,7 @@ TEST_CASE("Hash", "[misc]") {
scl::Hash<384> hash0;
auto digest = hash0.Update(buf, 200).Finalize();
+ REQUIRE(digest.size() == 48);
REQUIRE(digest == SHA3_384_0xa3_200_times);
scl::Hash<384> hash1;
@@ -109,6 +112,7 @@ TEST_CASE("Hash", "[misc]") {
scl::Hash<512> hash0;
auto digest = hash0.Update(buf, 200).Finalize();
+ REQUIRE(digest.size() == 64);
REQUIRE(digest == SHA3_512_0xa3_200_times);
scl::Hash<512> hash1;
@@ -124,9 +128,17 @@ TEST_CASE("Hash", "[misc]") {
auto ref = hash_ref.Update(ref_buf, 12).Finalize();
scl::Hash<256> hash1;
- std::vector v = {'h', 'e', 'l', 'l', 'o', ',',
- ' ', 'w', 'o', 'r', 'l', 'd'};
+ std::vector v = {
+ 'h', 'e', 'l', 'l', 'o', ',', ' ', 'w', 'o', 'r', 'l', 'd'};
auto from_vec = hash1.Update(v).Finalize();
REQUIRE(ref == from_vec);
}
+
+ SECTION("Hash array") {
+ unsigned char abc[] = "abc";
+ std::array abc_arr = {'a', 'b', 'c'};
+ auto ref = scl::Hash<256>{}.Update(abc, 3).Finalize();
+ auto act = scl::Hash<256>{}.Update(abc_arr).Finalize();
+ REQUIRE(ref == act);
+ }
}
diff --git a/test/scl/ss/test_additive.cc b/test/scl/ss/test_additive.cc
index cd905a9..fda0c0e 100644
--- a/test/scl/ss/test_additive.cc
+++ b/test/scl/ss/test_additive.cc
@@ -21,12 +21,12 @@
#include
#include "scl/math.h"
-#include "scl/prg.h"
+#include "scl/primitives/prg.h"
#include "scl/ss/additive.h"
TEST_CASE("AdditiveSS", "[ss]") {
using FF = scl::Fp<61>;
- scl::PRG prg;
+ auto prg = scl::PRG::Create();
auto secret = FF(12345);
@@ -41,6 +41,7 @@ TEST_CASE("AdditiveSS", "[ss]") {
REQUIRE(scl::ReconstructAdditive(sum) == secret + x);
REQUIRE_THROWS_MATCHES(
- scl::CreateAdditiveShares(secret, 0, prg), std::invalid_argument,
+ scl::CreateAdditiveShares(secret, 0, prg),
+ std::invalid_argument,
Catch::Matchers::Message("cannot create shares for 0 people"));
}
diff --git a/test/scl/ss/test_feldman.cc b/test/scl/ss/test_feldman.cc
new file mode 100644
index 0000000..47027ef
--- /dev/null
+++ b/test/scl/ss/test_feldman.cc
@@ -0,0 +1,45 @@
+/**
+ * @file test_feldman.cc
+ *
+ * SCL --- Secure Computation Library
+ * Copyright (C) 2022 Anders Dalskov
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+#include
+
+#include "scl/math/curves/secp256k1.h"
+#include "scl/math/ec.h"
+#include "scl/primitives/prg.h"
+#include "scl/ss/feldman.h"
+
+TEST_CASE("Feldman", "[ss]") {
+ using EC = scl::EC;
+ using FF = EC::Order;
+
+ auto prg = scl::PRG::Create();
+ std::size_t t = 4;
+
+ SECTION("Share") {
+ auto factory = scl::FeldmanSSFactory::Create(t, prg);
+ auto secret = FF(123);
+ auto sb = factory.Share(secret, 24);
+ REQUIRE(sb.shares.Size() == 24);
+ REQUIRE(sb.commitments.Size() == t + 1);
+ REQUIRE(factory.Verify(sb.shares[22], sb.commitments, 22));
+ REQUIRE(factory.Verify(secret, sb.commitments));
+ REQUIRE(factory.Recover(sb.shares) == secret);
+ }
+}
diff --git a/test/scl/ss/test_poly.cc b/test/scl/ss/test_poly.cc
index e9296a1..c80440b 100644
--- a/test/scl/ss/test_poly.cc
+++ b/test/scl/ss/test_poly.cc
@@ -19,16 +19,21 @@
*/
#include
+#include
+#include "../math/fields.h"
#include "scl/math.h"
-#include "scl/prg.h"
+#include "scl/math/curves/secp256k1.h"
+#include "scl/primitives/prg.h"
#include "scl/ss/poly.h"
-using FF = scl::Fp<61>;
-using Poly = scl::details::Polynomial;
-using Vec = scl::Vec;
+using namespace scl_tests;
+
+TEMPLATE_TEST_CASE("Polynomials", "[ss][math]", FIELD_DEFS) {
+ using FF = TestType;
+ using Poly = scl::details::Polynomial;
+ using Vec = scl::Vec;
-TEST_CASE("Polynomials", "[ss][math]") {
SECTION("DefaultConstruct") {
Poly p;
REQUIRE(p.Degree() == 0);
@@ -149,10 +154,15 @@ TEST_CASE("Polynomials", "[ss][math]") {
for (std::size_t i = 0; i < x.Degree(); ++i) {
REQUIRE(x[i] == q[i]);
}
+
+ Poly z;
+ REQUIRE_THROWS_MATCHES(p.Divide(z),
+ std::invalid_argument,
+ Catch::Matchers::Message("division by 0"));
}
SECTION("DivideRandom") {
- scl::PRG prg;
+ auto prg = scl::PRG::Create();
auto c0 = Vec::Random(10, prg);
auto c1 = Vec::Random(9, prg);
auto a = Poly::Create(c0);
diff --git a/test/scl/ss/test_shamir.cc b/test/scl/ss/test_shamir.cc
index 9a44801..07b45d4 100644
--- a/test/scl/ss/test_shamir.cc
+++ b/test/scl/ss/test_shamir.cc
@@ -23,7 +23,7 @@
#include "../gf7.h"
#include "scl/math.h"
-#include "scl/prg.h"
+#include "scl/primitives/prg.h"
#include "scl/ss/shamir.h"
TEST_CASE("Shamir", "[ss]") {
@@ -32,48 +32,46 @@ TEST_CASE("Shamir", "[ss]") {
const std::size_t t = 2;
- SECTION("Reconstruct") {
- scl::PRG prg;
- scl::details::ShamirSSFactory factory(
- t, prg, scl::details::SecurityLevel::PASSIVE);
- auto intr = factory.GetInterpolator();
-
+ SECTION("Recover") {
+ auto prg = scl::PRG::Create();
+ auto factory =
+ scl::ShamirSSFactory::Create(t, prg, scl::SecurityLevel::PASSIVE);
auto secret = FF(123);
auto shares = factory.Share(secret);
- auto s = intr.Reconstruct(shares);
+ auto s = factory.Recover(shares);
REQUIRE(s == secret);
REQUIRE_THROWS_MATCHES(
- intr.Reconstruct(shares.SubVector(1)), std::invalid_argument,
+ factory.Recover(shares.SubVector(1)),
+ std::invalid_argument,
Catch::Matchers::Message("not enough shares to reconstruct"));
}
SECTION("Detection") {
- scl::PRG prg;
- scl::details::ShamirSSFactory factory(
- t, prg, scl::details::SecurityLevel::DETECT);
- auto intr = scl::details::Reconstructor::Create(
- t, scl::details::SecurityLevel::DETECT);
+ auto prg = scl::PRG::Create();
+ auto factory =
+ scl::ShamirSSFactory::Create(t, prg, scl::SecurityLevel::DETECT);
auto secret = FF(555);
auto shares = factory.Share(secret);
REQUIRE(shares.Size() == 2 * t + 1);
- REQUIRE(intr.Reconstruct(shares) == secret);
+ REQUIRE(factory.Recover(shares) == secret);
REQUIRE_THROWS_MATCHES(
- intr.Reconstruct(shares.SubVector(2)), std::invalid_argument,
+ factory.Recover(shares.SubVector(2)),
+ std::invalid_argument,
Catch::Matchers::Message("not enough shares to reconstruct"));
- auto ss = intr.ReconstructShare(shares, 2);
+ auto ss = factory.RecoverShare(shares, 2);
REQUIRE(ss == shares[2]);
- REQUIRE(intr.Reconstruct(shares, 3) == intr.ReconstructShare(shares, 2));
+ REQUIRE(factory.Recover(shares, 3) == factory.RecoverShare(shares, 2));
}
SECTION("Robust") {
- scl::PRG prg;
- scl::details::ShamirSSFactory factory(
- t, prg, scl::details::SecurityLevel::CORRECT);
+ auto prg = scl::PRG::Create();
+ auto factory =
+ scl::ShamirSSFactory::Create(t, prg, scl::SecurityLevel::CORRECT);
// no errors
auto secret = FF(123);
@@ -83,9 +81,7 @@ TEST_CASE("Shamir", "[ss]") {
REQUIRE(reconstructed == secret);
// can also reconstruct with an interpolator
- auto intr = scl::details::Reconstructor::Create(
- t, scl::details::SecurityLevel::CORRECT);
- REQUIRE(intr.Reconstruct(shares) == secret);
+ REQUIRE(factory.Recover(shares) == secret);
// one error
shares[0] = FF(63212);
@@ -100,7 +96,8 @@ TEST_CASE("Shamir", "[ss]") {
// three errors -- that's one too many
shares[1] = FF(123);
REQUIRE_THROWS_MATCHES(
- scl::details::ReconstructShamirRobust(shares, t), std::logic_error,
+ scl::details::ReconstructShamirRobust(shares, t),
+ std::logic_error,
Catch::Matchers::Message("could not correct shares"));
REQUIRE_THROWS_MATCHES(
@@ -119,7 +116,7 @@ TEST_CASE("Shamir", "[ss]") {
TEST_CASE("BerlekampWelch", "[ss][math]") {
// https://en.wikipedia.org/wiki/Berlekamp%E2%80%93Welch_algorithm#Example
- using FF = scl::FF;
+ using FF = scl::FF;
using Vec = scl::Vec;
Vec bs = {FF(1), FF(5), FF(3), FF(6), FF(3), FF(2), FF(2)};