Skip to content

Commit c15abc1

Browse files
committed
Merge pull request #30 from arrayfire/bugfix/allow_batch_operations
allow batch operands on binary functions
2 parents 8cbe77e + ce828cf commit c15abc1

File tree

4 files changed

+11
-11
lines changed

4 files changed

+11
-11
lines changed

examples/helloworld.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ fn main() {
1818

1919
println!("Element-wise arithmetic");
2020
let b = sin(&a)
21-
.and_then(|x| add(&x, &1.5))
21+
.and_then(|x| add(&x, &1.5, false))
2222
.unwrap();
2323

2424
let b2 = sin(&a).
2525
and_then(|x| {
2626
cos(&a)
27-
.and_then(|y| add(&x, &y))
27+
.and_then(|y| add(&x, &y, false))
2828
})
2929
.unwrap();
3030

examples/histogram.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ fn main() {
3333

3434
let disp_img = man.dims()
3535
.and_then(|x| constant(255 as f32, x))
36-
.and_then(|x| div(&man, &x))
36+
.and_then(|x| div(&man, &x, false))
3737
.unwrap();
3838

3939
loop {

examples/pi.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ fn main() {
1818
let start = PreciseTime::now();
1919

2020
for bench_iter in 0..100 {
21-
let pi_val = add(&mul(x, x).unwrap(), &mul(y, y).unwrap())
21+
let pi_val = add(&mul(x, x, false).unwrap(), &mul(y, y, false).unwrap(), false)
2222
.and_then( |z| sqrt(&z) )
23-
.and_then( |z| le(&z, &constant(1, dims).unwrap()) )
23+
.and_then( |z| le(&z, &constant(1, dims).unwrap(), false) )
2424
.and_then( |z| sum_all(&z) )
2525
.map( |z| z.0 * 4.0/(samples as f64) )
2626
.unwrap();

src/arith/mod.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -223,32 +223,32 @@ impl Convertable for Array {
223223

224224
macro_rules! overloaded_binary_func {
225225
($fn_name: ident, $help_name: ident, $ffi_name: ident) => (
226-
fn $help_name(lhs: &Array, rhs: &Array) -> Result<Array, AfError> {
226+
fn $help_name(lhs: &Array, rhs: &Array, batch: bool) -> Result<Array, AfError> {
227227
unsafe {
228228
let mut temp: i64 = 0;
229229
let err_val = $ffi_name(&mut temp as MutAfArray,
230230
lhs.get() as AfArray, rhs.get() as AfArray,
231-
0);
231+
batch as c_int);
232232
match err_val {
233233
0 => Ok(Array::from(temp)),
234234
_ => Err(AfError::from(err_val)),
235235
}
236236
}
237237
}
238238

239-
pub fn $fn_name<T: Convertable, U: Convertable> (arg1: &T, arg2: &U) -> Result<Array, AfError> {
239+
pub fn $fn_name<T: Convertable, U: Convertable> (arg1: &T, arg2: &U, batch: bool) -> Result<Array, AfError> {
240240
let lhs = arg1.convert();
241241
let rhs = arg2.convert();
242242
match (lhs.is_scalar().unwrap(), rhs.is_scalar().unwrap()) {
243243
( true, false) => {
244244
let l = tile(&lhs, rhs.dims().unwrap()).unwrap();
245-
$help_name(&l, &rhs)
245+
$help_name(&l, &rhs, batch)
246246
},
247247
(false, true) => {
248248
let r = tile(&rhs, lhs.dims().unwrap()).unwrap();
249-
$help_name(&lhs, &r)
249+
$help_name(&lhs, &r, batch)
250250
},
251-
_ => $help_name(&lhs, &rhs),
251+
_ => $help_name(&lhs, &rhs, batch),
252252
}
253253
}
254254
)

0 commit comments

Comments
 (0)