Skip to content

Commit

Permalink
feat: rotated tiou distance (#43)
Browse files Browse the repository at this point in the history
  • Loading branch information
Smirkey authored Jan 21, 2024
1 parent e190351 commit e443fea
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 3 deletions.
36 changes: 35 additions & 1 deletion bindings/python/powerboxes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ._powerboxes import masks_to_boxes as _masks_to_boxes
from ._powerboxes import rotated_giou_distance as _rotated_giou_distance
from ._powerboxes import rotated_iou_distance as _rotated_iou_distance
from ._powerboxes import rotated_tiou_distance as _rotated_tiou_distance
from ._tiou import _dtype_to_func_tiou_distance

_BOXES_NOT_SAME_TYPE = "boxes1 and boxes2 must have the same dtype"
Expand Down Expand Up @@ -235,7 +236,7 @@ def rotated_iou_distance(


def rotated_giou_distance(
boxes1: npt.NDArray[T], boxes2: npt.NDArray[T]
boxes1: npt.NDArray[np.float64], boxes2: npt.NDArray[np.float64]
) -> npt.NDArray[np.float64]:
"""Compute the pairwise giou distance between rotated boxes
Expand Down Expand Up @@ -264,6 +265,38 @@ def rotated_giou_distance(
)


def rotated_tiou_distance(
boxes1: npt.NDArray[np.float64], boxes2: npt.NDArray[np.float64]
) -> npt.NDArray[np.float64]:
"""Compute pairwise box tiou (tracking iou) distances.
see https://arxiv.org/pdf/2310.05171.pdf for tiou definition
Boxes should be in (cx, cy, w, h, a) format
where cx and cy are center coordinates, w and h
width and height and a, the angle in degrees
Args:
boxes1: 2d array of boxes in cxywha format
boxes2: 2d array of boxes in cxywha format
Raises:
TypeError: if boxes1 or boxes2 are not numpy arrays
ValueError: if boxes1 and boxes2 have different dtypes
Returns:
np.ndarray: 2d matrix of pairwise distances
"""
if not isinstance(boxes1, np.ndarray) or not isinstance(boxes2, np.ndarray):
raise TypeError(_BOXES_NOT_NP_ARRAY)
if boxes1.dtype == boxes2.dtype == np.dtype("float64"):
return _rotated_tiou_distance(boxes1, boxes2)
else:
raise TypeError(
f"Boxes dtype: {boxes1.dtype}, {boxes2.dtype} not in float64 dtype"
)


def remove_small_boxes(boxes: npt.NDArray[T], min_size) -> npt.NDArray[T]:
"""Remove boxes with area less than min_area.
Expand Down Expand Up @@ -430,6 +463,7 @@ def rtree_nms(
"tiou_distance",
"rotated_iou_distance",
"rotated_giou_distance",
"rotated_tiou_distance",
"rtree_nms",
"__version__",
]
17 changes: 17 additions & 0 deletions bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ fn _powerboxes(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(rotated_iou_distance, m)?)?;
// Rotated GIoU
m.add_function(wrap_pyfunction!(rotated_giou_distance, m)?)?;
// Rotated TIoU
m.add_function(wrap_pyfunction!(rotated_tiou_distance, m)?)?;
Ok(())
}
// Masks to boxes
Expand Down Expand Up @@ -153,6 +155,21 @@ fn rotated_giou_distance(
return Ok(iou_as_numpy.to_owned());
}

// Rotated box TIoU

#[pyfunction]
fn rotated_tiou_distance(
_py: Python,
boxes1: &PyArray2<f64>,
boxes2: &PyArray2<f64>,
) -> PyResult<Py<PyArray2<f64>>> {
let boxes1 = preprocess_rotated_boxes(boxes1).unwrap();
let boxes2 = preprocess_rotated_boxes(boxes2).unwrap();
let iou = tiou::rotated_tiou_distance(&boxes1, &boxes2);
let iou_as_numpy = utils::array_to_numpy(_py, iou).unwrap();
return Ok(iou_as_numpy.to_owned());
}

// IoU
fn iou_distance_generic<T>(
_py: Python,
Expand Down
38 changes: 38 additions & 0 deletions bindings/tests/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
remove_small_boxes,
rotated_giou_distance,
rotated_iou_distance,
rotated_tiou_distance,
rtree_nms,
supported_dtypes,
tiou_distance,
Expand Down Expand Up @@ -339,3 +340,40 @@ def test_rotated_giou_distance_dtype():
boxes1.astype(unsuported_dtype_example),
boxes2.astype(unsuported_dtype_example),
)


@pytest.mark.parametrize("dtype", ["float64"])
def test_rotated_tiou_distance(dtype):
boxes1 = np.random.random((100, 5))
boxes2 = np.random.random((100, 5))
rotated_tiou_distance(
boxes1.astype(dtype),
boxes2.astype(dtype),
)


def test_rotated_tiou_distance_bad_inputs():
with pytest.raises(TypeError, match=_BOXES_NOT_NP_ARRAY):
rotated_tiou_distance("foo", "bar")
with pytest.raises(Exception):
try:
rotated_tiou_distance(
np.random.random((100, 4)), np.random.random((100, 4))
)
except: # noqa: E722
raise RuntimeError()
with pytest.raises(RuntimeError):
try:
rotated_tiou_distance(np.random.random((0, 4)), np.random.random((100, 4)))
except: # noqa: E722
raise RuntimeError()


def test_rotated_tiou_distance_dtype():
boxes1 = np.random.random((100, 5))
boxes2 = np.random.random((100, 5))
with pytest.raises(TypeError):
rotated_tiou_distance(
boxes1.astype(unsuported_dtype_example),
boxes2.astype(unsuported_dtype_example),
)
9 changes: 9 additions & 0 deletions bindings/tests/test_speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
remove_small_boxes,
rotated_giou_distance,
rotated_iou_distance,
rotated_tiou_distance,
rtree_nms,
supported_dtypes,
tiou_distance,
Expand All @@ -31,6 +32,14 @@ def generate_boxes(request):
return np.concatenate([topleft, topleft + wh], axis=1).astype(np.float64)


@pytest.mark.benchmark(group="rotated_tiou_distance")
@pytest.mark.parametrize("dtype", ["float64"])
def test_rotated_tiou_distance(benchmark, dtype):
boxes1 = np.random.random((100, 5)).astype(dtype)
boxes2 = np.random.random((100, 5)).astype(dtype)
benchmark(rotated_tiou_distance, boxes1, boxes2)


@pytest.mark.benchmark(group="rotated_iou_distance")
@pytest.mark.parametrize("dtype", ["float64"])
def test_rotated_iou_distance(benchmark, dtype):
Expand Down
88 changes: 86 additions & 2 deletions powerboxesrs/src/tiou.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use ndarray::Array2;
use num_traits::{Num, ToPrimitive};

use crate::{boxes, utils};
use crate::{
boxes::{self, rotated_box_areas},
rotation::{minimal_bounding_rect, Rect},
utils,
};
/// Computes the Tracking Intersection over Union (TIOU) distance between two sets of bounding boxes.
/// see https://arxiv.org/pdf/2310.05171.pdf
/// # Arguments
Expand Down Expand Up @@ -67,18 +71,98 @@ where
tiou_matrix
}

