Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring class group code #795

Merged
merged 3 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions fastcrypto-cli/src/vdf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ fn execute(cmd: Command) -> Result<String, Error> {
"Invalid input point or discriminant.",
)
})?;
if !g.has_parameter(&DISCRIMINANT_3072) {
if !g.is_in_group(&DISCRIMINANT_3072) {
return Err(Error::new(
ErrorKind::InvalidInput,
"Input point does not match discriminant.",
Expand Down Expand Up @@ -144,7 +144,7 @@ fn execute(cmd: Command) -> Result<String, Error> {
Error::new(ErrorKind::InvalidInput, "Invalid output hex string.")
})?)
.map_err(|_| Error::new(ErrorKind::InvalidInput, "Invalid input."))?;
if !input.has_parameter(&DISCRIMINANT_3072) {
if !input.is_in_group(&DISCRIMINANT_3072) {
return Err(Error::new(
ErrorKind::InvalidInput,
"Input has wrong discriminant.",
Expand All @@ -156,7 +156,7 @@ fn execute(cmd: Command) -> Result<String, Error> {
Error::new(ErrorKind::InvalidInput, "Invalid output hex string.")
})?)
.map_err(|_| Error::new(ErrorKind::InvalidInput, "Invalid output."))?;
if !output.has_parameter(&DISCRIMINANT_3072) {
if !output.is_in_group(&DISCRIMINANT_3072) {
return Err(Error::new(
ErrorKind::InvalidInput,
"Output has wrong discriminant.",
Expand All @@ -168,7 +168,7 @@ fn execute(cmd: Command) -> Result<String, Error> {
Error::new(ErrorKind::InvalidInput, "Invalid proof hex string.")
})?)
.map_err(|_| Error::new(ErrorKind::InvalidInput, "Invalid proof."))?;
if !proof.has_parameter(&DISCRIMINANT_3072) {
if !proof.is_in_group(&DISCRIMINANT_3072) {
return Err(Error::new(
ErrorKind::InvalidInput,
"Proof has wrong discriminant.",
Expand Down
6 changes: 3 additions & 3 deletions fastcrypto-vdf/src/class_group/hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,14 @@ mod tests {
for _ in 0..10 {
let qf = QuadraticForm::hash_to_group(&seed, &discriminant, 1).unwrap();
assert!(qf.is_reduced_assuming_normal());
assert!(qf.has_parameter(&discriminant));
assert!(qf.is_in_group(&discriminant));
seed[0] += 1;
}

for _ in 0..10 {
let qf = QuadraticForm::hash_to_group(&seed, &discriminant, 4).unwrap();
assert!(qf.is_reduced_assuming_normal());
assert!(qf.has_parameter(&discriminant));
assert!(qf.is_in_group(&discriminant));
seed[0] += 1;
}
}
Expand All @@ -201,7 +201,7 @@ mod tests {
fn qf_from_seed_sanity_tests() {
let discriminant = Discriminant::from_seed(b"discriminant seed", 800).unwrap();
let base_qf = QuadraticForm::hash_to_group(b"qf seed", &discriminant, 6).unwrap();
assert!(base_qf.has_parameter(&discriminant));
assert!(base_qf.is_in_group(&discriminant));

// Same seed, same discriminant, same k
let other_qf = QuadraticForm::hash_to_group(b"qf seed", &discriminant, 6).unwrap();
Expand Down
144 changes: 68 additions & 76 deletions fastcrypto-vdf/src/class_group/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,52 +145,34 @@ impl QuadraticForm {
a_divided_by_gcd: h,
b_divided_by_gcd,
} = extended_euclidean_algorithm(&f, &s);
capital_by *= &h;
capital_cy *= &h;

// 4.
let l = (&y * (&b * (w1.mod_floor(&h)) + &c * (w2.mod_floor(&h)))).mod_floor(&h);
(
g,
&b * (&m / &h) + &l * (&capital_by / &h),
b_divided_by_gcd,
)
let capital_bx = &b * (&m / &h) + &l * &capital_by;
capital_by *= &h;
capital_cy *= &h;
(g, capital_bx, b_divided_by_gcd)
};

// 5. (partial xgcd)
let mut bx = capital_bx.mod_floor(&capital_by);
let mut by = capital_by.clone();

let mut x = BigInt::one();
let mut y = BigInt::zero();
let mut z = 0u32;

while by.abs() > *self.partial_gcd_limit() && !bx.is_zero() {
let (q, t) = by.div_rem(&bx);
by = bx;
bx = t;
swap(&mut x, &mut y);
x -= &q * &y;
z += 1;
}

if z.is_odd() {
by = -by;
y = -y;
}
let (bx, x, by, y, iterated) = partial_xgcd(
capital_bx.mod_floor(&capital_by),
capital_by.clone(),
self.partial_gcd_limit(),
);

let u3: BigInt;
let w3: BigInt;
let v3: BigInt;

if z == 0 {
if !iterated {
// 6.
let q = &capital_cy * &bx;
let cx = (&q - &m) / &capital_by;
let dx = (&bx * &capital_dy - w2) / &capital_by;
u3 = &by * &capital_cy;
w3 = &bx * &cx - &g * &dx;
v3 = v2 - (&q << 1);
w3 = &bx * &cx - &g * &dx;
} else {
// 7.
let cx = (&capital_cy * &bx - &m * &x) / &capital_by;
Expand All @@ -207,18 +189,17 @@ impl QuadraticForm {
};

u3 = &by * &cy - &g * &y * &dy;
w3 = &bx * &cx - &g * &x * &dx;
v3 = &g * (&q3 + &q4) - &q1 - &q2;
w3 = &bx * &cx - &g * &x * &dx;
}

let mut form = QuadraticForm {
QuadraticForm {
a: u3,
b: v3,
c: w3,
partial_gcd_limit: self.partial_gcd_limit.clone(),
};
form.reduce();
form
}
.into_reduced()
}
}

Expand All @@ -241,61 +222,72 @@ impl Doubling for QuadraticForm {
b_divided_by_gcd: capital_dy,
} = extended_euclidean_algorithm(u, v);

let mut bx = (&y * w).mod_floor(&capital_by);
let mut by = capital_by.clone();
let (bx, x, by, y, iterated) = partial_xgcd(
(&y * w).mod_floor(&capital_by),
capital_by.clone(),
self.partial_gcd_limit(),
);

let mut x = BigInt::one();
let mut y = BigInt::zero();
let mut z = 0u32;
let mut u3 = &by * &by;
let mut w3 = &bx * &bx;
let mut v3 = &u3 + &w3 - &(&bx + &by).pow(2);

while by.abs() > *self.partial_gcd_limit() && !bx.is_zero() {
let (q, t) = by.div_rem(&bx);
by = bx;
bx = t;
swap(&mut x, &mut y);
x -= &q * &y;
z += 1;
}

if z.is_odd() {
by = -by;
y = -y;
}

let mut u3: BigInt;
let mut w3: BigInt;
let mut v3: BigInt;

if z == 0 {
if !iterated {
let dx = (&bx * &capital_dy - w) / &capital_by;
u3 = &by * &by;
w3 = &bx * &bx;
let s = &bx + &by;
v3 = v - &s * &s + &u3 + &w3;
w3 = &w3 - &g * &dx;
v3 += v;
w3 -= &g * &dx;
} else {
let dx = (&bx * &capital_dy - w * &x) / &capital_by;
let q1 = &dx * &y;
let mut dy = &q1 + &capital_dy;
v3 = &g * (&dy + &q1);
dy = &dy / &x;
u3 = &by * &by;
w3 = &bx * &bx;
v3 = &v3 - (&bx + &by).pow(2) + &u3 + &w3;

u3 = &u3 - &g * &y * &dy;
w3 = &w3 - &g * &x * &dx;
v3 += &g * (&dy + &q1);
dy /= &x;

u3 -= &g * &y * &dy;
w3 -= &g * &x * &dx;
}

let mut form = QuadraticForm {
QuadraticForm {
a: u3,
b: v3,
c: w3,
partial_gcd_limit: self.partial_gcd_limit.clone(),
};
form.reduce();
form
}
.into_reduced()
}
}

/// Compute the xgcd of bx and by with a partial limit. When by is below the limit, the computation
/// stops early and returns the result. The result is a tuple (bx, x, by, y, iterated) where bx and
/// by are the now reduced coefficients, x and y are the Bezout coefficients for bx and by respectively,
/// and iterated is true if the there were any iterations.
fn partial_xgcd(
mut bx: BigInt,
mut by: BigInt,
limit: &BigInt,
) -> (BigInt, BigInt, BigInt, BigInt, bool) {
let mut x = BigInt::one();
let mut y = BigInt::zero();
let mut iterated = false;
let mut odd = false;

while by.abs() > *limit && !bx.is_zero() {
let (q, r) = by.div_rem(&bx);
by = bx;
bx = r;
swap(&mut x, &mut y);
x -= &q * &y;

odd = !odd;
iterated = true;
}

if odd {
by = -by;
y = -y;
}

(bx, x, by, y, iterated)
}

impl ParameterizedGroupElement for QuadraticForm {
Expand All @@ -309,7 +301,7 @@ impl ParameterizedGroupElement for QuadraticForm {
.expect("Doesn't fail")
}

fn has_parameter(&self, discriminant: &Discriminant) -> bool {
fn is_in_group(&self, discriminant: &Discriminant) -> bool {
discriminant.as_bigint().eq(&self.discriminant())
}
}
Expand Down
5 changes: 5 additions & 0 deletions fastcrypto-vdf/src/class_group/reduction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ impl QuadraticForm {
self.b = cs.shl(1) - &self.b;
}
}

pub(crate) fn into_reduced(mut self) -> QuadraticForm {
self.reduce();
self
}
}

#[cfg(test)]
Expand Down
19 changes: 8 additions & 11 deletions fastcrypto-vdf/src/math/extended_gcd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//! divided by the GCD since these are often used, for example in the NUCOMP and NUDPL algorithms,
//! and come out for free while computing the Bezout coefficients.

use num_bigint::BigInt;
use num_bigint::{BigInt, Sign};
use num_integer::Integer;
use num_traits::{One, Signed, Zero};
use std::mem;
Expand Down Expand Up @@ -58,16 +58,8 @@ pub(crate) fn extended_euclidean_algorithm(a: &BigInt, b: &BigInt) -> EuclideanA
}

// The last coefficients are equal to +/- a / gcd(a,b) and b / gcd(a,b) respectively.
let a_divided_by_gcd = if a.sign() != s.0.sign() {
s.0.neg()
} else {
s.0
};
let b_divided_by_gcd = if b.sign() != t.0.sign() {
t.0.neg()
} else {
t.0
};
let a_divided_by_gcd = set_sign(s.0, a.sign());
let b_divided_by_gcd = set_sign(t.0, b.sign());

if !r.1.is_negative() {
EuclideanAlgorithmOutput {
Expand All @@ -88,6 +80,11 @@ pub(crate) fn extended_euclidean_algorithm(a: &BigInt, b: &BigInt) -> EuclideanA
}
}

/// Return a number with the same magnitude as `value` but with the given sign.
fn set_sign(value: BigInt, sign: Sign) -> BigInt {
BigInt::from_biguint(sign, value.into_parts().1)
}

#[test]
fn test_xgcd() {
test_xgcd_single(BigInt::from(240), BigInt::from(46));
Expand Down
2 changes: 1 addition & 1 deletion fastcrypto-vdf/src/math/parameterized_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ pub trait ParameterizedGroupElement:
fn zero(parameters: &Self::ParameterType) -> Self;

/// Returns true if this is an element of the group defined by `parameter`.
fn has_parameter(&self, parameter: &Self::ParameterType) -> bool;
fn is_in_group(&self, parameter: &Self::ParameterType) -> bool;
}
8 changes: 4 additions & 4 deletions fastcrypto-vdf/src/vdf/wesolowski/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl<
type ProofType = G;

fn evaluate(&self, input: &G) -> FastCryptoResult<(G, G)> {
if !input.has_parameter(&self.group_parameter) || self.iterations == 0 {
if !input.is_in_group(&self.group_parameter) || self.iterations == 0 {
return Err(InvalidInput);
}

Expand All @@ -84,9 +84,9 @@ impl<
}

fn verify(&self, input: &G, output: &G, proof: &G) -> FastCryptoResult<()> {
if !input.has_parameter(&self.group_parameter)
|| !output.has_parameter(&self.group_parameter)
|| !proof.has_parameter(&self.group_parameter)
if !input.is_in_group(&self.group_parameter)
|| !output.is_in_group(&self.group_parameter)
|| !proof.is_in_group(&self.group_parameter)
{
return Err(InvalidInput);
}
Expand Down
Loading