Skip to content

Commit

Permalink
impl mul by i32
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenfeizhang committed Sep 12, 2024
1 parent 9c00ff9 commit 909e430
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 10 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ ark-ff = { version = "0.4" }
bytes = "1.6.0"
chrono = "0.4.38"
clap = { version = "4.1", features = ["derive"] }
cfg-if = "1.0"
criterion = { version = "0.5", features = ["html_reports"] }
env_logger = "0.11.3"
halo2curves = { git = "https://github.com/PolyhedraZK/halo2curves", default-features = false, features = [
Expand All @@ -58,6 +59,7 @@ halo2curves = { git = "https://github.com/PolyhedraZK/halo2curves", default-feat
itertools = "0.13"
log = "0.4"
rand = "0.8.5"
raw-cpuid = "11.1.0"
rayon = "1.10"
sha2 = "0.10.8"
tiny-keccak = { version = "2.0.2", features = [ "sha3" ] }
Expand Down
7 changes: 3 additions & 4 deletions arith/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@ edition = "2021"

[dependencies]
ark-std.workspace = true
cfg-if.workspace = true
ethnum.workspace = true
halo2curves.workspace = true
log.workspace = true
rand.workspace = true
raw-cpuid.workspace = true
sha2.workspace = true
thiserror.workspace = true
ethnum.workspace = true

raw-cpuid = "11.1.0"
cfg-if = "1.0"

[dev-dependencies]
tynm.workspace = true
Expand Down
40 changes: 34 additions & 6 deletions arith/benches/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use arith::{Field, GF2_128x8, GF2x8, M31Ext3, M31Ext3x16, M31x16, GF2, GF2_128, M31};
#[cfg(target_arch = "x86_64")]
use arith::{GF2_128x8_256, M31x16_256};
use ark_std::rand::RngCore;
use ark_std::test_rng;
use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
use halo2curves::bn256::Fr;
Expand Down Expand Up @@ -173,13 +174,40 @@ pub(crate) fn bench_field<F: Field>(c: &mut Criterion) {
);
}

fn bench_mul_i32<F: Field>(c: &mut Criterion) {
let mut rng = test_rng();

c.bench_function(
&format!(
"mul-i32<{}> 100x times {}x",
type_name::<F>(),
F::SIZE * 8 / F::FIELD_SIZE
),
|b| {
b.iter_batched(
|| random_element::<F>(),
|mut x| {
let b = rng.next_u32() as i32;

for _ in 0..100 {
x = x.mul_by_i32(b)
}
x
},
BatchSize::SmallInput,
)
},
);
}

fn criterion_benchmark(c: &mut Criterion) {
bench_field::<M31>(c);
bench_field::<M31x16>(c);
#[cfg(target_arch = "x86_64")]
bench_field::<M31x16_256>(c);
bench_field::<M31Ext3>(c);
bench_field::<M31Ext3x16>(c);
bench_mul_i32::<Fr>(c);
// bench_field::<M31>(c);
// bench_field::<M31x16>(c);
// #[cfg(target_arch = "x86_64")]
// bench_field::<M31x16_256>(c);
// bench_field::<M31Ext3>(c);
// bench_field::<M31Ext3x16>(c);
bench_field::<Fr>(c);
bench_field::<GF2>(c);
bench_field::<GF2x8>(c);
Expand Down
5 changes: 5 additions & 0 deletions arith/src/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ pub trait Field:
let t = self.mul_by_3();
t + t
}

#[inline(always)]
fn mul_by_i32(&self, _b: i32) -> Self {
unimplemented!("not supported for this field")
}
}

pub trait FieldForECC: Field + Hash + Eq + PartialOrd + Ord {
Expand Down
69 changes: 69 additions & 0 deletions arith/src/field/bn254.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::io::{Read, Write};
use std::mem::transmute;

use halo2curves::ff::{Field as Halo2Field, FromUniformBytes};
use halo2curves::{bn256::Fr, ff::PrimeField};
Expand Down Expand Up @@ -108,6 +109,74 @@ impl Field for Fr {
.unwrap(),
)
}

#[inline]
// Faster multiplication when the other element is not in Montgomery form
//
// # setup
// p = 21888242871839275222246405745257275088548364400416034343698204186575808495617
// N = 2^256
// r = N % p
// r_inv = 1/r % N
//
// # initialization for montgomery form
// a = 123456789
// a_mont = a * r % N
// b = 987654321
// b_mont = b * r % N
//
// # default multiplication
// c = a * b
// c_mont = c * r % N
// c_mont2 = a_mont * b_mont * r_inv % N
// print("default mul is correct:", c_mont==c_mont2)
//
// # short circuit multiplication
// # a is in u32 and does not require to be converted to montgomery
// # this requires one u32 * u256 operation
// # v.s. default mul requires two u256 * u256 operations
// c_mont3 = a * b_mont % N
// print("fast mul is correct:", c_mont==c_mont3)
//
//
// warning: currently the code failed the test
// it generates a same Fr field element, but the montgomery form is different
fn mul_by_i32(&self, rhs: i32) -> Self {
// the following is more efficient than directly performing u256 * u256
unsafe {
let [a, b, c, d] = transmute::<Fr, [u64; 4]>(*self);
let (sign, rhs_unsigned) = if rhs < 0 {
(false, -rhs as u32)
} else {
(true, rhs as u32)
};

let ar128 = u128::from(a) * rhs_unsigned as u128;
let br128 = u128::from(b) * rhs_unsigned as u128;
let cr128 = u128::from(c) * rhs_unsigned as u128;
let dr64 = d * rhs_unsigned as u64; // we don't care about the upper 64 bits

let [ar_low, ar_high] = transmute::<u128, [u64; 2]>(ar128);
let [br_low, br_high] = transmute::<u128, [u64; 2]>(br128);
let [cr_low, cr_high] = transmute::<u128, [u64; 2]>(cr128);

let mut result = [0u64; 4];
let mut carry = false;

result[0] = ar_low;
(result[1], carry) = br_low.overflowing_add(ar_high);
(result[2], carry) = cr_low.overflowing_add(br_high + carry as u64); // br_high is 32 bits so we can add without overflow
result[3] = dr64 + cr_high + carry as u64; // don't care for overflow

let fr = transmute::<[u64; 4], Fr>(result);

if sign {
fr
} else {
-fr
}
}
}
}

impl FieldForECC for Fr {
Expand Down
29 changes: 29 additions & 0 deletions arith/src/tests/bn254.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use ark_std::test_rng;
use halo2curves::bn256::Fr;
use halo2curves::ff::Field as Halo2Field;
use rand::RngCore;

use super::field::{random_field_tests, random_inversion_tests};
use crate::Field;

#[test]
fn test_field() {
Expand All @@ -10,3 +13,29 @@ fn test_field() {
let mut rng = test_rng();
random_inversion_tests::<Fr, _>(&mut rng, "bn254::Fr".to_string());
}

#[test]
fn test_mul_by_i32() {
let mut rng = test_rng();
let a = Fr::random(&mut rng);
// let a = Fr::one();
let b = test_rng().next_u32() as i32;
// let b = 1;
// let b = -2;

let b_fr = if b < 0 {
Fr::from((-b) as u64).neg()
} else {
Fr::from(b as u64)
};

let c = a.mul_by_i32(b);
let c2 = a * b_fr;
// unsafe {
// assert_eq!(
// std::mem::transmute::<_, [u64; 4]>(c),
// std::mem::transmute::<_, [u64; 4]>(c2)
// );
// }
assert_eq!(c, c2);
}

0 comments on commit 909e430

Please sign in to comment.