Skip to content

Commit

Permalink
feat: use cache BR for bit_reverse
Browse files Browse the repository at this point in the history
feat: parallel bit_reverse
  • Loading branch information
SuccinctPaul committed Oct 24, 2023
1 parent a9441a1 commit ca7e18e
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 19 deletions.
95 changes: 76 additions & 19 deletions starky/src/fft_p.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,67 @@ use crate::constant::{get_max_workers, MAX_OPS_PER_THREAD, MIN_OPS_PER_THREAD, S
use crate::fft_worker::{fft_block, interpolate_prepare_block};
use crate::helper::log2_any;
use crate::traits::FieldExtension;
use crate::utils::parallells::parallelize;
use core::cmp::min;
use lazy_static::lazy_static;
use rayon::prelude::*;
use std::collections::HashMap;
use std::sync::Mutex;

lazy_static! {
static ref BR_CACHE: Mutex<HashMap<usize, Vec<usize>>> = Mutex::new(HashMap::new());
}
pub fn BR(x: usize, domain_pow: usize) -> usize {
assert!(domain_pow <= 32);
let mut x = x;
x = (x >> 16) | (x << 16);
x = ((x & 0xFF00FF00) >> 8) | ((x & 0x00FF00FF) << 8);
x = ((x & 0xF0F0F0F0) >> 4) | ((x & 0x0F0F0F0F) << 4);
x = ((x & 0xCCCCCCCC) >> 2) | ((x & 0x33333333) << 2);
(((x & 0xAAAAAAAA) >> 1) | ((x & 0x55555555) << 1)) >> (32 - domain_pow)
let cal = |x: usize, domain_pow: usize| -> usize {
let mut x = x;
x = (x >> 16) | (x << 16);
x = ((x & 0xFF00FF00) >> 8) | ((x & 0x00FF00FF) << 8);
x = ((x & 0xF0F0F0F0) >> 4) | ((x & 0x0F0F0F0F) << 4);
x = ((x & 0xCCCCCCCC) >> 2) | ((x & 0x33333333) << 2);
(((x & 0xAAAAAAAA) >> 1) | ((x & 0x55555555) << 1)) >> (32 - domain_pow)
};

// get cache by domain_pow
let mut map = BR_CACHE.lock().unwrap();
let mut cache = if map.contains_key(&domain_pow) {
map.remove(&domain_pow).unwrap() // get and remove the old values.
} else {
vec![]
};
// check if need append more to cache
let cache_len = cache.len();
let n = 1 << domain_pow;
if cache_len <= n || cache_len < x {
let end = if n >= x { n } else { x };
// todo parallel
for i in cache_len..=end {
let a = cal(i, domain_pow);
cache.push(a);
}
}
let res = cache[x];
// update map with cache
map.insert(domain_pow, cache);
res
}
fn BRs(start: usize, end: usize, domain_pow: usize) -> Vec<usize> {
assert!(end > start);
// 1. obtain a useless one to precompute the cache.
// to make sure the cache existed and its len >= end.
BR(end, domain_pow);

// 2. get cache by domain_pow
let map = BR_CACHE.lock().unwrap();
let cache = if map.contains_key(&domain_pow) {
map.get(&domain_pow).unwrap()
} else {
// double check
BR(end, domain_pow);
map.get(&domain_pow).unwrap()
};

(start..end).map(|i| cache[i]).collect()
}

pub fn transpose<F: FieldExtension>(
Expand Down Expand Up @@ -44,11 +94,14 @@ pub fn bit_reverse<F: FieldExtension>(
nbits: usize,
) {
let n = 1 << nbits;
for i in 0..n {
let ri = BR(i, nbits);
for k in 0..n_pols {
buffdst[i * n_pols + k] = buffsrc[ri * n_pols + k];
}
let ris = BRs(0, n, nbits); // move it outside the loop. obtain it from cache.

let len = n * n_pols;
assert_eq!(len, buffdst.len());
for j in 0..len {
let i = j / n_pols;
let k = j % n_pols;
buffdst[j] = buffsrc[ris[i] * n_pols + k];
}
}

Expand All @@ -59,9 +112,10 @@ pub fn interpolate_bit_reverse<F: FieldExtension>(
nbits: usize,
) {
let n = 1 << nbits;
let ris = BRs(0, n, nbits); // move it outside the loop. obtain it from cache.

for i in 0..n {
let ri = BR(i, nbits);
let rii = (n - ri) % n;
let rii = (n - ris[i]) % n;
for k in 0..n_pols {
buffdst[i * n_pols + k] = buffsrc[rii * n_pols + k];
}
Expand All @@ -76,12 +130,15 @@ pub fn inv_bit_reverse<F: FieldExtension>(
) {
let n = 1 << nbits;
let n_inv = F::inv(&F::from(n));
for i in 0..n {
let ri = BR(i, nbits);
let rii = (n - ri) % n;
for p in 0..n_pols {
buffdst[i * n_pols + p] = buffsrc[rii * n_pols + p] * n_inv;
}
let ris = BRs(0, n, nbits); // move it outside the loop. obtain it from cache.

let len = n * n_pols;
assert_eq!(len, buffdst.len());
for j in 0..len {
let i = j / n_pols;
let k = j % n_pols;
let rii = (n - ris[i]) % n;
buffdst[j] = buffsrc[rii * n_pols + k] * n_inv;
}
}

Expand Down
2 changes: 2 additions & 0 deletions starky/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#![allow(clippy::needless_range_loop)]
#![allow(dead_code)]

pub mod errors;
pub mod polsarray;
Expand Down Expand Up @@ -31,6 +32,7 @@ pub mod poseidon_bls12381_opt;

pub mod merklehash;
pub mod merklehash_bls12381;

pub mod merklehash_bn128;

mod digest;
Expand Down

0 comments on commit ca7e18e

Please sign in to comment.