Skip to content

Commit

Permalink
fix: revert previous commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Smirkey authored Apr 29, 2024
1 parent bcf3e1d commit f52253e
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 192 deletions.
30 changes: 13 additions & 17 deletions bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ mod utils;

use std::fmt::Debug;

use ndarray::Array1;
use num_traits::{Bounded, Float, Num, Signed, ToPrimitive};
use numpy::{PyArray1, PyArray2, PyArray3};
use powerboxesrs::{boxes, diou, giou, iou, nms, tiou};
Expand Down Expand Up @@ -124,7 +123,7 @@ fn _powerboxes(_py: Python, m: &PyModule) -> PyResult<()> {
#[pyfunction]
fn masks_to_boxes(_py: Python, masks: &PyArray3<bool>) -> PyResult<Py<PyArray2<usize>>> {
let masks = preprocess_array3(masks);
let boxes = boxes::masks_to_boxes(masks);
let boxes = boxes::masks_to_boxes(&masks);
let boxes_as_numpy = utils::array_to_numpy(_py, boxes).unwrap();
return Ok(boxes_as_numpy.to_owned());
}
Expand Down Expand Up @@ -180,7 +179,7 @@ fn diou_distance_generic<T>(
boxes2: &PyArray2<T>,
) -> PyResult<Py<PyArray2<f64>>>
where
T: Num + Float + numpy::Element,
T: Float + numpy::Element,
{
let boxes1 = preprocess_boxes(boxes1).unwrap();
let boxes2 = preprocess_boxes(boxes2).unwrap();
Expand Down Expand Up @@ -217,7 +216,7 @@ where
{
let boxes1 = preprocess_boxes(boxes1).unwrap();
let boxes2 = preprocess_boxes(boxes2).unwrap();
let iou = iou::iou_distance(boxes1, boxes2);
let iou = iou::iou_distance(&boxes1, &boxes2);
let iou_as_numpy = utils::array_to_numpy(_py, iou).unwrap();
return Ok(iou_as_numpy.to_owned());
}
Expand Down Expand Up @@ -807,7 +806,7 @@ where
))
}
};
let converted_boxes = boxes::box_convert(&boxes, in_fmt, out_fmt);
let converted_boxes = boxes::box_convert(&boxes, &in_fmt, &out_fmt);
let converted_boxes_as_numpy = utils::array_to_numpy(_py, converted_boxes).unwrap();
return Ok(converted_boxes_as_numpy.to_owned());
}
Expand Down Expand Up @@ -903,13 +902,12 @@ fn nms_generic<T>(
score_threshold: f64,
) -> PyResult<Py<PyArray1<usize>>>
where
T: numpy::Element + Num + PartialEq + PartialOrd + ToPrimitive + Copy,
T: Num + numpy::Element + PartialOrd + ToPrimitive + Copy,
{
let boxes = preprocess_boxes(boxes).unwrap();
let scores = preprocess_array1(scores);
let keep = nms::nms(&boxes, &scores, iou_threshold, score_threshold);
let keep_as_ndarray = Array1::from(keep);
let keep_as_numpy = utils::array_to_numpy(_py, keep_as_ndarray).unwrap();
let keep_as_numpy = utils::array_to_numpy(_py, keep).unwrap();
return Ok(keep_as_numpy.to_owned());
}
#[pyfunction]
Expand Down Expand Up @@ -1066,23 +1064,21 @@ fn rtree_nms_generic<T>(
score_threshold: f64,
) -> PyResult<Py<PyArray1<usize>>>
where
T: numpy::Element
+ Num
+ Signed
+ Bounded
+ Debug
+ PartialEq
T: Num
+ numpy::Element
+ PartialOrd
+ ToPrimitive
+ Copy
+ Signed
+ Bounded
+ Debug
+ Sync
+ Send,
{
let boxes = preprocess_boxes(boxes).unwrap();
let scores = preprocess_array1(scores);
let keep = nms::rtree_nms(&boxes, &scores, iou_threshold, score_threshold);
let keep_as_ndarray = Array1::from(keep);
let keep_as_numpy = utils::array_to_numpy(_py, keep_as_ndarray).unwrap();
let keep_as_numpy = utils::array_to_numpy(_py, keep).unwrap();
return Ok(keep_as_numpy.to_owned());
}
#[pyfunction]
Expand Down Expand Up @@ -1164,4 +1160,4 @@ fn rtree_nms_i16(
iou_threshold,
score_threshold,
)?);
}
}
72 changes: 27 additions & 45 deletions bindings/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,44 +1,19 @@
use ndarray::{ArrayBase, Dim, OwnedRepr, ViewRepr};
use ndarray::{Array1, Array2, Array3, ArrayBase, OwnedRepr};
use num_traits::Num;
use numpy::{IntoPyArray, PyArray, PyArray1, PyArray2, PyArray3};
use pyo3::prelude::*;

