Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions src/uint/boxed/mul.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
//! [`BoxedUint`] multiplication operations.

use crate::{uint::mul::mul_limbs, BoxedUint, CheckedMul, Limb, WideningMul, Wrapping, Zero};
use crate::{
uint::mul::{diagonal_mul_limbs, half_mul_limbs, mul_limbs},
BoxedUint, CheckedMul, Limb, WideningMul, Wrapping, Zero,
};
use core::ops::{Mul, MulAssign};
use subtle::{Choice, CtOption};

Expand All @@ -9,9 +12,9 @@ impl BoxedUint {
///
/// Returns a widened output with a limb count equal to the sums of the input limb counts.
pub fn mul(&self, rhs: &Self) -> Self {
let mut limbs = vec![Limb::ZERO; self.nlimbs() + rhs.nlimbs()];
mul_limbs(&self.limbs, &rhs.limbs, &mut limbs);
limbs.into()
let mut out = vec![Limb::ZERO; self.nlimbs() + rhs.nlimbs()];
mul_limbs(&self.limbs, &rhs.limbs, &mut out);
out.into()
}

/// Perform wrapping multiplication, wrapping to the width of `self`.
Expand All @@ -21,8 +24,11 @@ impl BoxedUint {

/// Multiply `self` by itself.
pub fn square(&self) -> Self {
// TODO(tarcieri): more optimized implementation (shared with `Uint`?)
self.mul(self)
let mut out = Self::from(vec![Limb::ZERO; self.nlimbs() * 2]);
half_mul_limbs(&self.limbs, &mut out.limbs);
out <<= 1;
diagonal_mul_limbs(&self.limbs, &mut out.limbs);
out
}
}

Expand Down
156 changes: 94 additions & 62 deletions src/uint/mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use subtle::CtOption;
/// which will also be reused for `BoxedUint`.
// TODO(tarcieri): change this into a `const fn` when `const_mut_refs` is stable
macro_rules! impl_schoolbook_multiplication {
($lhs:expr, $rhs:expr, $lo:expr, $hi:expr) => {{
($lhs:expr, $rhs:expr, $lo:expr, $hi:expr) => {
let mut i = 0;
while i < $lhs.len() {
let mut j = 0;
Expand Down Expand Up @@ -47,7 +47,80 @@ macro_rules! impl_schoolbook_multiplication {
}
i += 1;
}
}};
};
}

/// Impl multiplication considering half of the grid, as used by the squaring algorithm.
///
/// Like [`impl_schoolbook_multiplication`], this is implemented as a macro to abstract over
/// `const fn` and boxed use cases.
macro_rules! impl_half_multiply {
($input:expr, $lo:expr, $hi:expr) => {
// Schoolbook multiplication, but only considering half of the multiplication grid
let mut i = 1;
while i < $input.len() {
let mut j = 0;
let mut carry = Limb::ZERO;

while j < i {
let k = i + j;

if k >= $input.len() {
let (n, c) = $hi[k - $input.len()].mac($input[i], $input[j], carry);
$hi[k - $input.len()] = n;
carry = c;
} else {
let (n, c) = $lo[k].mac($input[i], $input[j], carry);
$lo[k] = n;
carry = c;
}

j += 1;
}

if (2 * i) < $input.len() {
$lo[2 * i] = carry;
} else {
$hi[2 * i - $input.len()] = carry;
}

i += 1;
}
};
}

/// Handle the diagonal of the multiplication grid, which finishes the multiplication grid.
///
/// This is used by the squaring algorithm, and impl'd as a macro to abstract over `const fn` and
/// boxed use cases. It's intended to be used in conjucntion with [`impl_half_multiply`].
macro_rules! impl_diagonal_multiply {
($input:expr, $lo:expr, $hi:expr) => {
let mut carry = Limb::ZERO;
let mut i = 0;
while i < $input.len() {
if (i * 2) < $input.len() {
let (n, c) = $lo[i * 2].mac($input[i], $input[i], carry);
$lo[i * 2] = n;
carry = c;
} else {
let (n, c) = $hi[i * 2 - $input.len()].mac($input[i], $input[i], carry);
$hi[i * 2 - $input.len()] = n;
carry = c;
}

if (i * 2 + 1) < $input.len() {
let n = $lo[i * 2 + 1].0 as WideWord + carry.0 as WideWord;
$lo[i * 2 + 1] = Limb(n as Word);
carry = Limb((n >> Word::BITS) as Word);
} else {
let n = $hi[i * 2 + 1 - $input.len()].0 as WideWord + carry.0 as WideWord;
$hi[i * 2 + 1 - $input.len()] = Limb(n as Word);
carry = Limb((n >> Word::BITS) as Word);
}

i += 1;
}
};
}

impl<const LIMBS: usize> Uint<LIMBS> {
Expand Down Expand Up @@ -101,69 +174,12 @@ impl<const LIMBS: usize> Uint<LIMBS> {
// by the original author Sam Kumar: https://github.com/RustCrypto/crypto-bigint/pull/133#discussion_r1056870411
let mut lo = Self::ZERO;
let mut hi = Self::ZERO;

// Schoolbook multiplication, but only considering half of the multiplication grid
let mut i = 1;
while i < LIMBS {
let mut j = 0;
let mut carry = Limb::ZERO;

while j < i {
let k = i + j;

if k >= LIMBS {
let (n, c) = hi.limbs[k - LIMBS].mac(self.limbs[i], self.limbs[j], carry);
hi.limbs[k - LIMBS] = n;
carry = c;
} else {
let (n, c) = lo.limbs[k].mac(self.limbs[i], self.limbs[j], carry);
lo.limbs[k] = n;
carry = c;
}

j += 1;
}

if (2 * i) < LIMBS {
lo.limbs[2 * i] = carry;
} else {
hi.limbs[2 * i - LIMBS] = carry;
}

i += 1;
}
impl_half_multiply!(self.limbs, lo.limbs, hi.limbs);

// Double the current result, this accounts for the other half of the multiplication grid.
// TODO: The top word is empty so we can also use a special purpose shl.
(lo, hi) = Self::shl_vartime_wide((lo, hi), 1);

// Handle the diagonal of the multiplication grid, which finishes the multiplication grid.
let mut carry = Limb::ZERO;
let mut i = 0;
while i < LIMBS {
if (i * 2) < LIMBS {
let (n, c) = lo.limbs[i * 2].mac(self.limbs[i], self.limbs[i], carry);
lo.limbs[i * 2] = n;
carry = c;
} else {
let (n, c) = hi.limbs[i * 2 - LIMBS].mac(self.limbs[i], self.limbs[i], carry);
hi.limbs[i * 2 - LIMBS] = n;
carry = c;
}

if (i * 2 + 1) < LIMBS {
let n = lo.limbs[i * 2 + 1].0 as WideWord + carry.0 as WideWord;
lo.limbs[i * 2 + 1] = Limb(n as Word);
carry = Limb((n >> Word::BITS) as Word);
} else {
let n = hi.limbs[i * 2 + 1 - LIMBS].0 as WideWord + carry.0 as WideWord;
hi.limbs[i * 2 + 1 - LIMBS] = Limb(n as Word);
carry = Limb((n >> Word::BITS) as Word);
}

i += 1;
}

impl_diagonal_multiply!(self.limbs, lo.limbs, hi.limbs);
(lo, hi)
}
}
Expand Down Expand Up @@ -361,14 +377,30 @@ where
}
}

