Skip to content

Commit

Permalink
BLS12-381 group elements - make the code safer (#714)
Browse files Browse the repository at this point in the history
- Safer lifetimes
- guarantee a unique bytes representation of GT elements
  • Loading branch information
benr-ml committed Jan 8, 2024
1 parent c18dc38 commit dd5adb6
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 9 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion fastcrypto/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ bs58 = "0.4.0"
ed25519-consensus = { version = "2.1.0", features = ["serde"] }
eyre = "0.6.8"
hex = "0.4.3"
hex-literal = "0.4.1"
hkdf = { version = "0.12.3", features = ["std"] }
rand.workspace = true
rust_secp256k1 = { version = "0.27.0", package = "secp256k1", features = ["recovery", "rand-std", "bitcoin_hashes", "global-context"] }
Expand Down Expand Up @@ -103,7 +104,6 @@ experimental = []

[dev-dependencies]
criterion = "0.4.0"
hex-literal = "0.3.4"
k256 = { version = "0.11.6", features = ["ecdsa", "sha256", "keccak256"] }
proptest = "1.1.0"
serde-reflection = "0.3.6"
Expand Down
29 changes: 23 additions & 6 deletions fastcrypto/src/groups/bls12381.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use blst::{
};
use derive_more::From;
use fastcrypto_derive::GroupOpsExtend;
use hex_literal::hex;
use once_cell::sync::OnceCell;
use serde::{de, Deserialize};
use std::fmt::Debug;
Expand Down Expand Up @@ -57,6 +58,7 @@ pub const SCALAR_LENGTH: usize = 32;
pub const G1_ELEMENT_BYTE_LENGTH: usize = 48;
pub const G2_ELEMENT_BYTE_LENGTH: usize = 96;
pub const GT_ELEMENT_BYTE_LENGTH: usize = 576;
pub const FP_BYTE_LENGTH: usize = 48;

impl Add for G1Element {
type Output = Self;
Expand Down Expand Up @@ -160,8 +162,7 @@ impl MultiScalarMul for G1Element {
return Err(FastCryptoError::InvalidInput);
}
// Inspired by blstrs.
let points =
unsafe { std::slice::from_raw_parts(points.as_ptr() as *const blst_p1, points.len()) };
let points = to_blst_type_slice(points);
let points = p1_affines::from(points);
let mut scalar_bytes: Vec<u8> = Vec::with_capacity(scalars.len() * 32);
for a in scalars.iter().map(|s| s.0) {
Expand All @@ -177,6 +178,14 @@ impl MultiScalarMul for G1Element {
}
}

// Bound the lifetime of points to the output slice.
fn to_blst_type_slice<From, To>(points: &[From]) -> &[To] {
// SAFETY: the cast from `&[G1Element]` to `&[blst_p1]` is safe because
// G1Element is a transparent wrapper around blst_p1. The lifetime of
// output slice is the same as the input slice.
unsafe { std::slice::from_raw_parts(points.as_ptr() as *const To, points.len()) }
}

impl GroupElement for G1Element {
type ScalarType = Scalar;

Expand Down Expand Up @@ -348,8 +357,7 @@ impl MultiScalarMul for G2Element {
return Err(FastCryptoError::InvalidInput);
}
// Inspired by blstrs.
let points =
unsafe { std::slice::from_raw_parts(points.as_ptr() as *const blst_p2, points.len()) };
let points = to_blst_type_slice(points);
let points = p2_affines::from(points);
let mut scalar_bytes: Vec<u8> = Vec::with_capacity(scalars.len() * 32);
for a in scalars.iter().map(|s| s.0) {
Expand Down Expand Up @@ -535,6 +543,8 @@ impl GTElement {
}
}

const P_AS_BYTES: [u8; FP_BYTE_LENGTH] = hex!("1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab");

// Note that the serialization below is uncompressed, i.e. it uses 576 bytes.
impl ToFromByteArray<GT_ELEMENT_BYTE_LENGTH> for GTElement {
fn from_byte_array(bytes: &[u8; GT_ELEMENT_BYTE_LENGTH]) -> Result<Self, FastCryptoError> {
Expand All @@ -546,14 +556,21 @@ impl ToFromByteArray<GT_ELEMENT_BYTE_LENGTH> for GTElement {
for j in 0..2 {
for k in 0..2 {
let mut fp = blst_fp::default();
let slice = &bytes[current..current + FP_BYTE_LENGTH];
// We compare with P_AS_BYTES to ensure that we process a canonical representation
// which is uses mod p elements.
if *slice >= P_AS_BYTES[..] {
return Err(FastCryptoError::InvalidInput);
}
unsafe {
blst_fp_from_bendian(&mut fp, bytes[current..current + 48].as_ptr());
blst_fp_from_bendian(&mut fp, slice.as_ptr());
}
gt.fp6[j].fp2[i].fp[k] = fp;
current += 48;
current += FP_BYTE_LENGTH;
}
}
}

match gt.in_group() {
true => Ok(Self::from(gt)),
false => Err(FastCryptoError::InvalidInput),
Expand Down
28 changes: 28 additions & 0 deletions fastcrypto/src/tests/bls12381_group_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,3 +370,31 @@ fn test_reduce_mod_uniform_buffer() {
hex::decode("21015212b5c7a44c04c39447bf7d2addc5035a9b118f07a29956bf00fa65bd74").unwrap();
assert_eq!(expected, reduce_mod_uniform_buffer(&bytes).to_byte_array());
}

#[test]
fn test_serialization_gt() {
// All zero serialization for GT should fail.
let bytes = [0u8; 576];
assert!(GTElement::from_byte_array(&bytes).is_err());

// to and from_byte_array should be inverses.
let bytes = GTElement::generator().to_byte_array();
assert_eq!(
GTElement::generator(),
GTElement::from_byte_array(&bytes).unwrap()
);

// reject if one of the elements >= P
let mut bytes = GTElement::generator().to_byte_array();
let p = hex::decode("1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab").unwrap();
let mut carry = 0;
let mut target = [0; 48];
for i in (0..48).rev() {
let sum = (bytes[i] as u16) + (p[i] as u16) + carry;
target[i] = (sum % 256) as u8;
carry = sum / 256;
}
assert_eq!(carry, 0);
bytes[0..48].copy_from_slice(&target);
assert!(GTElement::from_byte_array(&bytes).is_err());
}

0 comments on commit dd5adb6

Please sign in to comment.