/// Converts a 2-dimensional Rust ndarray to a NumPy array.
///
/// # Arguments
///
/// * `py` - The Python interpreter context.
/// * `array` - The 2-dimensional Rust ndarray to convert.
///
/// # Returns
///
/// A reference to the converted NumPy array.
///
/// # Example
///
/// ```rust
/// let py = Python::acquire_gil().python();
/// let array_2d: Array2<f64> = Array2::ones((3, 3));
/// let numpy_array_2d = array2_to_numpy(py, array_2d).unwrap();
/// ```
pub fn array_to_numpy<T, D>(
pub fn array_to_numpy<T: numpy::Element, D: ndarray::Dimension>(
py: Python,
array: ArrayBase<OwnedRepr<T>, D>,
) -> PyResult<&PyArray<T, D>>
where
T: numpy::Element,
D: ndarray::Dimension,
{
let numpy_array = array.into_pyarray(py);

) -> PyResult<&PyArray<T, D>> {
let numpy_array: &PyArray<T, D> = array.into_pyarray(py);
return Ok(numpy_array);
}

pub fn preprocess_boxes<N>(
array: &PyArray2<N>,
) -> Result<ArrayBase<ViewRepr<&N>, Dim<[usize; 2]>>, PyErr>
pub fn preprocess_boxes<N>(array: &PyArray2<N>) -> Result<Array2<N>, PyErr>
where
N: numpy::Element,
N: Num + numpy::Element + Send,
{
let array = unsafe { array.as_array() };
let array_shape = array.shape();
Expand All @@ -57,14 +32,16 @@ where
}
}

let array = array
.to_owned()
.into_shape((array_shape[0], array_shape[1]))
.unwrap();
return Ok(array);
}

pub fn preprocess_rotated_boxes<'a, N>(
array: &PyArray2<N>,
) -> Result<ArrayBase<ViewRepr<&N>, Dim<[usize; 2]>>, PyErr>
pub fn preprocess_rotated_boxes<N>(array: &PyArray2<N>) -> Result<Array2<N>, PyErr>
where
N: Num + numpy::Element + Send + 'a,
N: Num + numpy::Element + Send,
{
let array = unsafe { array.as_array() };
let array_shape = array.shape();
Expand All @@ -83,37 +60,42 @@ where
}
}

let array = array
.to_owned()
.into_shape((array_shape[0], array_shape[1]))
.unwrap();
return Ok(array);
}

pub fn preprocess_array3<'a, N>(array: &PyArray3<N>) -> ArrayBase<ViewRepr<&N>, Dim<[usize; 3]>>
pub fn preprocess_array3<N>(array: &PyArray3<N>) -> Array3<N>
where
N: numpy::Element + 'a,
N: numpy::Element,
{
let array = unsafe { array.as_array() };
let array = unsafe { array.as_array().to_owned() };
return array;
}

pub fn preprocess_array1<'a, N>(array: &PyArray1<N>) -> ArrayBase<ViewRepr<&N>, Dim<[usize; 1]>>
pub fn preprocess_array1<N>(array: &PyArray1<N>) -> Array1<N>
where
N: numpy::Element + 'a,
N: numpy::Element,
{
let array: ArrayBase<ViewRepr<&N>, ndarray::prelude::Dim<[usize; 1]>> =
unsafe { array.as_array() };
let array = unsafe { array.as_array().to_owned() };
return array;
}

