Skip to content

Commit

Permalink
refactor: use rtree in rotated box iou (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
Smirkey committed Jan 15, 2024
1 parent e500cba commit 4dc5b2c
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 43 deletions.
71 changes: 63 additions & 8 deletions powerboxesrs/src/iou.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{
};
use ndarray::{Array2, Zip};
use num_traits::{Num, ToPrimitive};
use rstar::RTree;

/// Calculates the intersection over union (IoU) distance between two sets of bounding boxes.
///
Expand Down Expand Up @@ -156,16 +157,70 @@ pub fn rotated_iou_distance(boxes1: &Array2<f64>, boxes2: &Array2<f64>) -> Array
.into_iter()
.map(|row| Rect::new(row[0], row[1], row[2], row[3], row[4]))
.collect();
let boxes1_bounding_rects: Vec<utils::Bbox<f64>> = boxes1_rects
.iter()
.enumerate()
.map(|(idx, rect)| {
let points = rect.points();
let (min_x, max_x) = points
.iter()
.map(|p| p.x)
.fold((f64::INFINITY, f64::NEG_INFINITY), |acc, x| {
(acc.0.min(x), acc.1.max(x))
});
let (min_y, max_y) = points
.iter()
.map(|p| p.y)
.fold((f64::INFINITY, f64::NEG_INFINITY), |acc, y| {
(acc.0.min(y), acc.1.max(y))
});
utils::Bbox {
index: idx,
x1: min_x,
y1: min_y,
x2: max_x,
y2: max_y,
}
})
.collect();
let boxes2_bounding_rects: Vec<utils::Bbox<f64>> = boxes2_rects
.iter()
.enumerate()
.map(|(idx, rect)| {
let points = rect.points();
let (min_x, max_x) = points
.iter()
.map(|p| p.x)
.fold((f64::INFINITY, f64::NEG_INFINITY), |acc, x| {
(acc.0.min(x), acc.1.max(x))
});
let (min_y, max_y) = points
.iter()
.map(|p| p.y)
.fold((f64::INFINITY, f64::NEG_INFINITY), |acc, y| {
(acc.0.min(y), acc.1.max(y))
});
utils::Bbox {
index: idx,
x1: min_x,
y1: min_y,
x2: max_x,
y2: max_y,
}
})
.collect();

for (i, rect1) in boxes1_rects.iter().enumerate() {
let area1 = areas1[i];
for (j, rect2) in boxes2_rects.iter().enumerate() {
let area2 = areas2[j];
let intersection = intersection_area(*rect1, *rect2);
let union = area1 + area2 - intersection + utils::EPS;
iou_matrix[[i, j]] = utils::ONE - intersection / union;
}
let box1_rtree: RTree<utils::Bbox<f64>> = RTree::bulk_load(boxes1_bounding_rects);
let box2_rtree: RTree<utils::Bbox<f64>> = RTree::bulk_load(boxes2_bounding_rects);

for (box1, box2) in box1_rtree.intersection_candidates_with_other_tree(&box2_rtree) {
let area1 = areas1[box1.index];
let area2 = areas2[box2.index];
let intersection = intersection_area(boxes1_rects[box1.index], boxes2_rects[box2.index]);
let union = area1 + area2 - intersection + utils::EPS;
iou_matrix[[box1.index, box2.index]] = utils::ONE - intersection / union;
}

return iou_matrix;
}

Expand Down
37 changes: 3 additions & 34 deletions powerboxesrs/src/nms.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::cmp::Ordering;
use crate::{boxes, utils};
use ndarray::{Array1, Array2, Axis};
use num_traits::{Num, ToPrimitive};
use rstar::{RStarInsertionStrategy, RTree, RTreeNum, RTreeObject, RTreeParams, AABB};
use rstar::{RTree, RTreeNum, AABB};

/// Performs non-maximum suppression (NMS) on a set of bounding boxes using their scores and IoU.
/// # Arguments
Expand Down Expand Up @@ -94,37 +94,6 @@ where
return Array1::from(keep);
}

// Struct we use to represent a bbox object in rstar R-tree
struct Bbox<T> {
index: usize,
x1: T,
y1: T,
x2: T,
y2: T,
}

// Implement RTreeObject for Bbox
impl<T> RTreeObject for Bbox<T>
where
T: RTreeNum + ToPrimitive + Sync + Send,
{
type Envelope = AABB<[T; 2]>;

fn envelope(&self) -> Self::Envelope {
AABB::from_corners([self.x1, self.y1], [self.x2, self.y2])
}
}

impl<T> RTreeParams for Bbox<T>
where
T: RTreeNum + ToPrimitive + Sync + Send,
{
const MIN_SIZE: usize = 16;
const MAX_SIZE: usize = 256;
const REINSERTION_COUNT: usize = 5;
type DefaultInsertionStrategy = RStarInsertionStrategy;
}

/// Performs non-maximum suppression (NMS) on a set of bounding using their score and IoU.
/// This function internally uses an RTree to speed up the computation. It is recommended to use this function
/// when the number of boxes is large.
Expand Down Expand Up @@ -181,12 +150,12 @@ where
let mut suppress = Array1::from_elem(scores.len(), false);
// build rtree

let rtree: RTree<Bbox<N>> = RTree::bulk_load(
let rtree: RTree<utils::Bbox<N>> = RTree::bulk_load(
order
.iter()
.map(|&idx| {
let box_ = boxes.row(idx);
Bbox {
utils::Bbox {
x1: box_[0],
y1: box_[1],
x2: box_[2],
Expand Down
34 changes: 33 additions & 1 deletion powerboxesrs/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use num_traits::Num;
use num_traits::{Num, ToPrimitive};
use rstar::{RStarInsertionStrategy, RTreeNum, RTreeObject, RTreeParams, AABB};

pub const EPS: f64 = 1e-16;
pub const ONE: f64 = 1.0;
Expand Down Expand Up @@ -26,6 +27,37 @@ where
}
}

// Struct we use to represent a bbox object in rstar R-tree
pub struct Bbox<T> {
pub index: usize,
pub x1: T,
pub y1: T,
pub x2: T,
pub y2: T,
}

// Implement RTreeObject for Bbox
impl<T> RTreeObject for Bbox<T>
where
T: RTreeNum + ToPrimitive + Sync + Send,
{
type Envelope = AABB<[T; 2]>;

fn envelope(&self) -> Self::Envelope {
AABB::from_corners([self.x1, self.y1], [self.x2, self.y2])
}
}

impl<T> RTreeParams for Bbox<T>
where
T: RTreeNum + ToPrimitive + Sync + Send,
{
const MIN_SIZE: usize = 16;
const MAX_SIZE: usize = 256;
const REINSERTION_COUNT: usize = 5;
type DefaultInsertionStrategy = RStarInsertionStrategy;
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down

0 comments on commit 4dc5b2c

Please sign in to comment.