Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metrics #887

Merged
merged 32 commits into from Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
46f224f
insert torchvision dependency and write tests for cifar10
BaruchG Aug 8, 2022
01d6f1e
removed print
BaruchG Aug 8, 2022
45c9a9d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 8, 2022
be1fa15
Merge remote-tracking branch 'upstream/master'
BaruchG Aug 19, 2022
86482ce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 19, 2022
adc57b9
cleanup failed merge
Aug 23, 2022
a2082e3
Merge branch 'master' into BaruchG/master
Aug 23, 2022
f8ba0e6
merge master
Aug 25, 2022
7230608
Merge remote-tracking branch 'upstream/master'
BaruchG Sep 6, 2022
36d57f8
Merge remote-tracking branch 'upstream/master'
BaruchG Sep 9, 2022
88eb870
Merge remote-tracking branch 'upstream/master'
BaruchG Sep 16, 2022
b1995a7
Revert "insert torchvision dependency and write tests for cifar10"
BaruchG Sep 16, 2022
2533023
Merge remote-tracking branch 'upstream/master'
BaruchG Sep 20, 2022
a9a5ab8
ensured tests are present for object detection and removed under review
BaruchG Sep 20, 2022
6ae60a8
cifar10 revert
BaruchG Sep 20, 2022
bfdae01
revert cifar10
BaruchG Sep 20, 2022
c6b4752
revert cifar10
BaruchG Sep 20, 2022
92d3297
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 20, 2022
88f0ff5
Merge branch 'master' into metrics
BaruchG Sep 21, 2022
54e1110
Merge branch 'master' into metrics
BaruchG Sep 21, 2022
f9954f4
renamed variables to conform to specs
BaruchG Sep 21, 2022
f5c5af2
Merge branch 'metrics' of https://github.com/BaruchG/lightning-bolts …
BaruchG Sep 21, 2022
d38a965
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 21, 2022
6a0eacf
added newline
BaruchG Sep 21, 2022
44dedfb
added newline
BaruchG Sep 21, 2022
0d4cc23
upgraded to assert_close and modified tolerance
BaruchG Sep 21, 2022
c32979a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 21, 2022
519c1fe
modified formatting of docstring
BaruchG Sep 22, 2022
d902fa1
Merge branch 'metrics' of https://github.com/BaruchG/lightning-bolts …
BaruchG Sep 22, 2022
42cc317
Merge branch 'Lightning-AI:master' into metrics
BaruchG Sep 22, 2022
d320008
Merge branch 'master' into metrics
mergify[bot] Sep 23, 2022
63de9e8
Merge branch 'master' into metrics
mergify[bot] Sep 23, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 5 additions & 9 deletions pl_bolts/metrics/object_detection.py
@@ -1,10 +1,7 @@
import torch
from torch import Tensor

from pl_bolts.utils.stability import under_review


@under_review()
def iou(preds: Tensor, target: Tensor) -> Tensor:
"""Calculates the intersection over union.

Expand Down Expand Up @@ -33,11 +30,10 @@ def iou(preds: Tensor, target: Tensor) -> Tensor:
pred_area = (preds[:, 2] - preds[:, 0]) * (preds[:, 3] - preds[:, 1])
target_area = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
union = pred_area[:, None] + target_area - intersection
iou = torch.true_divide(intersection, union)
return iou
iou_value = torch.true_divide(intersection, union)
return iou_value


@under_review()
def giou(preds: Tensor, target: Tensor) -> Tensor:
"""Calculates the generalized intersection over union.

Expand Down Expand Up @@ -74,6 +70,6 @@ def giou(preds: Tensor, target: Tensor) -> Tensor:
C_x_max = torch.max(preds[:, None, 2], target[:, 2])
C_y_max = torch.max(preds[:, None, 3], target[:, 3])
C_area = (C_x_max - C_x_min).clamp(min=0) * (C_y_max - C_y_min).clamp(min=0)
iou = torch.true_divide(intersection, union)
giou = iou - torch.true_divide((C_area - union), C_area)
return giou
iou_value = torch.true_divide(intersection, union)
giou_value = iou_value - torch.true_divide((C_area - union), C_area)
return giou_value
12 changes: 6 additions & 6 deletions tests/metrics/test_object_detection.py
Expand Up @@ -11,7 +11,7 @@
[(torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 100, 200, 200]]), torch.tensor([[1.0]]))],
)
def test_iou_complete_overlap(preds, target, expected_iou):
torch.testing.assert_allclose(iou(preds, target), expected_iou)
torch.testing.assert_close(iou(preds, target), expected_iou)


@pytest.mark.parametrize(
Expand All @@ -22,7 +22,7 @@ def test_iou_complete_overlap(preds, target, expected_iou):
],
)
def test_iou_no_overlap(preds, target, expected_iou):
torch.testing.assert_allclose(iou(preds, target), expected_iou)
torch.testing.assert_close(iou(preds, target), expected_iou)


@pytest.mark.parametrize(
Expand All @@ -36,15 +36,15 @@ def test_iou_no_overlap(preds, target, expected_iou):
],
)
def test_iou_multi(preds, target, expected_iou):
torch.testing.assert_allclose(iou(preds, target), expected_iou)
torch.testing.assert_close(iou(preds, target), expected_iou)


@pytest.mark.parametrize(
"preds, target, expected_giou",
[(torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 100, 200, 200]]), torch.tensor([[1.0]]))],
)
def test_complete_overlap(preds, target, expected_giou):
torch.testing.assert_allclose(giou(preds, target), expected_giou)
torch.testing.assert_close(giou(preds, target), expected_giou)


@pytest.mark.parametrize(
Expand All @@ -55,7 +55,7 @@ def test_complete_overlap(preds, target, expected_giou):
],
)
def test_no_overlap(preds, target, expected_giou):
torch.testing.assert_allclose(giou(preds, target), expected_giou)
torch.testing.assert_close(giou(preds, target), expected_giou)


@pytest.mark.parametrize(
Expand All @@ -69,4 +69,4 @@ def test_no_overlap(preds, target, expected_giou):
],
)
def test_giou_multi(preds, target, expected_giou):
torch.testing.assert_allclose(giou(preds, target), expected_giou)
torch.testing.assert_close(giou(preds, target), expected_giou, atol=0.0001, rtol=0.0001)