Skip to content

Commit

Permalink
Add msm to BLS + trait (#594)
Browse files Browse the repository at this point in the history
  • Loading branch information
benr-ml committed May 24, 2023
1 parent 0e99e58 commit 8a8a0c5
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 61 deletions.
55 changes: 52 additions & 3 deletions fastcrypto/src/groups/bls12381.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

use crate::bls12381::min_pk::DST_G2;
use crate::bls12381::min_sig::DST_G1;
use crate::error::FastCryptoError;
use crate::groups::{GroupElement, HashToGroupElement, Pairing, Scalar as ScalarType};
use crate::error::{FastCryptoError, FastCryptoResult};
use crate::groups::{
GroupElement, HashToGroupElement, MultiScalarMul, Pairing, Scalar as ScalarType,
};
use crate::serde_helpers::BytesRepresentation;
use crate::serde_helpers::ToFromByteArray;
use crate::traits::AllowedRng;
Expand All @@ -19,7 +21,8 @@ use blst::{
blst_p1_mult, blst_p1_to_affine, blst_p2, blst_p2_add_or_double, blst_p2_affine, blst_p2_cneg,
blst_p2_compress, blst_p2_deserialize, blst_p2_from_affine, blst_p2_in_g2, blst_p2_mult,
blst_p2_to_affine, blst_scalar, blst_scalar_from_bendian, blst_scalar_from_fr,
blst_scalar_from_lendian, Pairing as BlstPairing, BLS12_381_G1, BLS12_381_G2, BLST_ERROR,
blst_scalar_from_lendian, p1_affines, p2_affines, Pairing as BlstPairing, BLS12_381_G1,
BLS12_381_G2, BLST_ERROR,
};
use derive_more::From;
use fastcrypto_derive::GroupOpsExtend;
Expand All @@ -29,10 +32,12 @@ use std::ptr;

/// Elements of the group G_1 in BLS 12-381.
#[derive(Debug, From, Clone, Copy, Eq, PartialEq, GroupOpsExtend)]
#[repr(transparent)]
pub struct G1Element(blst_p1);

/// Elements of the group G_2 in BLS 12-381.
#[derive(Debug, From, Clone, Copy, Eq, PartialEq, GroupOpsExtend)]
#[repr(transparent)]
pub struct G2Element(blst_p2);

/// Elements of the subgroup G_T of F_q^{12} in BLS 12-381. Note that it is written in additive notation here.
Expand Down Expand Up @@ -131,6 +136,28 @@ impl Mul<Scalar> for G1Element {
}
}

impl MultiScalarMul for G1Element {
fn multi_scalar_mul(scalars: &[Self::ScalarType], points: &[Self]) -> FastCryptoResult<Self> {
if scalars.len() != points.len() || scalars.is_empty() {
return Err(FastCryptoError::InvalidInput);
}
// Inspired by blstrs.
let points =
unsafe { std::slice::from_raw_parts(points.as_ptr() as *const blst_p1, points.len()) };
let points = p1_affines::from(points);
let mut scalar_bytes: Vec<u8> = Vec::with_capacity(scalars.len() * 32);
for a in scalars.iter().map(|s| s.0) {
let mut scalar: blst_scalar = blst_scalar::default();
unsafe {
blst_scalar_from_fr(&mut scalar, &a);
}
scalar_bytes.extend_from_slice(&scalar.b);
}
let res = points.mult(scalar_bytes.as_slice(), 255);
Ok(Self::from(res))
}
}

impl GroupElement for G1Element {
type ScalarType = Scalar;

Expand Down Expand Up @@ -279,6 +306,28 @@ impl Mul<Scalar> for G2Element {
}
}

