Skip to content

Commit

Permalink
fix: revert previous commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Smirkey committed Apr 29, 2024
1 parent f64fa96 commit 4fbdf2a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 36 deletions.
30 changes: 13 additions & 17 deletions bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ mod utils;

use std::fmt::Debug;

use ndarray::Array1;
use num_traits::{Bounded, Float, Num, Signed, ToPrimitive};
use numpy::{PyArray1, PyArray2, PyArray3};
use powerboxesrs::{boxes, diou, giou, iou, nms, tiou};
Expand Down Expand Up @@ -124,7 +123,7 @@ fn _powerboxes(_py: Python, m: &PyModule) -> PyResult<()> {
#[pyfunction]
fn masks_to_boxes(_py: Python, masks: &PyArray3<bool>) -> PyResult<Py<PyArray2<usize>>> {
let masks = preprocess_array3(masks);
let boxes = boxes::masks_to_boxes(masks);
let boxes = boxes::masks_to_boxes(&masks);
let boxes_as_numpy = utils::array_to_numpy(_py, boxes).unwrap();
return Ok(boxes_as_numpy.to_owned());
}
Expand Down Expand Up @@ -180,7 +179,7 @@ fn diou_distance_generic<T>(
boxes2: &PyArray2<T>,
) -> PyResult<Py<PyArray2<f64>>>
where
T: Num + Float + numpy::Element,
T: Float + numpy::Element,
{
let boxes1 = preprocess_boxes(boxes1).unwrap();
let boxes2 = preprocess_boxes(boxes2).unwrap();
Expand Down Expand Up @@ -217,7 +216,7 @@ where
{
let boxes1 = preprocess_boxes(boxes1).unwrap();
let boxes2 = preprocess_boxes(boxes2).unwrap();
let iou = iou::iou_distance(boxes1.to_owned(), boxes2.to_owned());
let iou = iou::iou_distance(&boxes1, &boxes2);
let iou_as_numpy = utils::array_to_numpy(_py, iou).unwrap();
return Ok(iou_as_numpy.to_owned());
}
Expand Down Expand Up @@ -807,7 +806,7 @@ where
))
}
};
let converted_boxes = boxes::box_convert(&boxes, in_fmt, out_fmt);
let converted_boxes = boxes::box_convert(&boxes, &in_fmt, &out_fmt);
let converted_boxes_as_numpy = utils::array_to_numpy(_py, converted_boxes).unwrap();
return Ok(converted_boxes_as_numpy.to_owned());
}
Expand Down Expand Up @@ -903,13 +902,12 @@ fn nms_generic<T>(
score_threshold: f64,
) -> PyResult<Py<PyArray1<usize>>>
where
T: numpy::Element + Num + PartialEq + PartialOrd + ToPrimitive + Copy,
T: Num + numpy::Element + PartialOrd + ToPrimitive + Copy,
{
let boxes = preprocess_boxes(boxes).unwrap();
let scores = preprocess_array1(scores);
let keep = nms::nms(&boxes, &scores, iou_threshold, score_threshold);
let keep_as_ndarray = Array1::from(keep);
let keep_as_numpy = utils::array_to_numpy(_py, keep_as_ndarray).unwrap();
let keep_as_numpy = utils::array_to_numpy(_py, keep).unwrap();
return Ok(keep_as_numpy.to_owned());
}
#[pyfunction]
Expand Down Expand Up @@ -1066,23 +1064,21 @@ fn rtree_nms_generic<T>(
score_threshold: f64,
) -> PyResult<Py<PyArray1<usize>>>
where
T: numpy::Element
+ Num
+ Signed
+ Bounded
+ Debug
+ PartialEq
T: Num
+ numpy::Element
+ PartialOrd
+ ToPrimitive
+ Copy
+ Signed
+ Bounded
+ Debug
+ Sync
+ Send,
{
let boxes = preprocess_boxes(boxes).unwrap();
let scores = preprocess_array1(scores);
let keep = nms::rtree_nms(&boxes, &scores, iou_threshold, score_threshold);
let keep_as_ndarray = Array1::from(keep);
let keep_as_numpy = utils::array_to_numpy(_py, keep_as_ndarray).unwrap();
let keep_as_numpy = utils::array_to_numpy(_py, keep).unwrap();
return Ok(keep_as_numpy.to_owned());
}
#[pyfunction]
Expand Down Expand Up @@ -1164,4 +1160,4 @@ fn rtree_nms_i16(
iou_threshold,
score_threshold,
)?);
}
}
26 changes: 7 additions & 19 deletions powerboxesrs/src/iou.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{
rotation::{intersection_area, minimal_bounding_rect, Rect},
utils,
};
use ndarray::{Array2, ArrayView2, CowArray, Dim, Zip};
use ndarray::{Array2, Zip};
use num_traits::{Num, ToPrimitive};
use rstar::RTree;

