Skip to content

Commit

Permalink
chore: revert bad changes and apply nms logic to rtree_nms
Browse files Browse the repository at this point in the history
  • Loading branch information
Smirkey authored Apr 29, 2024
1 parent 4fbdf2a commit bcf3e1d
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 124 deletions.
30 changes: 17 additions & 13 deletions bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mod utils;

use std::fmt::Debug;

use ndarray::Array1;
use num_traits::{Bounded, Float, Num, Signed, ToPrimitive};
use numpy::{PyArray1, PyArray2, PyArray3};
use powerboxesrs::{boxes, diou, giou, iou, nms, tiou};
Expand Down Expand Up @@ -123,7 +124,7 @@ fn _powerboxes(_py: Python, m: &PyModule) -> PyResult<()> {
#[pyfunction]
fn masks_to_boxes(_py: Python, masks: &PyArray3<bool>) -> PyResult<Py<PyArray2<usize>>> {
let masks = preprocess_array3(masks);
let boxes = boxes::masks_to_boxes(&masks);
let boxes = boxes::masks_to_boxes(masks);
let boxes_as_numpy = utils::array_to_numpy(_py, boxes).unwrap();
return Ok(boxes_as_numpy.to_owned());
}
Expand Down Expand Up @@ -179,7 +180,7 @@ fn diou_distance_generic<T>(
boxes2: &PyArray2<T>,
) -> PyResult<Py<PyArray2<f64>>>
where
T: Float + numpy::Element,
T: Num + Float + numpy::Element,
{
let boxes1 = preprocess_boxes(boxes1).unwrap();
let boxes2 = preprocess_boxes(boxes2).unwrap();
Expand Down Expand Up @@ -216,7 +217,7 @@ where
{
let boxes1 = preprocess_boxes(boxes1).unwrap();
let boxes2 = preprocess_boxes(boxes2).unwrap();
let iou = iou::iou_distance(&boxes1, &boxes2);
let iou = iou::iou_distance(boxes1, boxes2);
let iou_as_numpy = utils::array_to_numpy(_py, iou).unwrap();
return Ok(iou_as_numpy.to_owned());
}
Expand Down Expand Up @@ -806,7 +807,7 @@ where
))
}
};
let converted_boxes = boxes::box_convert(&boxes, &in_fmt, &out_fmt);
let converted_boxes = boxes::box_convert(&boxes, in_fmt, out_fmt);
let converted_boxes_as_numpy = utils::array_to_numpy(_py, converted_boxes).unwrap();
return Ok(converted_boxes_as_numpy.to_owned());
}
Expand Down Expand Up @@ -902,12 +903,13 @@ fn nms_generic<T>(
score_threshold: f64,
) -> PyResult<Py<PyArray1<usize>>>
where
T: Num + numpy::Element + PartialOrd + ToPrimitive + Copy,
T: numpy::Element + Num + PartialEq + PartialOrd + ToPrimitive + Copy,
{
let boxes = preprocess_boxes(boxes).unwrap();
let scores = preprocess_array1(scores);
let keep = nms::nms(&boxes, &scores, iou_threshold, score_threshold);
let keep_as_numpy = utils::array_to_numpy(_py, keep).unwrap();
let keep_as_ndarray = Array1::from(keep);
let keep_as_numpy = utils::array_to_numpy(_py, keep_as_ndarray).unwrap();
return Ok(keep_as_numpy.to_owned());
}
#[pyfunction]
Expand Down Expand Up @@ -1064,21 +1066,23 @@ fn rtree_nms_generic<T>(
score_threshold: f64,
) -> PyResult<Py<PyArray1<usize>>>
where
T: Num
+ numpy::Element
+ PartialOrd
+ ToPrimitive
+ Copy
T: numpy::Element
+ Num
+ Signed
+ Bounded
+ Debug
+ PartialEq
+ PartialOrd
+ ToPrimitive
+ Copy
+ Sync
+ Send,
{
let boxes = preprocess_boxes(boxes).unwrap();
let scores = preprocess_array1(scores);
let keep = nms::rtree_nms(&boxes, &scores, iou_threshold, score_threshold);
let keep_as_numpy = utils::array_to_numpy(_py, keep).unwrap();
let keep_as_ndarray = Array1::from(keep);
let keep_as_numpy = utils::array_to_numpy(_py, keep_as_ndarray).unwrap();
return Ok(keep_as_numpy.to_owned());
}
#[pyfunction]
Expand Down Expand Up @@ -1160,4 +1164,4 @@ fn rtree_nms_i16(
iou_threshold,
score_threshold,
)?);
}
}
72 changes: 45 additions & 27 deletions bindings/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,44 @@
use ndarray::{Array1, Array2, Array3, ArrayBase, OwnedRepr};
use ndarray::{ArrayBase, Dim, OwnedRepr, ViewRepr};
use num_traits::Num;
use numpy::{IntoPyArray, PyArray, PyArray1, PyArray2, PyArray3};
use pyo3::prelude::*;

