diff --git a/examples/helloworld.rs b/examples/helloworld.rs index e6dff241a..7d342a450 100644 --- a/examples/helloworld.rs +++ b/examples/helloworld.rs @@ -18,13 +18,13 @@ fn main() { println!("Element-wise arithmetic"); let b = sin(&a) - .and_then(|x| add(&x, &1.5)) + .and_then(|x| add(&x, &1.5, false)) .unwrap(); let b2 = sin(&a). and_then(|x| { cos(&a) - .and_then(|y| add(&x, &y)) + .and_then(|y| add(&x, &y, false)) }) .unwrap(); diff --git a/examples/histogram.rs b/examples/histogram.rs index 0984740d5..5a32b6820 100644 --- a/examples/histogram.rs +++ b/examples/histogram.rs @@ -33,7 +33,7 @@ fn main() { let disp_img = man.dims() .and_then(|x| constant(255 as f32, x)) - .and_then(|x| div(&man, &x)) + .and_then(|x| div(&man, &x, false)) .unwrap(); loop { diff --git a/examples/pi.rs b/examples/pi.rs index 1fb4ef520..9ec60d61d 100644 --- a/examples/pi.rs +++ b/examples/pi.rs @@ -18,9 +18,9 @@ fn main() { let start = PreciseTime::now(); for bench_iter in 0..100 { - let pi_val = add(&mul(x, x).unwrap(), &mul(y, y).unwrap()) + let pi_val = add(&mul(x, x, false).unwrap(), &mul(y, y, false).unwrap(), false) .and_then( |z| sqrt(&z) ) - .and_then( |z| le(&z, &constant(1, dims).unwrap()) ) + .and_then( |z| le(&z, &constant(1, dims).unwrap(), false) ) .and_then( |z| sum_all(&z) ) .map( |z| z.0 * 4.0/(samples as f64) ) .unwrap(); diff --git a/src/arith/mod.rs b/src/arith/mod.rs index 5c085d6a5..a4409efa6 100644 --- a/src/arith/mod.rs +++ b/src/arith/mod.rs @@ -223,12 +223,12 @@ impl Convertable for Array { macro_rules! overloaded_binary_func { ($fn_name: ident, $help_name: ident, $ffi_name: ident) => ( - fn $help_name(lhs: &Array, rhs: &Array) -> Result { + fn $help_name(lhs: &Array, rhs: &Array, batch: bool) -> Result { unsafe { let mut temp: i64 = 0; let err_val = $ffi_name(&mut temp as MutAfArray, lhs.get() as AfArray, rhs.get() as AfArray, - 0); + batch as c_int); match err_val { 0 => Ok(Array::from(temp)), _ => Err(AfError::from(err_val)), @@ -236,19 +236,19 @@ macro_rules! overloaded_binary_func { } } - pub fn $fn_name (arg1: &T, arg2: &U) -> Result { + pub fn $fn_name (arg1: &T, arg2: &U, batch: bool) -> Result { let lhs = arg1.convert(); let rhs = arg2.convert(); match (lhs.is_scalar().unwrap(), rhs.is_scalar().unwrap()) { ( true, false) => { let l = tile(&lhs, rhs.dims().unwrap()).unwrap(); - $help_name(&l, &rhs) + $help_name(&l, &rhs, batch) }, (false, true) => { let r = tile(&rhs, lhs.dims().unwrap()).unwrap(); - $help_name(&lhs, &r) + $help_name(&lhs, &r, batch) }, - _ => $help_name(&lhs, &rhs), + _ => $help_name(&lhs, &rhs, batch), } } )