Skip to content

Commit

Permalink
test: use cow array in iou
Browse files Browse the repository at this point in the history
  • Loading branch information
Smirkey committed Apr 29, 2024
1 parent 08c63a5 commit f64fa96
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 18 deletions.
28 changes: 16 additions & 12 deletions bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mod utils;

use std::fmt::Debug;

use ndarray::Array1;
use num_traits::{Bounded, Float, Num, Signed, ToPrimitive};
use numpy::{PyArray1, PyArray2, PyArray3};
use powerboxesrs::{boxes, diou, giou, iou, nms, tiou};
Expand Down Expand Up @@ -123,7 +124,7 @@ fn _powerboxes(_py: Python, m: &PyModule) -> PyResult<()> {
#[pyfunction]
fn masks_to_boxes(_py: Python, masks: &PyArray3<bool>) -> PyResult<Py<PyArray2<usize>>> {
let masks = preprocess_array3(masks);
let boxes = boxes::masks_to_boxes(&masks);
let boxes = boxes::masks_to_boxes(masks);
let boxes_as_numpy = utils::array_to_numpy(_py, boxes).unwrap();
return Ok(boxes_as_numpy.to_owned());
}
Expand Down Expand Up @@ -179,7 +180,7 @@ fn diou_distance_generic<T>(
boxes2: &PyArray2<T>,
) -> PyResult<Py<PyArray2<f64>>>
where
T: Float + numpy::Element,
T: Num + Float + numpy::Element,
{
let boxes1 = preprocess_boxes(boxes1).unwrap();
let boxes2 = preprocess_boxes(boxes2).unwrap();
Expand Down Expand Up @@ -216,7 +217,7 @@ where
{
let boxes1 = preprocess_boxes(boxes1).unwrap();
let boxes2 = preprocess_boxes(boxes2).unwrap();
let iou = iou::iou_distance(&boxes1, &boxes2);
let iou = iou::iou_distance(boxes1.to_owned(), boxes2.to_owned());
let iou_as_numpy = utils::array_to_numpy(_py, iou).unwrap();
return Ok(iou_as_numpy.to_owned());
}
Expand Down Expand Up @@ -806,7 +807,7 @@ where
))
}
};
let converted_boxes = boxes::box_convert(&boxes, &in_fmt, &out_fmt);
let converted_boxes = boxes::box_convert(&boxes, in_fmt, out_fmt);
let converted_boxes_as_numpy = utils::array_to_numpy(_py, converted_boxes).unwrap();
return Ok(converted_boxes_as_numpy.to_owned());
}
Expand Down Expand Up @@ -902,12 +903,13 @@ fn nms_generic<T>(
score_threshold: f64,
) -> PyResult<Py<PyArray1<usize>>>
where
T: Num + numpy::Element + PartialOrd + ToPrimitive + Copy,
T: numpy::Element + Num + PartialEq + PartialOrd + ToPrimitive + Copy,
{
let boxes = preprocess_boxes(boxes).unwrap();
let scores = preprocess_array1(scores);
let keep = nms::nms(&boxes, &scores, iou_threshold, score_threshold);
let keep_as_numpy = utils::array_to_numpy(_py, keep).unwrap();
let keep_as_ndarray = Array1::from(keep);
let keep_as_numpy = utils::array_to_numpy(_py, keep_as_ndarray).unwrap();
return Ok(keep_as_numpy.to_owned());
}
#[pyfunction]
Expand Down Expand Up @@ -1064,21 +1066,23 @@ fn rtree_nms_generic<T>(
score_threshold: f64,
) -> PyResult<Py<PyArray1<usize>>>
where
T: Num
+ numpy::Element
+ PartialOrd
+ ToPrimitive
+ Copy
T: numpy::Element
+ Num
+ Signed
+ Bounded
+ Debug
+ PartialEq
+ PartialOrd
+ ToPrimitive
+ Copy
+ Sync
+ Send,
{
let boxes = preprocess_boxes(boxes).unwrap();
let scores = preprocess_array1(scores);
let keep = nms::rtree_nms(&boxes, &scores, iou_threshold, score_threshold);
let keep_as_numpy = utils::array_to_numpy(_py, keep).unwrap();
let keep_as_ndarray = Array1::from(keep);
let keep_as_numpy = utils::array_to_numpy(_py, keep_as_ndarray).unwrap();
return Ok(keep_as_numpy.to_owned());
}
#[pyfunction]
Expand Down
24 changes: 18 additions & 6 deletions powerboxesrs/src/iou.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{
rotation::{intersection_area, minimal_bounding_rect, Rect},
utils,
};
use ndarray::{Array2, Zip};
use ndarray::{Array2, ArrayView2, CowArray, Dim, Zip};
use num_traits::{Num, ToPrimitive};
use rstar::RTree;

Expand All @@ -29,10 +29,13 @@ use rstar::RTree;
/// let iou = iou_distance(&boxes1, &boxes2);
/// assert_eq!(iou, array![[0.8571428571428572, 1.],[1., 0.8571428571428572]]);
/// ```
pub fn iou_distance<N>(boxes1: &Array2<N>, boxes2: &Array2<N>) -> Array2<f64>
pub fn iou_distance<'a, N, BA>(boxes1: BA, boxes2: BA) -> Array2<f64>
where
N: Num + PartialOrd + ToPrimitive + Copy,
N: Num + PartialEq + PartialOrd + ToPrimitive + Copy + 'a,
BA: Into<CowArray<'a, N, Dim<[usize; 2]>>>,
{
let boxes1 = boxes1.into();
let boxes2 = boxes2.into();
let num_boxes1 = boxes1.nrows();
let num_boxes2 = boxes2.nrows();

Expand Down Expand Up @@ -95,10 +98,13 @@ where
/// let iou = parallel_iou_distance(&boxes1, &boxes2);
/// assert_eq!(iou, array![[0.8571428571428572, 1.],[1., 0.8571428571428572]]);
/// ```
pub fn parallel_iou_distance<N>(boxes1: &Array2<N>, boxes2: &Array2<N>) -> Array2<f64>
pub fn parallel_iou_distance<'a, N, BA>(boxes1: BA, boxes2: BA) -> Array2<f64>
where
N: Num + PartialOrd + ToPrimitive + Copy + Clone + Sync + Send,
N: Num + PartialEq + PartialOrd + ToPrimitive + Send + Sync + Copy + 'a,
BA: Into<ArrayView2<'a, N>>,
{
let boxes1 = boxes1.into();
let boxes2 = boxes2.into();
let num_boxes1 = boxes1.nrows();
let num_boxes2 = boxes2.nrows();

Expand Down Expand Up @@ -151,7 +157,13 @@ where
/// # Returns
/// A 2D array containing the Rotated IoU distance matrix. The element at position (i, j) represents
/// the Rotated IoU distance between the i-th box in `boxes1` and the j-th box in `boxes2`.
pub fn rotated_iou_distance(boxes1: &Array2<f64>, boxes2: &Array2<f64>) -> Array2<f64> {
pub fn rotated_iou_distance<'a, BA>(boxes1: BA, boxes2: BA) -> Array2<f64>
where
BA: Into<ArrayView2<'a, f64>>,
{
let boxes1 = boxes1.into();
let boxes2 = boxes2.into();

let num_boxes1 = boxes1.nrows();
let num_boxes2 = boxes2.nrows();

Expand Down

0 comments on commit f64fa96

Please sign in to comment.