Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute classification metrics locally via numpy #651

Draft
wants to merge 30 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
d7c8edf
working through integration tests
ntlind Jul 3, 2024
abf6075
Merge branch 'main' into compute_local_metrics
ntlind Jul 5, 2024
18528d6
pass functional tests after merge conflicts
ntlind Jul 5, 2024
0334741
pass non-ROC integration tests
ntlind Jul 5, 2024
56767df
add ROCAUC metric
ntlind Jul 6, 2024
d0f0df0
pass a few more integration tests
ntlind Jul 8, 2024
f1d4b5a
pass integration tests
ntlind Jul 8, 2024
b0238c3
Merge branch 'main' into compute_local_metrics
ntlind Jul 9, 2024
a1422ce
minor cleanup
ntlind Jul 9, 2024
17dbb00
update npm test
ntlind Jul 9, 2024
2e39ade
up benchmarks
ntlind Jul 9, 2024
1f4c354
update benchmarks
ntlind Jul 9, 2024
026094a
pass first set of AR and AP tests
ntlind Jul 10, 2024
8e9828a
add curves
ntlind Jul 11, 2024
3869b27
fix OD iou calculation for rasters
ntlind Jul 11, 2024
9cd1654
pass functional tests
ntlind Jul 11, 2024
1fa240c
fix groundtruths with no predictions
ntlind Jul 12, 2024
3e1f438
Merge branch 'main' into compute_local_metrics
ntlind Jul 12, 2024
62a660a
fix gts in PR output
ntlind Jul 12, 2024
39466db
pass more integration tests, edge cases still outstanding
ntlind Jul 12, 2024
b75087b
pass integration tests
ntlind Jul 16, 2024
dd19fff
small benchmarking script changes
ntlind Jul 16, 2024
87930e2
clean up and add aggregate OD functions
ntlind Jul 16, 2024
479dab0
remove deletion
ntlind Jul 16, 2024
66b47b2
adjust benchmarks
ntlind Jul 16, 2024
6dc83bc
update od benchmarks
ntlind Jul 17, 2024
e5e27a7
save progress on detailed pr curves; doesn't pass tests
ntlind Jul 19, 2024
ee31b93
finish detailed PR curves for OD
ntlind Jul 24, 2024
172a2bc
refactor classification
ntlind Jul 24, 2024
3189d7d
refactor classification and pass all tests
ntlind Jul 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,23 @@ repos:
args: [--line-length=79]

- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.350
rev: v1.1.370
hooks:
- id: pyright
additional_dependencies:
[
"requests",
"Pillow >= 9.1.0",
"numpy",
"pandas >= 2.2.2",
"pytest",
"python-dotenv",
"SQLAlchemy>=2.0",
"fastapi[all]>=0.100.0",
"importlib_metadata; python_version < '3.8'",
"pydantic-settings",
"tqdm",
"pandas",
"pandas >= 2.2.2",
"packaging",
"PyJWT[crypto]",
"structlog",
Expand Down
1 change: 1 addition & 0 deletions api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies = [
"SQLAlchemy>=2.0",
"Pillow >= 9.1.0",
"numpy",
"pandas>=2.2.2",
"python-dotenv",
"pydantic-settings",
"structlog",
Expand Down
8 changes: 5 additions & 3 deletions api/tests/functional-tests/crud/test_create_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,9 +1121,11 @@ def method_to_test(
dataset_names=["test_dataset"],
model_names=["test_model"],
filters=schemas.Filter(
annotations=schemas.LogicalFunction.and_(*conditions)
if conditions
else None,
annotations=(
schemas.LogicalFunction.and_(*conditions)
if conditions
else None
),
labels=schemas.Condition(
lhs=schemas.Symbol(name=schemas.SupportedSymbol.LABEL_KEY),
rhs=schemas.Value.infer(label_key),
Expand Down
3 changes: 3 additions & 0 deletions api/tests/unit-tests/backend/metrics/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ def test__calculate_101_pt_interp():
# make sure we get back 0 if we don't pass any precisions
assert _calculate_101_pt_interp([], []) == 0

# get back -1 if all recalls and precisions are -1
assert _calculate_101_pt_interp([-1, -1], [-1, -1]) == -1


def test__compute_mean_detection_metrics_from_aps():
# make sure we get back 0 if we don't pass any precisions
Expand Down
Loading
Loading