impl MultiScalarMul for G2Element {
fn multi_scalar_mul(scalars: &[Self::ScalarType], points: &[Self]) -> FastCryptoResult<Self> {
if scalars.len() != points.len() || scalars.is_empty() {
return Err(FastCryptoError::InvalidInput);
}
// Inspired by blstrs.
let points =
unsafe { std::slice::from_raw_parts(points.as_ptr() as *const blst_p2, points.len()) };
let points = p2_affines::from(points);
let mut scalar_bytes: Vec<u8> = Vec::with_capacity(scalars.len() * 32);
for a in scalars.iter().map(|s| s.0) {
let mut scalar: blst_scalar = blst_scalar::default();
unsafe {
blst_scalar_from_fr(&mut scalar, &a);
}
scalar_bytes.extend_from_slice(&scalar.b);
}
let res = points.mult(scalar_bytes.as_slice(), 255);
Ok(Self::from(res))
}
}

impl GroupElement for G2Element {
type ScalarType = Scalar;

Expand Down
7 changes: 6 additions & 1 deletion fastcrypto/src/groups/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) 2022, Mysten Labs, Inc.
// SPDX-License-Identifier: Apache-2.0

use crate::error::FastCryptoError;
use crate::error::{FastCryptoError, FastCryptoResult};
use crate::traits::AllowedRng;
use core::ops::{Add, Div, Mul, Neg, Sub};
use std::fmt::Debug;
Expand Down Expand Up @@ -64,3 +64,8 @@ pub trait HashToGroupElement {
/// Hashes the given message and maps the result to a group element.
fn hash_to_group_element(msg: &[u8]) -> Self;
}

/// Trait for groups that support multi-scalar multiplication.
pub trait MultiScalarMul: GroupElement {
fn multi_scalar_mul(scalars: &[Self::ScalarType], points: &[Self]) -> FastCryptoResult<Self>;
}
23 changes: 8 additions & 15 deletions fastcrypto/src/groups/ristretto255.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
//! Implementations of the [ristretto255 group](https://www.ietf.org/archive/id/draft-irtf-cfrg-ristretto255-decaf448-03.html) which is a group of
//! prime order 2^{252} + 27742317777372353535851937790883648493 built over Curve25519.

use crate::groups::{GroupElement, HashToGroupElement, Scalar};
use crate::error::FastCryptoResult;
use crate::groups::{GroupElement, HashToGroupElement, MultiScalarMul, Scalar};
use crate::hash::Sha512;
use crate::serde_helpers::ToFromByteArray;
use crate::traits::AllowedRng;
Expand Down Expand Up @@ -51,26 +52,18 @@ impl RistrettoPoint {
pub fn decompress(bytes: &[u8; 32]) -> Result<Self, FastCryptoError> {
RistrettoPoint::try_from(bytes.as_slice())
}
}