/// Calculates the rotated tracking IoU (Tiou) distance between two sets of rotated bounding boxes.
///
/// Given two sets of rotated bounding boxes represented by `boxes1` and `boxes2`, this function
/// computes the rotated Tiou distance matrix between them. The rotated Tiou distance is a measure
/// of dissimilarity between two rotated bounding boxes, taking into account both their overlap
/// and the encompassing area.
///
/// # Arguments
///
/// * `boxes1` - A reference to a 2D array (Array2) containing the parameters of the first set of rotated bounding boxes.
/// Each row of `boxes1` represents a rotated bounding box with parameters [center_x, center_y, width, height, angle in degrees].
///
/// * `boxes2` - A reference to a 2D array (Array2) containing the parameters of the second set of rotated bounding boxes.
/// Each row of `boxes2` represents a rotated bounding box with parameters [center_x, center_y, width, height, angle in degrees].
///
/// # Returns
///
/// A 2D array (Array2) representing the rotated Tiou distance matrix between the input sets of rotated bounding boxes.
/// The element at position (i, j) in the matrix represents the rotated Giou distance between the i-th box in `boxes1` and
/// the j-th box in `boxes2`.
///
pub fn rotated_tiou_distance(boxes1: &Array2<f64>, boxes2: &Array2<f64>) -> Array2<f64> {
let num_boxes1 = boxes1.nrows();
let num_boxes2 = boxes2.nrows();

let mut iou_matrix = Array2::<f64>::ones((num_boxes1, num_boxes2));
let areas1 = rotated_box_areas(&boxes1);
let areas2 = rotated_box_areas(&boxes2);

let boxes1_rects: Vec<(f64, f64, f64, f64)> = boxes1
.rows()
.into_iter()
.map(|row| {
minimal_bounding_rect(&Rect::new(row[0], row[1], row[2], row[3], row[4]).points())
})
.collect();
let boxes2_rects: Vec<(f64, f64, f64, f64)> = boxes2
.rows()
.into_iter()
.map(|row| {
minimal_bounding_rect(&Rect::new(row[0], row[1], row[2], row[3], row[4]).points())
})
.collect();

for (i, r1) in boxes1_rects.iter().enumerate() {
let area1 = areas1[i];
let (x1_r1, y1_r1, x2_r1, y2_r1) = r1;

for (j, r2) in boxes2_rects.iter().enumerate() {
let area2 = areas2[j];
let (x1_r2, y1_r2, x2_r2, y2_r2) = r2;

// Calculate the enclosing box (C) coordinates
let c_x1 = utils::min(*x1_r1, *x1_r2);
let c_y1 = utils::min(*y1_r1, *y1_r2);
let c_x2 = utils::max(*x2_r1, *x2_r2);
let c_y2 = utils::max(*y2_r1, *y2_r2);
// Calculate the area of the enclosing box (C)
let c_area = (c_x2 - c_x1) * (c_y2 - c_y1);
let c_area = c_area.to_f64().unwrap();
iou_matrix[[i, j]] = utils::ONE - utils::min(area1 / c_area, area2 / c_area)
}
}
return iou_matrix;
}

