Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds into_pyarray to faer Mat #482

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -18,13 +18,17 @@ license = "BSD-2-Clause"
half = { version = "2.0", default-features = false, optional = true }
libc = "0.2"
nalgebra = { version = ">=0.30, <0.34", default-features = false, optional = true }
faer = { version = "0.21.9", optional = true }
num-complex = ">= 0.2, < 0.5"
num-integer = "0.1"
num-traits = "0.2"
ndarray = ">= 0.15, < 0.17"
pyo3 = { version = "0.24", default-features = false, features = ["macros"] }
rustc-hash = "2.0"

[features]
faer = ["dep:faer"]

[dev-dependencies]
pyo3 = { version = "0.24", default-features = false, features = ["auto-initialize"] }
nalgebra = { version = ">=0.30, <0.34", default-features = false, features = ["std"] }
28 changes: 27 additions & 1 deletion src/convert.rs
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@

use std::{mem, os::raw::c_int, ptr};

use ndarray::{ArrayBase, Data, Dim, Dimension, IntoDimension, Ix1, OwnedRepr};
use ndarray::{ArrayBase, Data, Dim, Dimension, IntoDimension, Ix1, Ix2, OwnedRepr};
use pyo3::{Bound, Python};

use crate::array::{PyArray, PyArrayMethods};
@@ -90,6 +90,32 @@ impl<T: Element> IntoPyArray for Vec<T> {
}
}

#[cfg(feature = "faer")]
impl<T: Element> IntoPyArray for faer::Mat<T> {
type Item = T;
type Dim = Ix2;

fn into_pyarray<'py>(mut self, py: Python<'py>) -> Bound<'py, PyArray<Self::Item, Self::Dim>> {
let dims = Dim([self.nrows(), self.ncols()]);
let rstride = self.row_stride();
let cstride = self.col_stride();
let strides = [
rstride * mem::size_of::<T>() as npy_intp,
cstride * mem::size_of::<T>() as npy_intp,
];
let data_ptr = self.as_ptr_mut();
unsafe {
PyArray::from_raw_parts(
py,
dims,
strides.as_ptr(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngoldbaum does PyArray_NewFromDescr copy the strides array or just store the pointer? This looks like a potential use-after-free 🤔

(We already have the same pattern in the other into_pyarray functions in this file, which makes me think it's probably fine? Either that or there's a nasty bug in rust-numpy already...)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's an explicit copy if strides are passed in:

https://github.com/numpy/numpy/blob/a651643fc9b5699a6dc6f3c85ba499c1f52ce3aa/numpy/_core/src/multiarray/ctors.c#L801-L816

There's another code path that handles structured dtypes with subarrays but it looks like that copies the strides too.

There's a comment here worrying about unaligned input strides:

https://github.com/numpy/numpy/blob/a651643fc9b5699a6dc6f3c85ba499c1f52ce3aa/numpy/_core/src/multiarray/ctors.c#L911-L916

But also I have no idea why an unaligned stride array would ever be a problem.

data_ptr,
PySliceContainer::from(self),
)
}
}
}

impl<A, D> IntoPyArray for ArrayBase<OwnedRepr<A>, D>
where
A: Element,
29 changes: 29 additions & 0 deletions src/slice_container.rs
Original file line number Diff line number Diff line change
@@ -71,6 +71,35 @@ impl<T: Send + Sync> From<Vec<T>> for PySliceContainer {
}
}

#[cfg(feature = "faer")]
impl<T: Send + Sync> From<faer::Mat<T>> for PySliceContainer {
fn from(data: faer::Mat<T>) -> Self {
unsafe fn drop_faer_mat<T>(ptr: *mut u8, len_nrows: usize, cap_ncols: usize) {
let _ = faer::mat::MatMut::from_raw_parts_mut(
ptr as *mut T,
len_nrows,
cap_ncols,
1,
cap_ncols as isize,
);
}

let mut data = mem::ManuallyDrop::new(data);

let ptr = data.as_ptr_mut() as *mut u8;
let len = data.nrows();
let cap = data.ncols();
let drop = drop_faer_mat::<T>;

Self {
ptr,
len,
cap,
drop,
}
}
}

impl<A, D> From<ArrayBase<OwnedRepr<A>, D>> for PySliceContainer
where
A: Send + Sync,
25 changes: 25 additions & 0 deletions tests/to_py.rs
Original file line number Diff line number Diff line change
@@ -288,6 +288,31 @@ fn slice_container_type_confusion() {
});
}

#[cfg(feature = "faer")]
#[test]
fn faer_mat_to_numpy() {
let faer_mat: faer::Mat<f64> = faer::Scale(2.0) * faer::mat::Mat::<f64>::identity(2, 2);
let faer_mat_wide: faer::Mat<f64> = faer::mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let faer_mat_tall: faer::Mat<f64> = faer_mat_wide.transpose().to_owned();
Python::with_gil(|py| {
let mat_pyarray = faer_mat.into_pyarray(py);
let mat_wide_pyarray = faer_mat_wide.into_pyarray(py);
let mat_tall_pyarray = faer_mat_tall.into_pyarray(py);
assert_eq!(
mat_pyarray.readonly().as_array(),
array![[2.0f64, 0.0f64], [0.0f64, 2.0f64]]
);
assert_eq!(
mat_wide_pyarray.readonly().as_array(),
array![[1.0f64, 2.0, 3.0], [4.0, 5.0, 6.0]]
);
assert_eq!(
mat_tall_pyarray.readonly().as_array(),
array![[1.0f64, 4.0], [2.0, 5.0], [3.0, 6.0]]
);
});
}

#[cfg(feature = "nalgebra")]
#[test]
fn matrix_to_numpy() {
Loading
Oops, something went wrong.