diff --git a/examples/helloworld.rs b/examples/helloworld.rs index 190cb5162..6ed726019 100644 --- a/examples/helloworld.rs +++ b/examples/helloworld.rs @@ -71,6 +71,12 @@ fn main() { print(&row(&a, num_rows - 1).unwrap()); print(&col(&a, num_cols - 1).unwrap()); + println!("Set last row to 1's"); + let r_dims = Dim4::new(&[3, 1, 1, 1]); + let r_input: [f32; 3] = [1.0, 1.0, 1.0]; + let r = Array::new(r_dims, &r_input, Aftype::F32).unwrap(); + print(&set_row(&a, &r, num_rows - 1).unwrap()); + println!("Create 2-by-3 matrix from host data"); let d_dims = Dim4::new(&[2, 3, 1, 1]); let d_input: [i32; 6] = [1, 2, 3, 4, 5, 6]; diff --git a/src/index.rs b/src/index.rs index 003ca3689..93bdf82e3 100644 --- a/src/index.rs +++ b/src/index.rs @@ -114,26 +114,81 @@ pub fn index(input: &Array, seqs: &[Seq]) -> Result { } } +#[allow(dead_code)] pub fn row(input: &Array, row_num: u64) -> Result { - let dims_err = input.dims(); - let dims = match dims_err { - Ok(dim) => dim.clone(), - Err(e) => panic!("Error unwrapping dims in row(): {}", e), - }; - index(input, &[Seq::new(row_num as f64, row_num as f64, 1.0) - ,Seq::new(0.0, dims[1] as f64 - 1.0, 1.0)]) + , Seq::default()]) +} + +#[allow(dead_code)] +pub fn set_row(input: &Array, new_row: &Array, row_num: u64) -> Result { + assign_seq(input, &[Seq::new(row_num as f64, row_num as f64, 1.0), Seq::default()] + , new_row) +} + +#[allow(dead_code)] +pub fn rows(input: &Array, first: u64, last: u64) -> Result { + index(input, &[Seq::new(first as f64, last as f64, 1.0), Seq::default()]) } +#[allow(dead_code)] +pub fn set_rows(input: &Array, new_rows: &Array, first: u64, last: u64) -> Result { + assign_seq(input, &[Seq::new(first as f64, last as f64, 1.0), Seq::default()] + , new_rows) +} + +#[allow(dead_code)] pub fn col(input: &Array, col_num: u64) -> Result { - let dims_err = input.dims(); - let dims = match dims_err { - Ok(dim) => dim.clone(), - Err(e) => panic!("Error unwrapping dims in row(): {}", e), - }; + index(input, &[Seq::default() + , Seq::new(col_num as f64, col_num as f64, 1.0)]) +} + +#[allow(dead_code)] +pub fn set_col(input: &Array, new_col: &Array, col_num: u64) -> Result { + assign_seq(input, &[Seq::default(), Seq::new(col_num as f64, col_num as f64, 1.0)] + , new_col) +} - index(input, &[Seq::new(0.0, dims[0] as f64 - 1.0, 1.0) - ,Seq::new(col_num as f64, col_num as f64, 1.0)]) +#[allow(dead_code)] +pub fn cols(input: &Array, first: u64, last: u64) -> Result { + index(input, &[Seq::default() + , Seq::new(first as f64, last as f64, 1.0)]) +} + +#[allow(dead_code)] +pub fn set_cols(input: &Array, new_cols: &Array, first: u64, last: u64) -> Result { + assign_seq(input, &[Seq::default(), Seq::new(first as f64, last as f64, 1.0)] + , new_cols) +} + +#[allow(dead_code)] +pub fn slice(input: &Array, slice_num: u64) -> Result { + index(input, &[Seq::default() + , Seq::default() + , Seq::new(slice_num as f64, slice_num as f64, 1.0)]) +} + +#[allow(dead_code)] +pub fn set_slice(input: &Array, new_slice: &Array, slice_num: u64) -> Result { + assign_seq(input, &[Seq::default() + , Seq::default() + , Seq::new(slice_num as f64, slice_num as f64, 1.0)] + , new_slice) +} + +#[allow(dead_code)] +pub fn slices(input: &Array, first: u64, last: u64) -> Result { + index(input, &[Seq::default() + , Seq::default() + , Seq::new(first as f64, last as f64, 1.0)]) +} + +#[allow(dead_code)] +pub fn set_slices(input: &Array, new_slices: &Array, first: u64, last: u64) -> Result { + assign_seq(input, &[Seq::default() + , Seq::default() + , Seq::new(first as f64, last as f64, 1.0)] + , new_slices) } pub fn lookup(input: &Array, indices: &Array, seq_dim: i32) -> Result { @@ -148,11 +203,11 @@ pub fn lookup(input: &Array, indices: &Array, seq_dim: i32) -> Result Result { +pub fn assign_seq(lhs: &Array, seqs: &[Seq], rhs: &Array) -> Result { unsafe{ let mut temp: i64 = 0; let err_val = af_assign_seq(&mut temp as MutAfArray, lhs.get() as AfArray, - ndims as c_uint, seqs.as_ptr() as *const Seq, + seqs.len() as c_uint, seqs.as_ptr() as *const Seq, rhs.get() as AfArray); match err_val { 0 => Ok(Array::from(temp)), diff --git a/src/lib.rs b/src/lib.rs old mode 100644 new mode 100755 index c712cb484..a819f459b --- a/src/lib.rs +++ b/src/lib.rs @@ -44,7 +44,9 @@ mod defines; pub use dim4::Dim4; mod dim4; -pub use index::{Indexer, index, row, col, lookup, assign_seq, index_gen, assign_gen}; +pub use index::{Indexer, index, row, rows, col, cols, slice, slices + , set_row, set_rows, set_col, set_cols, set_slice, set_slices + , lookup, assign_seq, index_gen, assign_gen}; mod index; pub use seq::Seq;