/// Wrapper function used by `BoxedUint`
/// Wrapper function for schoolbook multiplication used by `BoxedUint`.
#[cfg(feature = "alloc")]
pub(crate) fn mul_limbs(lhs: &[Limb], rhs: &[Limb], out: &mut [Limb]) {
debug_assert_eq!(lhs.len() + rhs.len(), out.len());
let (lo, hi) = out.split_at_mut(lhs.len());
impl_schoolbook_multiplication!(lhs, rhs, lo, hi);
}

/// Wrapper function for half multiplication used by `BoxedUint` to implement squarings.
#[cfg(feature = "alloc")]
pub(crate) fn half_mul_limbs(input: &[Limb], output: &mut [Limb]) {
debug_assert_eq!(input.len() * 2, output.len());
let (lo, hi) = output.split_at_mut(input.len());
impl_half_multiply!(input, lo, hi);
}

/// Wrapper function for diagonal multiplication used by `BoxedUint` to implement squarings.
#[cfg(feature = "alloc")]
pub(crate) fn diagonal_mul_limbs(input: &[Limb], output: &mut [Limb]) {
debug_assert_eq!(input.len() * 2, output.len());
let (lo, hi) = output.split_at_mut(input.len());
impl_diagonal_multiply!(input, lo, hi);
}

#[cfg(test)]
mod tests {
use crate::{CheckedMul, Zero, U128, U192, U256, U64};
Expand Down
10 changes: 9 additions & 1 deletion tests/boxed_uint_proptests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ proptest! {
}

#[test]
fn mul_wide(a in uint(), b in uint()) {
fn mul(a in uint(), b in uint()) {
let a_bi = to_biguint(&a);
let b_bi = to_biguint(&b);

Expand All @@ -157,6 +157,14 @@ proptest! {
prop_assert_eq!(expected, to_biguint(&actual));
}

#[test]
fn square(n in uint()) {
let n_bi = to_biguint(&n);
let expected = &n_bi * &n_bi;
let actual = n.square();
prop_assert_eq!(expected, to_biguint(&actual));
}

#[test]
fn rem_vartime((a, b) in uint_pair()) {
if bool::from(!b.is_zero()) {
Expand Down