#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array1;
use ndarray::ArrayBase;

#[test]
fn test_array_to_numpy() {
let array = Array1::from(vec![1., 2., 3., 4.]);
let data = vec![1., 2., 3., 4.];
let array = ArrayBase::from_shape_vec((1, 4), data).unwrap();
Python::with_gil(|py| {
let result = array_to_numpy(py, array).unwrap();
assert_eq!(result.readonly().shape(), &[4]);
assert_eq!(result.readonly().shape(), &[1, 4]);
assert_eq!(result.readonly().shape(), &[1, 4]);
});
}

Expand Down
26 changes: 7 additions & 19 deletions powerboxesrs/src/iou.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{
rotation::{intersection_area, minimal_bounding_rect, Rect},
utils,
};
use ndarray::{Array2, ArrayView2, Zip};
use ndarray::{Array2, Zip};
use num_traits::{Num, ToPrimitive};
use rstar::RTree;

Expand All @@ -29,13 +29,10 @@ use rstar::RTree;
/// let iou = iou_distance(&boxes1, &boxes2);
/// assert_eq!(iou, array![[0.8571428571428572, 1.],[1., 0.8571428571428572]]);
/// ```
pub fn iou_distance<'a, N, BA>(boxes1: BA, boxes2: BA) -> Array2<f64>
pub fn iou_distance<N>(boxes1: &Array2<N>, boxes2: &Array2<N>) -> Array2<f64>
where
N: Num + PartialEq + PartialOrd + ToPrimitive + Copy + 'a,
BA: Into<ArrayView2<'a, N>>,
N: Num + PartialOrd + ToPrimitive + Copy,
{
let boxes1 = boxes1.into();
let boxes2 = boxes2.into();
let num_boxes1 = boxes1.nrows();
let num_boxes2 = boxes2.nrows();

Expand Down Expand Up @@ -98,13 +95,10 @@ where
/// let iou = parallel_iou_distance(&boxes1, &boxes2);
/// assert_eq!(iou, array![[0.8571428571428572, 1.],[1., 0.8571428571428572]]);
/// ```
pub fn parallel_iou_distance<'a, N, BA>(boxes1: BA, boxes2: BA) -> Array2<f64>
pub fn parallel_iou_distance<N>(boxes1: &Array2<N>, boxes2: &Array2<N>) -> Array2<f64>
where
N: Num + PartialEq + PartialOrd + ToPrimitive + Send + Sync + Copy + 'a,
BA: Into<ArrayView2<'a, N>>,
N: Num + PartialOrd + ToPrimitive + Copy + Clone + Sync + Send,
{
let boxes1 = boxes1.into();
let boxes2 = boxes2.into();
let num_boxes1 = boxes1.nrows();
let num_boxes2 = boxes2.nrows();

Expand Down Expand Up @@ -157,13 +151,7 @@ where
/// # Returns
/// A 2D array containing the Rotated IoU distance matrix. The element at position (i, j) represents
/// the Rotated IoU distance between the i-th box in `boxes1` and the j-th box in `boxes2`.
pub fn rotated_iou_distance<'a, BA>(boxes1: BA, boxes2: BA) -> Array2<f64>
where
BA: Into<ArrayView2<'a, f64>>,
{
let boxes1 = boxes1.into();
let boxes2 = boxes2.into();

pub fn rotated_iou_distance(boxes1: &Array2<f64>, boxes2: &Array2<f64>) -> Array2<f64> {
let num_boxes1 = boxes1.nrows();
let num_boxes2 = boxes2.nrows();

Expand Down Expand Up @@ -288,4 +276,4 @@ mod tests {
let rotated_iou_distance_result = rotated_iou_distance(&boxes1, &boxes2);
assert_eq!(rotated_iou_distance_result, arr2(&[[0.8571428571428572]]));
}
}
}
Loading

0 comments on commit f52253e

Please sign in to comment.