Skip to content

Commit

Permalink
feat: rotated IoU distance (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
Smirkey committed Jan 8, 2024
1 parent efec8a4 commit 1c86706
Show file tree
Hide file tree
Showing 12 changed files with 264 additions and 15 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,22 @@ Some benchmarks of powerboxes against various open source alternatives, not all
Benchmarks can be found in this google colab [notebook](https://colab.research.google.com/drive/1Z8auT4GZFbwaNs9hZfnB0kvYBbX-MOgS?usp=sharing)

### Box area
Here it's torchvision vs powerboxes
Here it's torchvision vs powerboxes vs numpy

![Box area](./images/box_area.png)

### Box convert
Here it's torchvision vs powerboxes

![Box convert](./images/box_area.png)
![Box convert](./images/box_convert.png)

### Box IoU matrix
Torchvision vs powerboxes vs shapely
Torchvision vs numpy vs powerboxes

![Box IoU](./images/box_iou.png)

### NMS
Torchvision vs powerboxes vs lsnms
Torchvision vs powerboxes vs lsnms vs numpy

#### Large image (10000x10000 pixels)

Expand Down
12 changes: 6 additions & 6 deletions bindings/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,23 @@ intersection = pb.iou_distance(box, box)

Some benchmarks of powerboxes against various open source alternatives, not all functions are benchmarked. Notice that we use log scales, **all differences are major** !

### Box area, (I suspect torchvision to use multiple cores)
Here it's torchvision vs powerboxes
### Box area
Here it's torchvision vs powerboxes vs numpy

![Box area](../images/box_area.png)

### Box convert,(I suspect torchvision to use multiple cores)
### Box convert
Here it's torchvision vs powerboxes

![Box convert](../images/box_area.png)
![Box convert](../images/box_convert.png)

### Box IoU matrix
Torchvision vs shapely vs shapely
Torchvision vs numpy vs powerboxes

![Box IoU](../images/box_iou.png)

### NMS
Torchvision vs powerboxes vs lsnms
Torchvision vs powerboxes vs lsnms vs numpy

#### Large image (10000x10000 pixels)

Expand Down
17 changes: 16 additions & 1 deletion bindings/python/powerboxes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from ._nms import _dtype_to_func_nms, _dtype_to_func_rtree_nms
from ._powerboxes import masks_to_boxes as _masks_to_boxes
from ._powerboxes import rotated_iou_distance as _rotated_iou_distance
from ._tiou import _dtype_to_func_tiou_distance

_BOXES_NOT_SAME_TYPE = "boxes1 and boxes2 must have the same dtype"
Expand Down Expand Up @@ -202,6 +203,19 @@ def tiou_distance(
raise ValueError(_BOXES_NOT_SAME_TYPE)


def rotated_iou_distance(
boxes1: npt.NDArray[T], boxes2: npt.NDArray[T]
) -> npt.NDArray[np.float64]:
if not isinstance(boxes1, np.ndarray) or not isinstance(boxes2, np.ndarray):
raise TypeError(_BOXES_NOT_NP_ARRAY)
if boxes1.dtype == boxes2.dtype == np.dtype("float64"):
return _rotated_iou_distance(boxes1, boxes2)
else:
raise TypeError(
f"Boxes dtype: {boxes1.dtype}, {boxes2.dtype} not in float64 dtype"
)


def remove_small_boxes(boxes: npt.NDArray[T], min_size) -> npt.NDArray[T]:
"""Remove boxes with area less than min_area.
Expand Down Expand Up @@ -365,7 +379,8 @@ def rtree_nms(
"masks_to_boxes",
"supported_dtypes",
"nms",
"tiou",
"tiou_distance",
"rotated_iou_distance"
"rtree_nms",
"__version__",
]
19 changes: 18 additions & 1 deletion bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use num_traits::{Bounded, Num, Signed, ToPrimitive};
use numpy::{PyArray1, PyArray2, PyArray3};
use powerboxesrs::{boxes, giou, iou, nms, tiou};
use pyo3::prelude::*;
use utils::{preprocess_array1, preprocess_array3, preprocess_boxes};
use utils::{preprocess_array1, preprocess_array3, preprocess_boxes, preprocess_rotated_boxes};

#[pymodule]
fn _powerboxes(_py: Python, m: &PyModule) -> PyResult<()> {
Expand Down Expand Up @@ -108,6 +108,8 @@ fn _powerboxes(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(rtree_nms_i16, m)?)?;
// Masks to boxes
m.add_function(wrap_pyfunction!(masks_to_boxes, m)?)?;
// Rotated IoU
m.add_function(wrap_pyfunction!(rotated_iou_distance, m)?)?;
Ok(())
}
// Masks to boxes
Expand All @@ -119,6 +121,21 @@ fn masks_to_boxes(_py: Python, masks: &PyArray3<bool>) -> PyResult<Py<PyArray2<u
return Ok(boxes_as_numpy.to_owned());
}

// Rotated box IoU

#[pyfunction]
fn rotated_iou_distance(
_py: Python,
boxes1: &PyArray2<f64>,
boxes2: &PyArray2<f64>,
) -> PyResult<Py<PyArray2<f64>>> {
let boxes1 = preprocess_rotated_boxes(boxes1).unwrap();
let boxes2 = preprocess_rotated_boxes(boxes2).unwrap();
let iou = iou::rotated_iou_distance(&boxes1, &boxes2);
let iou_as_numpy = utils::array_to_numpy(_py, iou).unwrap();
return Ok(iou_as_numpy.to_owned());
}

// IoU
fn iou_distance_generic<T>(
_py: Python,
Expand Down
30 changes: 29 additions & 1 deletion bindings/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ where

if array_shape[1] != 4 {
return Err(pyo3::exceptions::PyValueError::new_err(
"Arrays must have shape (N, 4)",
"Arrays must have at least shape (N, 4)",
));
} else {
let num_boxes = array_shape[0];
Expand All @@ -39,6 +39,34 @@ where
return Ok(array);
}

pub fn preprocess_rotated_boxes<N>(array: &PyArray2<N>) -> Result<Array2<N>, PyErr>
where
N: Num + numpy::Element + Send,
{
let array = unsafe { array.as_array() };
let array_shape = array.shape();

if array_shape[1] != 5 {
return Err(pyo3::exceptions::PyValueError::new_err(
"Arrays must have at least shape (N, 5)",
));
} else {
let num_boxes = array_shape[0];

if num_boxes == 0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"Arrays must have shape (N, 5) with N > 0",
));
}
}

let array = array
.to_owned()
.into_shape((array_shape[0], array_shape[1]))
.unwrap();
return Ok(array);
}

pub fn preprocess_array3<N>(array: &PyArray3<N>) -> Array3<N>
where
N: numpy::Element,
Expand Down
36 changes: 36 additions & 0 deletions bindings/tests/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
parallel_giou_distance,
parallel_iou_distance,
remove_small_boxes,
rotated_iou_distance,
rtree_nms,
supported_dtypes,
tiou_distance,
Expand Down Expand Up @@ -265,3 +266,38 @@ def test_rtree_nms_bad_dtype():
scores = np.random.random((100,))
with pytest.raises(TypeError):
rtree_nms(boxes1.astype(unsuported_dtype_example), scores, 0.5, 0.5)


@pytest.mark.parametrize("dtype", ["float64"])
def test_rotated_iou_distance(dtype):
boxes1 = np.random.random((100, 5))
boxes2 = np.random.random((100, 5))
rotated_iou_distance(
boxes1.astype(dtype),
boxes2.astype(dtype),
)


def test_rotated_iou_distance_bad_inputs():
with pytest.raises(TypeError, match=_BOXES_NOT_NP_ARRAY):
rotated_iou_distance("foo", "bar")
with pytest.raises(Exception):
try:
rotated_iou_distance(np.random.random((100, 4)), np.random.random((100, 4)))
except: # noqa: E722
raise RuntimeError()
with pytest.raises(RuntimeError):
try:
rotated_iou_distance(np.random.random((0, 4)), np.random.random((100, 4)))
except: # noqa: E722
raise RuntimeError()


def test_rotated_iou_distance_dtype():
boxes1 = np.random.random((100, 5))
boxes2 = np.random.random((100, 5))
with pytest.raises(TypeError):
rotated_iou_distance(
boxes1.astype(unsuported_dtype_example),
boxes2.astype(unsuported_dtype_example),
)
9 changes: 9 additions & 0 deletions bindings/tests/test_speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
parallel_giou_distance,
parallel_iou_distance,
remove_small_boxes,
rotated_iou_distance,
rtree_nms,
supported_dtypes,
tiou_distance,
Expand All @@ -29,6 +30,14 @@ def generate_boxes(request):
return np.concatenate([topleft, topleft + wh], axis=1).astype(np.float64)


@pytest.mark.benchmark(group="tiou_distance")
@pytest.mark.parametrize("dtype", ["float64"])
def test_rotated_iou_distance(benchmark, dtype):
boxes1 = np.random.random((100, 5)).astype(dtype)
boxes2 = np.random.random((100, 5)).astype(dtype)
benchmark(rotated_iou_distance, boxes1, boxes2)


@pytest.mark.benchmark(group="tiou_distance")
@pytest.mark.parametrize("dtype", supported_dtypes)
def test_tiou_distance(benchmark, dtype):
Expand Down
Binary file modified images/box_nms_large_image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/box_nms_normal_image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
108 changes: 106 additions & 2 deletions powerboxesrs/src/iou.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::{boxes, utils};
use ndarray::{Array2, Zip};
use crate::{boxes, rotation::cxcywha_to_points, utils};
use ndarray::{Array2, Axis, Zip};
use num_traits::{Num, ToPrimitive};
use rstar::{Envelope, RStarInsertionStrategy, RTree, RTreeNum, RTreeObject, RTreeParams, AABB};

/// Calculates the intersection over union (IoU) distance between two sets of bounding boxes.
///
Expand Down Expand Up @@ -165,6 +166,101 @@ where
return iou_matrix;
}

// Struct we use to represent a bbox object in rstar R-tree
struct OrientedBbox<T> {
index: usize,
x1: T,
y1: T,
x2: T,
y2: T,
x3: T,
y3: T,
x4: T,
y4: T,
}

// Implement RTreeObject for Bbox
impl<T> RTreeObject for OrientedBbox<T>
where
T: RTreeNum + ToPrimitive + Sync + Send,
{
type Envelope = AABB<[T; 2]>;

fn envelope(&self) -> Self::Envelope {
AABB::from_points([
&[self.x1, self.y1],
&[self.x2, self.y2],
&[self.x3, self.y3],
&[self.x4, self.y4],
])
}
}

impl<T> RTreeParams for OrientedBbox<T>
where
T: RTreeNum + ToPrimitive + Sync + Send,
{
const MIN_SIZE: usize = 16;
const MAX_SIZE: usize = 256;
const REINSERTION_COUNT: usize = 5;
type DefaultInsertionStrategy = RStarInsertionStrategy;
}

pub fn rotated_iou_distance(boxes1: &Array2<f64>, boxes2: &Array2<f64>) -> Array2<f64> {
let num_boxes1 = boxes1.nrows();
let num_boxes2 = boxes2.nrows();

let mut iou_matrix = Array2::<f64>::ones((num_boxes1, num_boxes2));
let points_boxes_1: Vec<OrientedBbox<f64>> = boxes1
.axis_iter(Axis(0))
.enumerate()
.map(|(i, row)| {
let (p1, p2, p3, p4) = cxcywha_to_points(row[0], row[1], row[2], row[3], row[4]);
OrientedBbox {
index: i,
x1: p1.x,
y1: p1.y,
x2: p2.x,
y2: p2.y,
x3: p3.x,
y3: p3.y,
x4: p4.x,
y4: p4.y,
}
})
.collect();
let points_boxes_2: Vec<OrientedBbox<f64>> = boxes2
.axis_iter(Axis(0))
.enumerate()
.map(|(i, row)| {
let (p1, p2, p3, p4) = cxcywha_to_points(row[0], row[1], row[2], row[3], row[4]);
OrientedBbox {
index: i,
x1: p1.x,
y1: p1.y,
x2: p2.x,
y2: p2.y,
x3: p3.x,
y3: p3.y,
x4: p4.x,
y4: p4.y,
}
})
.collect();
let rtree_boxes_1: RTree<OrientedBbox<f64>> = RTree::bulk_load(points_boxes_1);
let rtree_boxes_2: RTree<OrientedBbox<f64>> = RTree::bulk_load(points_boxes_2);

for (box1, box2) in rtree_boxes_1.intersection_candidates_with_other_tree(&rtree_boxes_2) {
let box1_envelope = box1.envelope();
let box2_envelope = box2.envelope();
let intersection = box1_envelope.intersection_area(&box2_envelope);
let iou = intersection
/ (box1_envelope.area() + box2_envelope.area() - intersection + utils::EPS);
iou_matrix[[box1.index, box2.index]] = utils::ONE - iou;
}
return iou_matrix;
}

#[cfg(test)]
mod tests {
use ndarray::arr2;
Expand Down Expand Up @@ -231,4 +327,12 @@ mod tests {
assert_eq!(parallel_iou_distance_result, arr2(&[[1.0]]));
assert_eq!(1. - iou_distance_result, iou_result);
}

#[test]
fn test_rotated_iou_disstance() {
let boxes1 = arr2(&[[5.0, 5.0, 2.0, 2.0, 0.0]]);
let boxes2 = arr2(&[[4.0, 4.0, 2.0, 2.0, 0.0]]);
let rotated_iou_distance_result = rotated_iou_distance(&boxes1, &boxes2);
assert_eq!(rotated_iou_distance_result, arr2(&[[0.8571428571428572]]));
}
}
1 change: 1 addition & 0 deletions powerboxesrs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,6 @@ pub mod boxes;
pub mod giou;
pub mod iou;
pub mod nms;
mod rotation;
pub mod tiou;
mod utils;
Loading

0 comments on commit 1c86706

Please sign in to comment.