pub fn array_to_numpy<T: numpy::Element, D: ndarray::Dimension>(
/// Converts a 2-dimensional Rust ndarray to a NumPy array.
///
/// # Arguments
///
/// * `py` - The Python interpreter context.
/// * `array` - The 2-dimensional Rust ndarray to convert.
///
/// # Returns
///
/// A reference to the converted NumPy array.
///
/// # Example
///
/// ```rust
/// let py = Python::acquire_gil().python();
/// let array_2d: Array2<f64> = Array2::ones((3, 3));
/// let numpy_array_2d = array2_to_numpy(py, array_2d).unwrap();
/// ```
pub fn array_to_numpy<T, D>(
py: Python,
array: ArrayBase<OwnedRepr<T>, D>,
) -> PyResult<&PyArray<T, D>> {
let numpy_array: &PyArray<T, D> = array.into_pyarray(py);
) -> PyResult<&PyArray<T, D>>
where
T: numpy::Element,
D: ndarray::Dimension,
{
let numpy_array = array.into_pyarray(py);

return Ok(numpy_array);
}

pub fn preprocess_boxes<N>(array: &PyArray2<N>) -> Result<Array2<N>, PyErr>
pub fn preprocess_boxes<N>(
array: &PyArray2<N>,
) -> Result<ArrayBase<ViewRepr<&N>, Dim<[usize; 2]>>, PyErr>
where
N: Num + numpy::Element + Send,
N: numpy::Element,
{
let array = unsafe { array.as_array() };
let array_shape = array.shape();
Expand All @@ -32,16 +57,14 @@ where
}
}

let array = array
.to_owned()
.into_shape((array_shape[0], array_shape[1]))
.unwrap();
return Ok(array);
}

pub fn preprocess_rotated_boxes<N>(array: &PyArray2<N>) -> Result<Array2<N>, PyErr>
pub fn preprocess_rotated_boxes<'a, N>(
array: &PyArray2<N>,
) -> Result<ArrayBase<ViewRepr<&N>, Dim<[usize; 2]>>, PyErr>
where
N: Num + numpy::Element + Send,
N: Num + numpy::Element + Send + 'a,
{
let array = unsafe { array.as_array() };
let array_shape = array.shape();
Expand All @@ -60,42 +83,37 @@ where
}
}

let array = array
.to_owned()
.into_shape((array_shape[0], array_shape[1]))
.unwrap();
return Ok(array);
}

pub fn preprocess_array3<N>(array: &PyArray3<N>) -> Array3<N>
pub fn preprocess_array3<'a, N>(array: &PyArray3<N>) -> ArrayBase<ViewRepr<&N>, Dim<[usize; 3]>>
where
N: numpy::Element,
N: numpy::Element + 'a,
{
let array = unsafe { array.as_array().to_owned() };
let array = unsafe { array.as_array() };
return array;
}

pub fn preprocess_array1<N>(array: &PyArray1<N>) -> Array1<N>
pub fn preprocess_array1<'a, N>(array: &PyArray1<N>) -> ArrayBase<ViewRepr<&N>, Dim<[usize; 1]>>
where
N: numpy::Element,
N: numpy::Element + 'a,
{
let array = unsafe { array.as_array().to_owned() };
let array: ArrayBase<ViewRepr<&N>, ndarray::prelude::Dim<[usize; 1]>> =
unsafe { array.as_array() };
return array;
}

