Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/final_data_collection' into fina…
Browse files Browse the repository at this point in the history
…l_data_collection
  • Loading branch information
Ben Edidin committed Apr 17, 2024
2 parents 37c24f8 + 375d7dc commit 0523043
Show file tree
Hide file tree
Showing 22 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions configs/training/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ lr: 0.001
batch_size: 32
time_steps: 10
num_objects: 8
name: debug-training
name: final-training
game: Pong
model: SAM
model: FastSAM-x
num_iterations: 1000
ground_truth_masks: false

Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added models/trained/Pong/1713395308_feat_extract.pth
Binary file not shown.
Binary file not shown.
Binary file added models/trained/Pong/1713395447_feat_extract.pth
Binary file not shown.
8 changes: 4 additions & 4 deletions src/scripts/dataset_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from src.data_collection.data_loader import DataLoader


def calculate_iou(boxes1, boxes2):
def calculate_iou(boxes1: npt.NDArray, boxes2: npt.NDArray) -> npt.NDArray:
"""
Calculate IoU (Intersection over Union) between corresponding bounding boxes.
boxes1 and boxes2 should have the format (T, O, 4), where T is the number of timesteps,
O is the number of objects, and 4 represents (x2, y2, x1, y1).
"""
x1_1, y1_1, x2_1, y2_1 = np.split(boxes1, 4, axis=-1)
x1_2, y1_2, x2_2, y2_2 = np.split(boxes2, 4, axis=-1)
x1_1, y1_1, x2_1, y2_1 = np.split(boxes1, 4, axis=-1) # pylint: disable=unbalanced-tuple-unpacking
x1_2, y1_2, x2_2, y2_2 = np.split(boxes2, 4, axis=-1) # pylint: disable=unbalanced-tuple-unpacking
x1_1, x2_1 = np.minimum(x1_1, x2_1), np.maximum(x1_1, x2_1)
y1_1, y2_1 = np.minimum(y1_1, y2_1), np.maximum(y1_1, y2_1)
x1_2, x2_2 = np.minimum(x1_2, x2_2), np.maximum(x1_2, x2_2)
Expand Down Expand Up @@ -59,7 +59,7 @@ def main(game: str) -> None:
# Calculate IoU between FastSAM and SAM bboxes
iou = calculate_iou(objects_sam, objects_fastsam).squeeze(-1)
num_obj = np.maximum(obj_per_frame_sam, obj_per_frame_fastsam) # per frame
ious = [np.mean(iou[i, : num_obj[i]]) for i in range(num_frames)]
ious = [np.mean(iou[i, : num_obj[i]]) for i in range(min(num_frames, 50))]
# Make new plot of IoUs
plt.figure()
plt.plot(ious, label="Mean IoU")
Expand Down

0 comments on commit 0523043

Please sign in to comment.