#[cfg(test)]
mod tests {
use ndarray::arr2;

use super::*;

#[test]
fn test_giou() {
fn test_tiou() {
let boxes1 = arr2(&[[0.0, 0.0, 3.0, 3.0], [1.0, 1.0, 4.0, 4.0]]);
let boxes2 = arr2(&[[2.0, 2.0, 5.0, 5.0], [3.0, 3.0, 6.0, 6.0]]);

let tiou_matrix = tiou_distance(&boxes1, &boxes2);
assert_eq!(tiou_matrix, arr2(&[[0.64, 0.75], [0.4375, 0.64]]));
}
#[test]
fn test_rotated_tiou() {
let boxes1 = arr2(&[[0.0, 0.0, 3.0, 3.0, 20.0], [1.0, 1.0, 4.0, 4.0, 19.0]]);
let boxes2 = arr2(&[[2.0, 2.0, 5.0, 5.0, 0.0], [3.0, 3.0, 6.0, 6.0, 20.0]]);

let tiou_matrix = rotated_tiou_distance(&boxes1, &boxes2);
assert_eq!(
tiou_matrix,
arr2(&[
[0.7818149787949012, 0.8829233169330242],
[0.561738213456193, 0.7725560385451797]
])
);
}
}

0 comments on commit e443fea

Please sign in to comment.