diff --git a/build.conf b/build.conf index 68d9fa4d7..56c5e6491 100644 --- a/build.conf +++ b/build.conf @@ -1,5 +1,5 @@ { - "use_backend": "cuda", + "use_backend": "cpu", "use_lib": false, "lib_dir": "/usr/local/lib", @@ -7,7 +7,7 @@ "build_type": "Release", "build_threads": "4", - "build_cuda": "ON", + "build_cuda": "OFF", "build_opencl": "ON", "build_cpu": "ON", "build_examples": "OFF", @@ -28,7 +28,7 @@ "glew_dir": "E:\\Libraries\\GLEW", "glfw_dir": "E:\\Libraries\\glfw3", "boost_dir": "E:\\Libraries\\boost_1_56_0", - + "cuda_sdk": "/usr/local/cuda", "opencl_sdk": "/usr", "sdk_lib_dir": "lib" diff --git a/examples/helloworld.rs b/examples/helloworld.rs index fbd831225..a5dfb6719 100644 --- a/examples/helloworld.rs +++ b/examples/helloworld.rs @@ -3,6 +3,7 @@ extern crate arrayfire as af; use af::Dim4; use af::Array; +#[allow(unused_must_use)] fn main() { af::set_device(0); af::info(); @@ -14,10 +15,9 @@ fn main() { af::print(&a); println!("Element-wise arithmetic"); - let sin_res = af::sin(&a).unwrap(); - let cos_res = af::cos(&a).unwrap(); - let b = &sin_res + 1.5; - let b2 = &sin_res + &cos_res; + let b = af::add(af::sin(&a), 1.5).unwrap(); + let b2 = af::add(af::sin(&a), af::cos(&a)).unwrap(); + let b3 = ! &a; println!("sin(a) + 1.5 => "); af::print(&b); println!("sin(a) + cos(a) => "); af::print(&b2); diff --git a/src/arith/mod.rs b/src/arith/mod.rs index 55133fab0..2a655077a 100644 --- a/src/arith/mod.rs +++ b/src/arith/mod.rs @@ -1,10 +1,11 @@ extern crate libc; extern crate num; +use dim4::Dim4; use array::Array; use defines::AfError; use self::libc::{c_int}; -use data::constant; +use data::{constant, tile}; use self::num::Complex; type MutAfArray = *mut self::libc::c_longlong; @@ -182,32 +183,100 @@ macro_rules! binary_func { ) } -binary_func!(add, af_add); -binary_func!(sub, af_sub); -binary_func!(mul, af_mul); -binary_func!(div, af_div); -binary_func!(rem, af_rem); binary_func!(bitand, af_bitand); binary_func!(bitor, af_bitor); binary_func!(bitxor, af_bitxor); -binary_func!(shiftl, af_bitshiftl); -binary_func!(shiftr, af_bitshiftr); -binary_func!(lt, af_lt); -binary_func!(gt, af_gt); -binary_func!(le, af_le); -binary_func!(ge, af_ge); -binary_func!(eq, af_eq); binary_func!(neq, af_neq); binary_func!(and, af_and); binary_func!(or, af_or); binary_func!(minof, af_minof); binary_func!(maxof, af_maxof); -binary_func!(modulo, af_mod); binary_func!(hypot, af_hypot); -binary_func!(atan2, af_atan2); -binary_func!(cplx2, af_cplx2); -binary_func!(root, af_root); -binary_func!(pow, af_pow); + +pub trait Convertable { + fn convert(&self) -> Array; +} + +macro_rules! convertable_type_def { + ($rust_type: ty) => ( + impl Convertable for $rust_type { + fn convert(&self) -> Array { + constant(*self, Dim4::new(&[1,1,1,1])).unwrap() + } + } + ) +} + +convertable_type_def!(f64); +convertable_type_def!(f32); +convertable_type_def!(i32); +convertable_type_def!(u32); +convertable_type_def!(u8); + +impl Convertable for Array { + fn convert(&self) -> Array { + self.clone() + } +} + +impl Convertable for Result { + fn convert(&self) -> Array { + self.clone().unwrap() + } +} + +macro_rules! overloaded_binary_func { + ($fn_name: ident, $help_name: ident, $ffi_name: ident) => ( + fn $help_name(lhs: &Array, rhs: &Array) -> Result { + unsafe { + let mut temp: i64 = 0; + let err_val = $ffi_name(&mut temp as MutAfArray, + lhs.get() as AfArray, rhs.get() as AfArray, + 0); + match err_val { + 0 => Ok(Array::from(temp)), + _ => Err(AfError::from(err_val)), + } + } + } + + pub fn $fn_name (arg1: T, arg2: U) -> Result { + let lhs = arg1.convert(); + let rhs = arg2.convert(); + match (lhs.is_scalar().unwrap(), rhs.is_scalar().unwrap()) { + ( true, false) => { + let l = tile(&lhs, rhs.dims().unwrap()).unwrap(); + $help_name(&l, &rhs) + }, + (false, true) => { + let r = tile(&rhs, lhs.dims().unwrap()).unwrap(); + $help_name(&lhs, &r) + }, + _ => $help_name(&lhs, &rhs), + } + } + ) +} + +// thanks to Umar Arshad for the idea on how to +// implement overloaded function +overloaded_binary_func!(add, add_helper, af_add); +overloaded_binary_func!(sub, sub_helper, af_sub); +overloaded_binary_func!(mul, mul_helper, af_mul); +overloaded_binary_func!(div, div_helper, af_div); +overloaded_binary_func!(rem, rem_helper, af_rem); +overloaded_binary_func!(shiftl, shiftl_helper, af_bitshiftl); +overloaded_binary_func!(shiftr, shiftr_helper, af_bitshiftr); +overloaded_binary_func!(lt, lt_helper, af_lt); +overloaded_binary_func!(gt, gt_helper, af_gt); +overloaded_binary_func!(le, le_helper, af_le); +overloaded_binary_func!(ge, ge_helper, af_ge); +overloaded_binary_func!(eq, eq_helper, af_eq); +overloaded_binary_func!(modulo, modulo_helper, af_mod); +overloaded_binary_func!(atan2, atan2_helper, af_atan2); +overloaded_binary_func!(cplx2, cplx2_helper, af_cplx2); +overloaded_binary_func!(root, root_helper, af_root); +overloaded_binary_func!(pow, pow_helper, af_pow); macro_rules! arith_scalar_func { ($rust_type: ty, $op_name:ident, $fn_name: ident, $ffi_fn: ident) => ( diff --git a/src/array.rs b/src/array.rs index 68d320446..18ead3671 100644 --- a/src/array.rs +++ b/src/array.rs @@ -57,6 +57,8 @@ extern { fn af_retain_array(out: MutAfArray, arr: AfArray) -> c_int; + fn af_copy_array(out: MutAfArray, arr: AfArray) -> c_int; + fn af_release_array(arr: AfArray) -> c_int; fn af_print_array(arr: AfArray) -> c_int; @@ -171,6 +173,17 @@ impl Array { } } + pub fn copy(&self) -> Result { + unsafe { + let mut temp: i64 = 0; + let err_val = af_copy_array(&mut temp as MutAfArray, self.handle as AfArray); + match err_val { + 0 => Ok(Array::from(temp)), + _ => Err(AfError::from(err_val)), + } + } + } + is_func!(is_empty, af_is_empty); is_func!(is_scalar, af_is_scalar); is_func!(is_row, af_is_row); diff --git a/src/data/mod.rs b/src/data/mod.rs index d4efc4b4e..ae93b92b8 100644 --- a/src/data/mod.rs +++ b/src/data/mod.rs @@ -133,12 +133,16 @@ impl ConstGenerator for Complex { #[allow(unused_mut)] impl ConstGenerator for bool { - fn generate(&self, dims: Dim4) -> Array { + fn generate(&self, dims: Dim4) -> Result { unsafe { let mut temp: i64 = 0; - af_constant(&mut temp as MutAfArray, *self as c_int as c_double, - dims.ndims() as c_uint, dims.get().as_ptr() as *const DimT, 4); - Array::from(temp) + let err_val = af_constant(&mut temp as MutAfArray, *self as c_int as c_double, + dims.ndims() as c_uint, + dims.get().as_ptr() as *const DimT, 4); + match err_val { + 0 => Ok(Array::from(temp)), + _ => Err(AfError::from(err_val)), + } } } } diff --git a/src/dim4.rs b/src/dim4.rs index 5abe3c848..e63455c3d 100644 --- a/src/dim4.rs +++ b/src/dim4.rs @@ -39,7 +39,7 @@ impl Dim4 { let nelems = self.elements(); match nelems { 0 => 0, - 1 => 0, + 1 => 1, _ => { if self.dims[3] != 1 { 4 } else if self.dims[2] != 1 { 3 }