From 7175313157075dcf4e8bb7adac8179d4c245f26b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonas=20Lindstr=C3=B8m?= Date: Mon, 27 May 2024 14:35:46 +0200 Subject: [PATCH] Refactoring class group code (#795) * Small clean ups * refactor * Reorder computations --- fastcrypto-cli/src/vdf.rs | 8 +- fastcrypto-vdf/src/class_group/hash.rs | 6 +- fastcrypto-vdf/src/class_group/mod.rs | 144 +++++++++--------- fastcrypto-vdf/src/class_group/reduction.rs | 5 + fastcrypto-vdf/src/math/extended_gcd.rs | 19 +-- .../src/math/parameterized_group.rs | 2 +- fastcrypto-vdf/src/vdf/wesolowski/mod.rs | 8 +- 7 files changed, 93 insertions(+), 99 deletions(-) diff --git a/fastcrypto-cli/src/vdf.rs b/fastcrypto-cli/src/vdf.rs index d1fa3beaf6..581bed2c5b 100644 --- a/fastcrypto-cli/src/vdf.rs +++ b/fastcrypto-cli/src/vdf.rs @@ -116,7 +116,7 @@ fn execute(cmd: Command) -> Result { "Invalid input point or discriminant.", ) })?; - if !g.has_parameter(&DISCRIMINANT_3072) { + if !g.is_in_group(&DISCRIMINANT_3072) { return Err(Error::new( ErrorKind::InvalidInput, "Input point does not match discriminant.", @@ -144,7 +144,7 @@ fn execute(cmd: Command) -> Result { Error::new(ErrorKind::InvalidInput, "Invalid output hex string.") })?) .map_err(|_| Error::new(ErrorKind::InvalidInput, "Invalid input."))?; - if !input.has_parameter(&DISCRIMINANT_3072) { + if !input.is_in_group(&DISCRIMINANT_3072) { return Err(Error::new( ErrorKind::InvalidInput, "Input has wrong discriminant.", @@ -156,7 +156,7 @@ fn execute(cmd: Command) -> Result { Error::new(ErrorKind::InvalidInput, "Invalid output hex string.") })?) .map_err(|_| Error::new(ErrorKind::InvalidInput, "Invalid output."))?; - if !output.has_parameter(&DISCRIMINANT_3072) { + if !output.is_in_group(&DISCRIMINANT_3072) { return Err(Error::new( ErrorKind::InvalidInput, "Output has wrong discriminant.", @@ -168,7 +168,7 @@ fn execute(cmd: Command) -> Result { Error::new(ErrorKind::InvalidInput, "Invalid proof hex string.") })?) .map_err(|_| Error::new(ErrorKind::InvalidInput, "Invalid proof."))?; - if !proof.has_parameter(&DISCRIMINANT_3072) { + if !proof.is_in_group(&DISCRIMINANT_3072) { return Err(Error::new( ErrorKind::InvalidInput, "Proof has wrong discriminant.", diff --git a/fastcrypto-vdf/src/class_group/hash.rs b/fastcrypto-vdf/src/class_group/hash.rs index 8d13ed3fef..6c82df18bb 100644 --- a/fastcrypto-vdf/src/class_group/hash.rs +++ b/fastcrypto-vdf/src/class_group/hash.rs @@ -185,14 +185,14 @@ mod tests { for _ in 0..10 { let qf = QuadraticForm::hash_to_group(&seed, &discriminant, 1).unwrap(); assert!(qf.is_reduced_assuming_normal()); - assert!(qf.has_parameter(&discriminant)); + assert!(qf.is_in_group(&discriminant)); seed[0] += 1; } for _ in 0..10 { let qf = QuadraticForm::hash_to_group(&seed, &discriminant, 4).unwrap(); assert!(qf.is_reduced_assuming_normal()); - assert!(qf.has_parameter(&discriminant)); + assert!(qf.is_in_group(&discriminant)); seed[0] += 1; } } @@ -201,7 +201,7 @@ mod tests { fn qf_from_seed_sanity_tests() { let discriminant = Discriminant::from_seed(b"discriminant seed", 800).unwrap(); let base_qf = QuadraticForm::hash_to_group(b"qf seed", &discriminant, 6).unwrap(); - assert!(base_qf.has_parameter(&discriminant)); + assert!(base_qf.is_in_group(&discriminant)); // Same seed, same discriminant, same k let other_qf = QuadraticForm::hash_to_group(b"qf seed", &discriminant, 6).unwrap(); diff --git a/fastcrypto-vdf/src/class_group/mod.rs b/fastcrypto-vdf/src/class_group/mod.rs index af0127f56b..e0f34af189 100644 --- a/fastcrypto-vdf/src/class_group/mod.rs +++ b/fastcrypto-vdf/src/class_group/mod.rs @@ -145,52 +145,34 @@ impl QuadraticForm { a_divided_by_gcd: h, b_divided_by_gcd, } = extended_euclidean_algorithm(&f, &s); - capital_by *= &h; - capital_cy *= &h; // 4. let l = (&y * (&b * (w1.mod_floor(&h)) + &c * (w2.mod_floor(&h)))).mod_floor(&h); - ( - g, - &b * (&m / &h) + &l * (&capital_by / &h), - b_divided_by_gcd, - ) + let capital_bx = &b * (&m / &h) + &l * &capital_by; + capital_by *= &h; + capital_cy *= &h; + (g, capital_bx, b_divided_by_gcd) }; // 5. (partial xgcd) - let mut bx = capital_bx.mod_floor(&capital_by); - let mut by = capital_by.clone(); - - let mut x = BigInt::one(); - let mut y = BigInt::zero(); - let mut z = 0u32; - - while by.abs() > *self.partial_gcd_limit() && !bx.is_zero() { - let (q, t) = by.div_rem(&bx); - by = bx; - bx = t; - swap(&mut x, &mut y); - x -= &q * &y; - z += 1; - } - - if z.is_odd() { - by = -by; - y = -y; - } + let (bx, x, by, y, iterated) = partial_xgcd( + capital_bx.mod_floor(&capital_by), + capital_by.clone(), + self.partial_gcd_limit(), + ); let u3: BigInt; let w3: BigInt; let v3: BigInt; - if z == 0 { + if !iterated { // 6. let q = &capital_cy * &bx; let cx = (&q - &m) / &capital_by; let dx = (&bx * &capital_dy - w2) / &capital_by; u3 = &by * &capital_cy; - w3 = &bx * &cx - &g * &dx; v3 = v2 - (&q << 1); + w3 = &bx * &cx - &g * &dx; } else { // 7. let cx = (&capital_cy * &bx - &m * &x) / &capital_by; @@ -207,18 +189,17 @@ impl QuadraticForm { }; u3 = &by * &cy - &g * &y * &dy; - w3 = &bx * &cx - &g * &x * &dx; v3 = &g * (&q3 + &q4) - &q1 - &q2; + w3 = &bx * &cx - &g * &x * &dx; } - let mut form = QuadraticForm { + QuadraticForm { a: u3, b: v3, c: w3, partial_gcd_limit: self.partial_gcd_limit.clone(), - }; - form.reduce(); - form + } + .into_reduced() } } @@ -241,61 +222,72 @@ impl Doubling for QuadraticForm { b_divided_by_gcd: capital_dy, } = extended_euclidean_algorithm(u, v); - let mut bx = (&y * w).mod_floor(&capital_by); - let mut by = capital_by.clone(); + let (bx, x, by, y, iterated) = partial_xgcd( + (&y * w).mod_floor(&capital_by), + capital_by.clone(), + self.partial_gcd_limit(), + ); - let mut x = BigInt::one(); - let mut y = BigInt::zero(); - let mut z = 0u32; + let mut u3 = &by * &by; + let mut w3 = &bx * &bx; + let mut v3 = &u3 + &w3 - &(&bx + &by).pow(2); - while by.abs() > *self.partial_gcd_limit() && !bx.is_zero() { - let (q, t) = by.div_rem(&bx); - by = bx; - bx = t; - swap(&mut x, &mut y); - x -= &q * &y; - z += 1; - } - - if z.is_odd() { - by = -by; - y = -y; - } - - let mut u3: BigInt; - let mut w3: BigInt; - let mut v3: BigInt; - - if z == 0 { + if !iterated { let dx = (&bx * &capital_dy - w) / &capital_by; - u3 = &by * &by; - w3 = &bx * &bx; - let s = &bx + &by; - v3 = v - &s * &s + &u3 + &w3; - w3 = &w3 - &g * &dx; + v3 += v; + w3 -= &g * &dx; } else { let dx = (&bx * &capital_dy - w * &x) / &capital_by; let q1 = &dx * &y; let mut dy = &q1 + &capital_dy; - v3 = &g * (&dy + &q1); - dy = &dy / &x; - u3 = &by * &by; - w3 = &bx * &bx; - v3 = &v3 - (&bx + &by).pow(2) + &u3 + &w3; - - u3 = &u3 - &g * &y * &dy; - w3 = &w3 - &g * &x * &dx; + v3 += &g * (&dy + &q1); + dy /= &x; + + u3 -= &g * &y * &dy; + w3 -= &g * &x * &dx; } - let mut form = QuadraticForm { + QuadraticForm { a: u3, b: v3, c: w3, partial_gcd_limit: self.partial_gcd_limit.clone(), - }; - form.reduce(); - form + } + .into_reduced() + } +} + +/// Compute the xgcd of bx and by with a partial limit. When by is below the limit, the computation +/// stops early and returns the result. The result is a tuple (bx, x, by, y, iterated) where bx and +/// by are the now reduced coefficients, x and y are the Bezout coefficients for bx and by respectively, +/// and iterated is true if the there were any iterations. +fn partial_xgcd( + mut bx: BigInt, + mut by: BigInt, + limit: &BigInt, +) -> (BigInt, BigInt, BigInt, BigInt, bool) { + let mut x = BigInt::one(); + let mut y = BigInt::zero(); + let mut iterated = false; + let mut odd = false; + + while by.abs() > *limit && !bx.is_zero() { + let (q, r) = by.div_rem(&bx); + by = bx; + bx = r; + swap(&mut x, &mut y); + x -= &q * &y; + + odd = !odd; + iterated = true; } + + if odd { + by = -by; + y = -y; + } + + (bx, x, by, y, iterated) } impl ParameterizedGroupElement for QuadraticForm { @@ -309,7 +301,7 @@ impl ParameterizedGroupElement for QuadraticForm { .expect("Doesn't fail") } - fn has_parameter(&self, discriminant: &Discriminant) -> bool { + fn is_in_group(&self, discriminant: &Discriminant) -> bool { discriminant.as_bigint().eq(&self.discriminant()) } } diff --git a/fastcrypto-vdf/src/class_group/reduction.rs b/fastcrypto-vdf/src/class_group/reduction.rs index 375fef3bc5..6a334192fc 100644 --- a/fastcrypto-vdf/src/class_group/reduction.rs +++ b/fastcrypto-vdf/src/class_group/reduction.rs @@ -53,6 +53,11 @@ impl QuadraticForm { self.b = cs.shl(1) - &self.b; } } + + pub(crate) fn into_reduced(mut self) -> QuadraticForm { + self.reduce(); + self + } } #[cfg(test)] diff --git a/fastcrypto-vdf/src/math/extended_gcd.rs b/fastcrypto-vdf/src/math/extended_gcd.rs index 39bb2bdfbc..6f0d5c2386 100644 --- a/fastcrypto-vdf/src/math/extended_gcd.rs +++ b/fastcrypto-vdf/src/math/extended_gcd.rs @@ -6,7 +6,7 @@ //! divided by the GCD since these are often used, for example in the NUCOMP and NUDPL algorithms, //! and come out for free while computing the Bezout coefficients. -use num_bigint::BigInt; +use num_bigint::{BigInt, Sign}; use num_integer::Integer; use num_traits::{One, Signed, Zero}; use std::mem; @@ -58,16 +58,8 @@ pub(crate) fn extended_euclidean_algorithm(a: &BigInt, b: &BigInt) -> EuclideanA } // The last coefficients are equal to +/- a / gcd(a,b) and b / gcd(a,b) respectively. - let a_divided_by_gcd = if a.sign() != s.0.sign() { - s.0.neg() - } else { - s.0 - }; - let b_divided_by_gcd = if b.sign() != t.0.sign() { - t.0.neg() - } else { - t.0 - }; + let a_divided_by_gcd = set_sign(s.0, a.sign()); + let b_divided_by_gcd = set_sign(t.0, b.sign()); if !r.1.is_negative() { EuclideanAlgorithmOutput { @@ -88,6 +80,11 @@ pub(crate) fn extended_euclidean_algorithm(a: &BigInt, b: &BigInt) -> EuclideanA } } +/// Return a number with the same magnitude as `value` but with the given sign. +fn set_sign(value: BigInt, sign: Sign) -> BigInt { + BigInt::from_biguint(sign, value.into_parts().1) +} + #[test] fn test_xgcd() { test_xgcd_single(BigInt::from(240), BigInt::from(46)); diff --git a/fastcrypto-vdf/src/math/parameterized_group.rs b/fastcrypto-vdf/src/math/parameterized_group.rs index c03e255c3e..87543394b3 100644 --- a/fastcrypto-vdf/src/math/parameterized_group.rs +++ b/fastcrypto-vdf/src/math/parameterized_group.rs @@ -28,5 +28,5 @@ pub trait ParameterizedGroupElement: fn zero(parameters: &Self::ParameterType) -> Self; /// Returns true if this is an element of the group defined by `parameter`. - fn has_parameter(&self, parameter: &Self::ParameterType) -> bool; + fn is_in_group(&self, parameter: &Self::ParameterType) -> bool; } diff --git a/fastcrypto-vdf/src/vdf/wesolowski/mod.rs b/fastcrypto-vdf/src/vdf/wesolowski/mod.rs index 65d7deb8af..05085b42dc 100644 --- a/fastcrypto-vdf/src/vdf/wesolowski/mod.rs +++ b/fastcrypto-vdf/src/vdf/wesolowski/mod.rs @@ -58,7 +58,7 @@ impl< type ProofType = G; fn evaluate(&self, input: &G) -> FastCryptoResult<(G, G)> { - if !input.has_parameter(&self.group_parameter) || self.iterations == 0 { + if !input.is_in_group(&self.group_parameter) || self.iterations == 0 { return Err(InvalidInput); } @@ -84,9 +84,9 @@ impl< } fn verify(&self, input: &G, output: &G, proof: &G) -> FastCryptoResult<()> { - if !input.has_parameter(&self.group_parameter) - || !output.has_parameter(&self.group_parameter) - || !proof.has_parameter(&self.group_parameter) + if !input.is_in_group(&self.group_parameter) + || !output.is_in_group(&self.group_parameter) + || !proof.is_in_group(&self.group_parameter) { return Err(InvalidInput); }