Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/helloworld.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ fn main() {
print(&row(&a, num_rows - 1).unwrap());
print(&col(&a, num_cols - 1).unwrap());

println!("Set last row to 1's");
let r_dims = Dim4::new(&[3, 1, 1, 1]);
let r_input: [f32; 3] = [1.0, 1.0, 1.0];
let r = Array::new(r_dims, &r_input, Aftype::F32).unwrap();
print(&set_row(&a, &r, num_rows - 1).unwrap());

println!("Create 2-by-3 matrix from host data");
let d_dims = Dim4::new(&[2, 3, 1, 1]);
let d_input: [i32; 6] = [1, 2, 3, 4, 5, 6];
Expand Down
87 changes: 71 additions & 16 deletions src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,26 +114,81 @@ pub fn index(input: &Array, seqs: &[Seq]) -> Result<Array, AfError> {
}
}

#[allow(dead_code)]
pub fn row(input: &Array, row_num: u64) -> Result<Array, AfError> {
let dims_err = input.dims();
let dims = match dims_err {
Ok(dim) => dim.clone(),
Err(e) => panic!("Error unwrapping dims in row(): {}", e),
};

index(input, &[Seq::new(row_num as f64, row_num as f64, 1.0)
,Seq::new(0.0, dims[1] as f64 - 1.0, 1.0)])
, Seq::default()])
}

#[allow(dead_code)]
pub fn set_row(input: &Array, new_row: &Array, row_num: u64) -> Result<Array, AfError> {
assign_seq(input, &[Seq::new(row_num as f64, row_num as f64, 1.0), Seq::default()]
, new_row)
}

#[allow(dead_code)]
pub fn rows(input: &Array, first: u64, last: u64) -> Result<Array, AfError> {
index(input, &[Seq::new(first as f64, last as f64, 1.0), Seq::default()])
}

#[allow(dead_code)]
pub fn set_rows(input: &Array, new_rows: &Array, first: u64, last: u64) -> Result<Array, AfError> {
assign_seq(input, &[Seq::new(first as f64, last as f64, 1.0), Seq::default()]
, new_rows)
}

#[allow(dead_code)]
pub fn col(input: &Array, col_num: u64) -> Result<Array, AfError> {
let dims_err = input.dims();
let dims = match dims_err {
Ok(dim) => dim.clone(),
Err(e) => panic!("Error unwrapping dims in row(): {}", e),
};
index(input, &[Seq::default()
, Seq::new(col_num as f64, col_num as f64, 1.0)])
}

#[allow(dead_code)]
pub fn set_col(input: &Array, new_col: &Array, col_num: u64) -> Result<Array, AfError> {
assign_seq(input, &[Seq::default(), Seq::new(col_num as f64, col_num as f64, 1.0)]
, new_col)
}

index(input, &[Seq::new(0.0, dims[0] as f64 - 1.0, 1.0)
,Seq::new(col_num as f64, col_num as f64, 1.0)])
#[allow(dead_code)]
pub fn cols(input: &Array, first: u64, last: u64) -> Result<Array, AfError> {
index(input, &[Seq::default()
, Seq::new(first as f64, last as f64, 1.0)])
}

#[allow(dead_code)]
pub fn set_cols(input: &Array, new_cols: &Array, first: u64, last: u64) -> Result<Array, AfError> {
assign_seq(input, &[Seq::default(), Seq::new(first as f64, last as f64, 1.0)]
, new_cols)
}

#[allow(dead_code)]
pub fn slice(input: &Array, slice_num: u64) -> Result<Array, AfError> {
index(input, &[Seq::default()
, Seq::default()
, Seq::new(slice_num as f64, slice_num as f64, 1.0)])
}

#[allow(dead_code)]
pub fn set_slice(input: &Array, new_slice: &Array, slice_num: u64) -> Result<Array, AfError> {
assign_seq(input, &[Seq::default()
, Seq::default()
, Seq::new(slice_num as f64, slice_num as f64, 1.0)]
, new_slice)
}

#[allow(dead_code)]
pub fn slices(input: &Array, first: u64, last: u64) -> Result<Array, AfError> {
index(input, &[Seq::default()
, Seq::default()
, Seq::new(first as f64, last as f64, 1.0)])
}

#[allow(dead_code)]
pub fn set_slices(input: &Array, new_slices: &Array, first: u64, last: u64) -> Result<Array, AfError> {
assign_seq(input, &[Seq::default()
, Seq::default()
, Seq::new(first as f64, last as f64, 1.0)]
, new_slices)
}

pub fn lookup(input: &Array, indices: &Array, seq_dim: i32) -> Result<Array, AfError> {
Expand All @@ -148,11 +203,11 @@ pub fn lookup(input: &Array, indices: &Array, seq_dim: i32) -> Result<Array, AfE
}
}

pub fn assign_seq(lhs: &Array, ndims: usize, seqs: &[Seq], rhs: &Array) -> Result<Array, AfError> {
pub fn assign_seq(lhs: &Array, seqs: &[Seq], rhs: &Array) -> Result<Array, AfError> {
unsafe{
let mut temp: i64 = 0;
let err_val = af_assign_seq(&mut temp as MutAfArray, lhs.get() as AfArray,
ndims as c_uint, seqs.as_ptr() as *const Seq,
seqs.len() as c_uint, seqs.as_ptr() as *const Seq,
rhs.get() as AfArray);
match err_val {
0 => Ok(Array::from(temp)),
Expand Down
4 changes: 3 additions & 1 deletion src/lib.rs
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ mod defines;
pub use dim4::Dim4;
mod dim4;

pub use index::{Indexer, index, row, col, lookup, assign_seq, index_gen, assign_gen};
pub use index::{Indexer, index, row, rows, col, cols, slice, slices
, set_row, set_rows, set_col, set_cols, set_slice, set_slices
, lookup, assign_seq, index_gen, assign_gen};
mod index;

pub use seq::Seq;
Expand Down