diff --git a/bindings/src/lib.rs b/bindings/src/lib.rs index 8e6c1ea..c4c92b9 100644 --- a/bindings/src/lib.rs +++ b/bindings/src/lib.rs @@ -2,6 +2,7 @@ mod utils; use std::fmt::Debug; +use ndarray::Array1; use num_traits::{Bounded, Float, Num, Signed, ToPrimitive}; use numpy::{PyArray1, PyArray2, PyArray3}; use powerboxesrs::{boxes, diou, giou, iou, nms, tiou}; @@ -123,7 +124,7 @@ fn _powerboxes(_py: Python, m: &PyModule) -> PyResult<()> { #[pyfunction] fn masks_to_boxes(_py: Python, masks: &PyArray3) -> PyResult>> { let masks = preprocess_array3(masks); - let boxes = boxes::masks_to_boxes(&masks); + let boxes = boxes::masks_to_boxes(masks); let boxes_as_numpy = utils::array_to_numpy(_py, boxes).unwrap(); return Ok(boxes_as_numpy.to_owned()); } @@ -179,7 +180,7 @@ fn diou_distance_generic( boxes2: &PyArray2, ) -> PyResult>> where - T: Float + numpy::Element, + T: Num + Float + numpy::Element, { let boxes1 = preprocess_boxes(boxes1).unwrap(); let boxes2 = preprocess_boxes(boxes2).unwrap(); @@ -216,7 +217,7 @@ where { let boxes1 = preprocess_boxes(boxes1).unwrap(); let boxes2 = preprocess_boxes(boxes2).unwrap(); - let iou = iou::iou_distance(&boxes1, &boxes2); + let iou = iou::iou_distance(boxes1.to_owned(), boxes2.to_owned()); let iou_as_numpy = utils::array_to_numpy(_py, iou).unwrap(); return Ok(iou_as_numpy.to_owned()); } @@ -806,7 +807,7 @@ where )) } }; - let converted_boxes = boxes::box_convert(&boxes, &in_fmt, &out_fmt); + let converted_boxes = boxes::box_convert(&boxes, in_fmt, out_fmt); let converted_boxes_as_numpy = utils::array_to_numpy(_py, converted_boxes).unwrap(); return Ok(converted_boxes_as_numpy.to_owned()); } @@ -902,12 +903,13 @@ fn nms_generic( score_threshold: f64, ) -> PyResult>> where - T: Num + numpy::Element + PartialOrd + ToPrimitive + Copy, + T: numpy::Element + Num + PartialEq + PartialOrd + ToPrimitive + Copy, { let boxes = preprocess_boxes(boxes).unwrap(); let scores = preprocess_array1(scores); let keep = nms::nms(&boxes, &scores, iou_threshold, score_threshold); - let keep_as_numpy = utils::array_to_numpy(_py, keep).unwrap(); + let keep_as_ndarray = Array1::from(keep); + let keep_as_numpy = utils::array_to_numpy(_py, keep_as_ndarray).unwrap(); return Ok(keep_as_numpy.to_owned()); } #[pyfunction] @@ -1064,21 +1066,23 @@ fn rtree_nms_generic( score_threshold: f64, ) -> PyResult>> where - T: Num - + numpy::Element - + PartialOrd - + ToPrimitive - + Copy + T: numpy::Element + + Num + Signed + Bounded + Debug + + PartialEq + + PartialOrd + + ToPrimitive + + Copy + Sync + Send, { let boxes = preprocess_boxes(boxes).unwrap(); let scores = preprocess_array1(scores); let keep = nms::rtree_nms(&boxes, &scores, iou_threshold, score_threshold); - let keep_as_numpy = utils::array_to_numpy(_py, keep).unwrap(); + let keep_as_ndarray = Array1::from(keep); + let keep_as_numpy = utils::array_to_numpy(_py, keep_as_ndarray).unwrap(); return Ok(keep_as_numpy.to_owned()); } #[pyfunction] diff --git a/powerboxesrs/src/iou.rs b/powerboxesrs/src/iou.rs index b467e5a..2f2f4c0 100644 --- a/powerboxesrs/src/iou.rs +++ b/powerboxesrs/src/iou.rs @@ -3,7 +3,7 @@ use crate::{ rotation::{intersection_area, minimal_bounding_rect, Rect}, utils, }; -use ndarray::{Array2, Zip}; +use ndarray::{Array2, ArrayView2, CowArray, Dim, Zip}; use num_traits::{Num, ToPrimitive}; use rstar::RTree; @@ -29,10 +29,13 @@ use rstar::RTree; /// let iou = iou_distance(&boxes1, &boxes2); /// assert_eq!(iou, array![[0.8571428571428572, 1.],[1., 0.8571428571428572]]); /// ``` -pub fn iou_distance(boxes1: &Array2, boxes2: &Array2) -> Array2 +pub fn iou_distance<'a, N, BA>(boxes1: BA, boxes2: BA) -> Array2 where - N: Num + PartialOrd + ToPrimitive + Copy, + N: Num + PartialEq + PartialOrd + ToPrimitive + Copy + 'a, + BA: Into>>, { + let boxes1 = boxes1.into(); + let boxes2 = boxes2.into(); let num_boxes1 = boxes1.nrows(); let num_boxes2 = boxes2.nrows(); @@ -95,10 +98,13 @@ where /// let iou = parallel_iou_distance(&boxes1, &boxes2); /// assert_eq!(iou, array![[0.8571428571428572, 1.],[1., 0.8571428571428572]]); /// ``` -pub fn parallel_iou_distance(boxes1: &Array2, boxes2: &Array2) -> Array2 +pub fn parallel_iou_distance<'a, N, BA>(boxes1: BA, boxes2: BA) -> Array2 where - N: Num + PartialOrd + ToPrimitive + Copy + Clone + Sync + Send, + N: Num + PartialEq + PartialOrd + ToPrimitive + Send + Sync + Copy + 'a, + BA: Into>, { + let boxes1 = boxes1.into(); + let boxes2 = boxes2.into(); let num_boxes1 = boxes1.nrows(); let num_boxes2 = boxes2.nrows(); @@ -151,7 +157,13 @@ where /// # Returns /// A 2D array containing the Rotated IoU distance matrix. The element at position (i, j) represents /// the Rotated IoU distance between the i-th box in `boxes1` and the j-th box in `boxes2`. -pub fn rotated_iou_distance(boxes1: &Array2, boxes2: &Array2) -> Array2 { +pub fn rotated_iou_distance<'a, BA>(boxes1: BA, boxes2: BA) -> Array2 +where + BA: Into>, +{ + let boxes1 = boxes1.into(); + let boxes2 = boxes2.into(); + let num_boxes1 = boxes1.nrows(); let num_boxes2 = boxes2.nrows();