Expand All @@ -29,13 +29,10 @@ use rstar::RTree;
/// let iou = iou_distance(&boxes1, &boxes2);
/// assert_eq!(iou, array![[0.8571428571428572, 1.],[1., 0.8571428571428572]]);
/// ```
pub fn iou_distance<'a, N, BA>(boxes1: BA, boxes2: BA) -> Array2<f64>
pub fn iou_distance<N>(boxes1: &Array2<N>, boxes2: &Array2<N>) -> Array2<f64>
where
N: Num + PartialEq + PartialOrd + ToPrimitive + Copy + 'a,
BA: Into<CowArray<'a, N, Dim<[usize; 2]>>>,
N: Num + PartialOrd + ToPrimitive + Copy,
{
let boxes1 = boxes1.into();
let boxes2 = boxes2.into();
let num_boxes1 = boxes1.nrows();
let num_boxes2 = boxes2.nrows();

Expand Down Expand Up @@ -98,13 +95,10 @@ where
/// let iou = parallel_iou_distance(&boxes1, &boxes2);
/// assert_eq!(iou, array![[0.8571428571428572, 1.],[1., 0.8571428571428572]]);
/// ```
pub fn parallel_iou_distance<'a, N, BA>(boxes1: BA, boxes2: BA) -> Array2<f64>
pub fn parallel_iou_distance<N>(boxes1: &Array2<N>, boxes2: &Array2<N>) -> Array2<f64>
where
N: Num + PartialEq + PartialOrd + ToPrimitive + Send + Sync + Copy + 'a,
BA: Into<ArrayView2<'a, N>>,
N: Num + PartialOrd + ToPrimitive + Copy + Clone + Sync + Send,
{
let boxes1 = boxes1.into();
let boxes2 = boxes2.into();
let num_boxes1 = boxes1.nrows();
let num_boxes2 = boxes2.nrows();

Expand Down Expand Up @@ -157,13 +151,7 @@ where
/// # Returns
/// A 2D array containing the Rotated IoU distance matrix. The element at position (i, j) represents
/// the Rotated IoU distance between the i-th box in `boxes1` and the j-th box in `boxes2`.
pub fn rotated_iou_distance<'a, BA>(boxes1: BA, boxes2: BA) -> Array2<f64>
where
BA: Into<ArrayView2<'a, f64>>,
{
let boxes1 = boxes1.into();
let boxes2 = boxes2.into();

pub fn rotated_iou_distance(boxes1: &Array2<f64>, boxes2: &Array2<f64>) -> Array2<f64> {
let num_boxes1 = boxes1.nrows();
let num_boxes2 = boxes2.nrows();

Expand Down Expand Up @@ -288,4 +276,4 @@ mod tests {
let rotated_iou_distance_result = rotated_iou_distance(&boxes1, &boxes2);
assert_eq!(rotated_iou_distance_result, arr2(&[[0.8571428571428572]]));
}
}
}

0 comments on commit 4fbdf2a

Please sign in to comment.