Skip to content

Commit

Permalink
Merge branch 'final_data_collection' of https://github.com/Cubevoid/a…
Browse files Browse the repository at this point in the history
…tari-obj-pred into final_data_collection
  • Loading branch information
quajak committed Apr 17, 2024
2 parents 436d964 + b4b32e6 commit 375d7dc
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/scripts/dataset_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from src.data_collection.data_loader import DataLoader


def calculate_iou(boxes1, boxes2):
def calculate_iou(boxes1: npt.NDArray, boxes2: npt.NDArray) -> npt.NDArray:
"""
Calculate IoU (Intersection over Union) between corresponding bounding boxes.
boxes1 and boxes2 should have the format (T, O, 4), where T is the number of timesteps,
O is the number of objects, and 4 represents (x2, y2, x1, y1).
"""
x1_1, y1_1, x2_1, y2_1 = np.split(boxes1, 4, axis=-1)
x1_2, y1_2, x2_2, y2_2 = np.split(boxes2, 4, axis=-1)
x1_1, y1_1, x2_1, y2_1 = np.split(boxes1, 4, axis=-1) # pylint: disable=unbalanced-tuple-unpacking
x1_2, y1_2, x2_2, y2_2 = np.split(boxes2, 4, axis=-1) # pylint: disable=unbalanced-tuple-unpacking
x1_1, x2_1 = np.minimum(x1_1, x2_1), np.maximum(x1_1, x2_1)
y1_1, y2_1 = np.minimum(y1_1, y2_1), np.maximum(y1_1, y2_1)
x1_2, x2_2 = np.minimum(x1_2, x2_2), np.maximum(x1_2, x2_2)
Expand Down Expand Up @@ -59,7 +59,7 @@ def main(game: str) -> None:
# Calculate IoU between FastSAM and SAM bboxes
iou = calculate_iou(objects_sam, objects_fastsam).squeeze(-1)
num_obj = np.maximum(obj_per_frame_sam, obj_per_frame_fastsam) # per frame
ious = [np.mean(iou[i, : num_obj[i]]) for i in range(num_frames)]
ious = [np.mean(iou[i, : num_obj[i]]) for i in range(min(num_frames, 50))]
# Make new plot of IoUs
plt.figure()
plt.plot(ious, label="Mean IoU")
Expand Down

0 comments on commit 375d7dc

Please sign in to comment.