diff --git a/Cargo.lock b/Cargo.lock index ed028e4cd..645ae1da7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1851,9 +1851,9 @@ dependencies = [ [[package]] name = "hex-literal" -version = "0.3.4" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ebdb29d2ea9ed0083cd8cece49bbd968021bd99b0849edb4a9a7ee0fdf6a4e0" +checksum = "6fe2267d4ed49bc07b63801559be28c718ea06c4738b7a03c94df7386d2cde46" [[package]] name = "hkdf" diff --git a/fastcrypto/Cargo.toml b/fastcrypto/Cargo.toml index 902fb7ccd..57c19551a 100644 --- a/fastcrypto/Cargo.toml +++ b/fastcrypto/Cargo.toml @@ -15,6 +15,7 @@ bs58 = "0.4.0" ed25519-consensus = { version = "2.1.0", features = ["serde"] } eyre = "0.6.8" hex = "0.4.3" +hex-literal = "0.4.1" hkdf = { version = "0.12.3", features = ["std"] } rand.workspace = true rust_secp256k1 = { version = "0.27.0", package = "secp256k1", features = ["recovery", "rand-std", "bitcoin_hashes", "global-context"] } @@ -103,7 +104,6 @@ experimental = [] [dev-dependencies] criterion = "0.4.0" -hex-literal = "0.3.4" k256 = { version = "0.11.6", features = ["ecdsa", "sha256", "keccak256"] } proptest = "1.1.0" serde-reflection = "0.3.6" diff --git a/fastcrypto/src/groups/bls12381.rs b/fastcrypto/src/groups/bls12381.rs index bcb516c70..bb803132c 100644 --- a/fastcrypto/src/groups/bls12381.rs +++ b/fastcrypto/src/groups/bls12381.rs @@ -28,6 +28,7 @@ use blst::{ }; use derive_more::From; use fastcrypto_derive::GroupOpsExtend; +use hex_literal::hex; use once_cell::sync::OnceCell; use serde::{de, Deserialize}; use std::fmt::Debug; @@ -57,6 +58,7 @@ pub const SCALAR_LENGTH: usize = 32; pub const G1_ELEMENT_BYTE_LENGTH: usize = 48; pub const G2_ELEMENT_BYTE_LENGTH: usize = 96; pub const GT_ELEMENT_BYTE_LENGTH: usize = 576; +pub const FP_BYTE_LENGTH: usize = 48; impl Add for G1Element { type Output = Self; @@ -160,8 +162,7 @@ impl MultiScalarMul for G1Element { return Err(FastCryptoError::InvalidInput); } // Inspired by blstrs. - let points = - unsafe { std::slice::from_raw_parts(points.as_ptr() as *const blst_p1, points.len()) }; + let points = to_blst_type_slice(points); let points = p1_affines::from(points); let mut scalar_bytes: Vec = Vec::with_capacity(scalars.len() * 32); for a in scalars.iter().map(|s| s.0) { @@ -177,6 +178,14 @@ impl MultiScalarMul for G1Element { } } +// Bound the lifetime of points to the output slice. +fn to_blst_type_slice(points: &[From]) -> &[To] { + // SAFETY: the cast from `&[G1Element]` to `&[blst_p1]` is safe because + // G1Element is a transparent wrapper around blst_p1. The lifetime of + // output slice is the same as the input slice. + unsafe { std::slice::from_raw_parts(points.as_ptr() as *const To, points.len()) } +} + impl GroupElement for G1Element { type ScalarType = Scalar; @@ -348,8 +357,7 @@ impl MultiScalarMul for G2Element { return Err(FastCryptoError::InvalidInput); } // Inspired by blstrs. - let points = - unsafe { std::slice::from_raw_parts(points.as_ptr() as *const blst_p2, points.len()) }; + let points = to_blst_type_slice(points); let points = p2_affines::from(points); let mut scalar_bytes: Vec = Vec::with_capacity(scalars.len() * 32); for a in scalars.iter().map(|s| s.0) { @@ -535,6 +543,8 @@ impl GTElement { } } +const P_AS_BYTES: [u8; FP_BYTE_LENGTH] = hex!("1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab"); + // Note that the serialization below is uncompressed, i.e. it uses 576 bytes. impl ToFromByteArray for GTElement { fn from_byte_array(bytes: &[u8; GT_ELEMENT_BYTE_LENGTH]) -> Result { @@ -546,14 +556,21 @@ impl ToFromByteArray for GTElement { for j in 0..2 { for k in 0..2 { let mut fp = blst_fp::default(); + let slice = &bytes[current..current + FP_BYTE_LENGTH]; + // We compare with P_AS_BYTES to ensure that we process a canonical representation + // which is uses mod p elements. + if *slice >= P_AS_BYTES[..] { + return Err(FastCryptoError::InvalidInput); + } unsafe { - blst_fp_from_bendian(&mut fp, bytes[current..current + 48].as_ptr()); + blst_fp_from_bendian(&mut fp, slice.as_ptr()); } gt.fp6[j].fp2[i].fp[k] = fp; - current += 48; + current += FP_BYTE_LENGTH; } } } + match gt.in_group() { true => Ok(Self::from(gt)), false => Err(FastCryptoError::InvalidInput), diff --git a/fastcrypto/src/tests/bls12381_group_tests.rs b/fastcrypto/src/tests/bls12381_group_tests.rs index fd46c0744..e065bf66a 100644 --- a/fastcrypto/src/tests/bls12381_group_tests.rs +++ b/fastcrypto/src/tests/bls12381_group_tests.rs @@ -370,3 +370,31 @@ fn test_reduce_mod_uniform_buffer() { hex::decode("21015212b5c7a44c04c39447bf7d2addc5035a9b118f07a29956bf00fa65bd74").unwrap(); assert_eq!(expected, reduce_mod_uniform_buffer(&bytes).to_byte_array()); } + +#[test] +fn test_serialization_gt() { + // All zero serialization for GT should fail. + let bytes = [0u8; 576]; + assert!(GTElement::from_byte_array(&bytes).is_err()); + + // to and from_byte_array should be inverses. + let bytes = GTElement::generator().to_byte_array(); + assert_eq!( + GTElement::generator(), + GTElement::from_byte_array(&bytes).unwrap() + ); + + // reject if one of the elements >= P + let mut bytes = GTElement::generator().to_byte_array(); + let p = hex::decode("1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab").unwrap(); + let mut carry = 0; + let mut target = [0; 48]; + for i in (0..48).rev() { + let sum = (bytes[i] as u16) + (p[i] as u16) + carry; + target[i] = (sum % 256) as u8; + carry = sum / 256; + } + assert_eq!(carry, 0); + bytes[0..48].copy_from_slice(&target); + assert!(GTElement::from_byte_array(&bytes).is_err()); +}