#[cfg(test)]
mod tests {
use super::*;
use ndarray::ArrayBase;
use ndarray::Array1;

#[test]
fn test_array_to_numpy() {
let data = vec![1., 2., 3., 4.];
let array = ArrayBase::from_shape_vec((1, 4), data).unwrap();
let array = Array1::from(vec![1., 2., 3., 4.]);
Python::with_gil(|py| {
let result = array_to_numpy(py, array).unwrap();
assert_eq!(result.readonly().shape(), &[1, 4]);
assert_eq!(result.readonly().shape(), &[1, 4]);
assert_eq!(result.readonly().shape(), &[4]);
});
}

Expand Down
26 changes: 19 additions & 7 deletions powerboxesrs/src/iou.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{
rotation::{intersection_area, minimal_bounding_rect, Rect},
utils,
};
use ndarray::{Array2, Zip};
use ndarray::{Array2, ArrayView2, Zip};
use num_traits::{Num, ToPrimitive};
use rstar::RTree;

Expand All @@ -29,10 +29,13 @@ use rstar::RTree;
/// let iou = iou_distance(&boxes1, &boxes2);
/// assert_eq!(iou, array![[0.8571428571428572, 1.],[1., 0.8571428571428572]]);
/// ```
pub fn iou_distance<N>(boxes1: &Array2<N>, boxes2: &Array2<N>) -> Array2<f64>
pub fn iou_distance<'a, N, BA>(boxes1: BA, boxes2: BA) -> Array2<f64>
where
N: Num + PartialOrd + ToPrimitive + Copy,
N: Num + PartialEq + PartialOrd + ToPrimitive + Copy + 'a,
BA: Into<ArrayView2<'a, N>>,
{
let boxes1 = boxes1.into();
let boxes2 = boxes2.into();
let num_boxes1 = boxes1.nrows();
let num_boxes2 = boxes2.nrows();

Expand Down Expand Up @@ -95,10 +98,13 @@ where
/// let iou = parallel_iou_distance(&boxes1, &boxes2);
/// assert_eq!(iou, array![[0.8571428571428572, 1.],[1., 0.8571428571428572]]);
/// ```
pub fn parallel_iou_distance<N>(boxes1: &Array2<N>, boxes2: &Array2<N>) -> Array2<f64>
pub fn parallel_iou_distance<'a, N, BA>(boxes1: BA, boxes2: BA) -> Array2<f64>
where
N: Num + PartialOrd + ToPrimitive + Copy + Clone + Sync + Send,
N: Num + PartialEq + PartialOrd + ToPrimitive + Send + Sync + Copy + 'a,
BA: Into<ArrayView2<'a, N>>,
{
let boxes1 = boxes1.into();
let boxes2 = boxes2.into();
let num_boxes1 = boxes1.nrows();
let num_boxes2 = boxes2.nrows();

Expand Down Expand Up @@ -151,7 +157,13 @@ where
/// # Returns
/// A 2D array containing the Rotated IoU distance matrix. The element at position (i, j) represents
/// the Rotated IoU distance between the i-th box in `boxes1` and the j-th box in `boxes2`.
pub fn rotated_iou_distance(boxes1: &Array2<f64>, boxes2: &Array2<f64>) -> Array2<f64> {
pub fn rotated_iou_distance<'a, BA>(boxes1: BA, boxes2: BA) -> Array2<f64>
where
BA: Into<ArrayView2<'a, f64>>,
{
let boxes1 = boxes1.into();
let boxes2 = boxes2.into();

let num_boxes1 = boxes1.nrows();
let num_boxes2 = boxes2.nrows();

Expand Down Expand Up @@ -276,4 +288,4 @@ mod tests {
let rotated_iou_distance_result = rotated_iou_distance(&boxes1, &boxes2);
assert_eq!(rotated_iou_distance_result, arr2(&[[0.8571428571428572]]));
}
}
}
Loading

0 comments on commit bcf3e1d

Please sign in to comment.