/// Compute the linear combination of the given scalars and points. An error will be returned if
/// the sizes do not match.
pub fn multiscalar_mul<I, J>(scalars: I, points: J) -> Result<Self, FastCryptoError>
where
I: IntoIterator,
I::Item: Into<RistrettoScalar>,
J: IntoIterator<Item = Self>,
{
let scalars_iter = scalars.into_iter();
let points_iter = points.into_iter();

if scalars_iter.size_hint() != points_iter.size_hint() {
impl MultiScalarMul for RistrettoPoint {
fn multi_scalar_mul(scalars: &[Self::ScalarType], points: &[Self]) -> FastCryptoResult<Self> {
if scalars.len() != points.len() {
return Err(FastCryptoError::InvalidInput);
}

Ok(RistrettoPoint(
ExternalRistrettoPoint::vartime_multiscalar_mul(
scalars_iter.map(|s| s.into().0),
points_iter.map(|g| g.0),
scalars.iter().map(|s| s.0),
points.iter().map(|g| g.0),
),
))
}
Expand Down
99 changes: 65 additions & 34 deletions fastcrypto/src/tests/bls12381_group_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@
// SPDX-License-Identifier: Apache-2.0

use crate::bls12381::min_pk::{BLS12381KeyPair, BLS12381Signature};
use crate::groups::bls12381::{
G1Element, G2Element, GTElement, Scalar, G1_ELEMENT_BYTE_LENGTH, G2_ELEMENT_BYTE_LENGTH,
use crate::groups::bls12381::{G1Element, G2Element, GTElement, Scalar};
use crate::groups::{
GroupElement, HashToGroupElement, MultiScalarMul, Pairing, Scalar as ScalarTrait,
};
use crate::groups::{GroupElement, HashToGroupElement, Pairing};
use crate::test_helpers::verify_serialization;
use crate::traits::Signer;
use crate::traits::VerifyingKey;
use crate::traits::{KeyPair, ToFromBytes};
use rand::{rngs::StdRng, SeedableRng as _};
use rand::{rngs::StdRng, thread_rng, SeedableRng as _};

const MSG: &[u8] = b"test message";

// TODO: add regression tests with test vectors.
// TODO: add test vectors.

#[test]
fn test_g1_arithmetic() {
Expand Down Expand Up @@ -43,6 +44,26 @@ fn test_g1_arithmetic() {
assert_eq!(G1Element::zero(), g - g);
}

#[test]
fn test_g1_msm() {
let mut scalars = Vec::new();
let mut points = Vec::new();
let mut expected = G1Element::zero();
for _ in 0..50 {
let s = Scalar::rand(&mut thread_rng());
let e = Scalar::rand(&mut thread_rng());
let g = G1Element::generator() * e;
expected += g * s;
scalars.push(s);
points.push(g);
}
let actual = G1Element::multi_scalar_mul(&scalars, &points).unwrap();
assert_eq!(expected, actual);

assert!(G1Element::multi_scalar_mul(&scalars[1..], &points).is_err());
assert!(G1Element::multi_scalar_mul(&[], &[]).is_err());
}

#[test]
fn test_g2_arithmetic() {
// Test that different ways of computing [5]G gives the expected result
Expand Down Expand Up @@ -71,6 +92,26 @@ fn test_g2_arithmetic() {
assert_eq!(G2Element::zero(), g - g);
}

#[test]
fn test_g2_msm() {
let mut scalars = Vec::new();
let mut points = Vec::new();
let mut expected = G2Element::zero();
for _ in 0..50 {
let s = Scalar::rand(&mut thread_rng());
let e = Scalar::rand(&mut thread_rng());
let g = G2Element::generator() * e;
expected += g * s;
scalars.push(s);
points.push(g);
}
let actual = G2Element::multi_scalar_mul(&scalars, &points).unwrap();
assert_eq!(expected, actual);

assert!(G2Element::multi_scalar_mul(&scalars[1..], &points).is_err());
assert!(G2Element::multi_scalar_mul(&[], &[]).is_err());
}

#[test]
fn test_gt_arithmetic() {
// Test that different ways of computing [5]G gives the expected result
Expand Down Expand Up @@ -115,35 +156,25 @@ fn test_pairing_and_hash_to_curve() {
}

#[test]
fn test_g1_serialize_deserialize() {
// Serialize and deserialize 7*G1
let p = G1Element::generator() * Scalar::from(7);
let serialized = bincode::serialize(&p).unwrap();
assert_eq!(serialized.len(), G1_ELEMENT_BYTE_LENGTH);
let deserialized: G1Element = bincode::deserialize(&serialized).unwrap();
assert_eq!(deserialized, p);

// Serialize and deserialize O
let p = G1Element::zero();
let serialized = bincode::serialize(&p).unwrap();
let deserialized: G1Element = bincode::deserialize(&serialized).unwrap();
assert_eq!(deserialized, p);
}

#[test]
fn test_g2_serialize_deserialize() {
// Serialize and deserialize 7*G1
let p = G2Element::generator() * Scalar::from(7);
let serialized = bincode::serialize(&p).unwrap();
assert_eq!(serialized.len(), G2_ELEMENT_BYTE_LENGTH);
let deserialized: G2Element = bincode::deserialize(&serialized).unwrap();
assert_eq!(deserialized, p);

// Serialize and deserialize O
let p = G2Element::zero();
let serialized = bincode::serialize(&p).unwrap();
let deserialized: G2Element = bincode::deserialize(&serialized).unwrap();
assert_eq!(deserialized, p);
fn test_serde_and_regression() {
let s1 = Scalar::from(1);
let g1 = G1Element::generator();
let g2 = G2Element::generator();
let id1 = G1Element::zero();
let id2 = G2Element::zero();

verify_serialization(
&s1,
Some(
hex::decode("0000000000000000000000000000000000000000000000000000000000000001")
.unwrap()
.as_slice(),
),
);
verify_serialization(&g1, Some(hex::decode("97f1d3a73197d7942695638c4fa9ac0fc3688c4f9774b905a14e3a3f171bac586c55e83ff97a1aeffb3af00adb22c6bb").unwrap().as_slice()));
verify_serialization(&g2, Some(hex::decode("93e02b6052719f607dacd3a088274f65596bd0d09920b61ab5da61bbdc7f5049334cf11213945d57e5ac7d055d042b7e024aa2b2f08f0a91260805272dc51051c6e47ad4fa403b02b4510b647ae3d1770bac0326a805bbefd48056c8c121bdb8").unwrap().as_slice()));
verify_serialization(&id1, Some(hex::decode("c00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000").unwrap().as_slice()));
verify_serialization(&id2, Some(hex::decode("c00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000").unwrap().as_slice()));
}

#[test]
Expand Down
18 changes: 15 additions & 3 deletions fastcrypto/src/tests/ristretto255_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

use crate::groups::ristretto255::RistrettoPoint;
use crate::groups::ristretto255::RistrettoScalar;
use crate::groups::GroupElement;
use crate::groups::{GroupElement, MultiScalarMul};

#[test]
fn test_arithmetic() {
Expand Down Expand Up @@ -161,9 +161,21 @@ fn test_vectors() {
#[test]
fn test_multiscalar_mul() {
let g = RistrettoPoint::generator();
let h = RistrettoPoint::multiscalar_mul([1, 2, 3], [g, g, g]).unwrap();
let h = RistrettoPoint::multi_scalar_mul(
&[
RistrettoScalar::from(1),
RistrettoScalar::from(2),
RistrettoScalar::from(3),
],
&[g, g, g],
)
.unwrap();
assert_eq!(g * RistrettoScalar::from(6), h);

// Invalid lengths
assert!(RistrettoPoint::multiscalar_mul([1, 2], [g, g, g]).is_err());
assert!(RistrettoPoint::multi_scalar_mul(
&[RistrettoScalar::from(1), RistrettoScalar::from(2)],
&[g, g, g]
)
.is_err());
}
10 changes: 5 additions & 5 deletions fastcrypto/src/vrf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ pub trait VRFProof<const OUTPUT_SIZE: usize> {
pub mod ecvrf {
use crate::error::FastCryptoError;
use crate::groups::ristretto255::{RistrettoPoint, RistrettoScalar};
use crate::groups::{GroupElement, Scalar};
use crate::groups::{GroupElement, MultiScalarMul, Scalar};
use crate::hash::{HashFunction, ReverseWrapper, Sha512};
use crate::serde_helpers::ToFromByteArray;
use crate::traits::AllowedRng;
Expand Down Expand Up @@ -245,11 +245,11 @@ pub mod ecvrf {
let h = public_key.ecvrf_encode_to_curve(alpha_string);

let challenge = RistrettoScalar::from(&self.c);
let u = RistrettoPoint::multiscalar_mul(
[self.s, -challenge],
[RistrettoPoint::generator(), public_key.0],
let u = RistrettoPoint::multi_scalar_mul(
&[self.s, -challenge],
&[RistrettoPoint::generator(), public_key.0],
)?;
let v = RistrettoPoint::multiscalar_mul([self.s, -challenge], [h, self.gamma])?;
let v = RistrettoPoint::multi_scalar_mul(&[self.s, -challenge], &[h, self.gamma])?;

let c_prime = ecvrf_challenge_generation([&public_key.0, &h, &self.gamma, &u, &v]);

Expand Down

0 comments on commit 8a8a0c5

Please sign in to comment.