diff --git a/api/valor_api/backend/query/filtering.py b/api/valor_api/backend/query/filtering.py index 42b618cdb..419f682bf 100644 --- a/api/valor_api/backend/query/filtering.py +++ b/api/valor_api/backend/query/filtering.py @@ -51,14 +51,6 @@ Polygon, ) -category_to_supported_operations = { - "nullable": {"isnull", "isnotnull"}, - "equatable": {"eq", "ne"}, - "quantifiable": {"gt", "ge", "lt", "le"}, - "spatial": {"intersects", "inside", "outside"}, -} - - opstr_to_operator = { "equal": operator.eq, "notequal": operator.ne, @@ -75,51 +67,6 @@ } -filterable_types_to_function_category = { - "bool": {"equatable"}, - "string": {"equatable"}, - "integer": {"equatable", "quantifiable"}, - "float": {"equatable", "quantifiable"}, - "datetime": {"equatable", "quantifiable"}, - "date": {"equatable", "quantifiable"}, - "time": {"equatable", "quantifiable"}, - "duration": {"equatable", "quantifiable"}, - "point": {"equatable", "spatial"}, - "multipoint": {"spatial"}, - "linestring": {"spatial"}, - "multilinestring": {"spatial"}, - "polygon": {"spatial"}, - "box": {"spatial"}, - "multipolygon": {"spatial"}, - "raster": {"spatial"}, - "tasktypeenum": {"equatable"}, - "label": {"equatable"}, - "embedding": {}, -} - - -symbol_name_to_categories = { - "dataset.name": {"equatable"}, - "dataset.metadata": {}, - "model.name": {"equatable"}, - "model.metadata": {}, - "datum.uid": {"equatable"}, - "datum.metadata": {}, - "annotation.box": {"spatial", "nullable"}, - "annotation.polygon": {"spatial", "nullable"}, - "annotation.raster": {"spatial", "nullable"}, - "annotation.embedding": {}, - "annotation.metadata": {}, - "annotation.labels": {"equatable"}, - "label.key": {"equatable"}, - "label.value": {"equatable"}, -} - - -symbol_attribute_to_categories = { - "area": {"equatable", "quantifiable"}, -} - symbol_name_to_row_id_value = { "dataset.name": (Dataset.id, Dataset.name), "dataset.metadata": (Dataset.id, Dataset.meta), @@ -137,6 +84,7 @@ "label.value": (Label.id, Label.value), } + symbol_supports_attribute = { "area": { "annotation.box": lambda x: ST_Area(x), @@ -149,12 +97,6 @@ } } -symbol_supports_key = { - "dataset.metadata", - "model.metadata", - "datum.metadata", - "annotation.metadata", -} metadata_symbol_type_casting = { "bool": lambda x: x.astext.cast(Boolean), diff --git a/api/valor_api/backend/query/ops.py b/api/valor_api/backend/query/ops.py index 6a6f37f68..75077ed55 100644 --- a/api/valor_api/backend/query/ops.py +++ b/api/valor_api/backend/query/ops.py @@ -3,14 +3,14 @@ Boolean, Float, Integer, + alias, and_, + case, cast, func, not_, or_, select, - case, - alias, ) from sqlalchemy.sql.elements import BinaryExpression, ColumnElement @@ -23,9 +23,11 @@ Model, Prediction, ) -from valor_api.backend.query.filtering import _recursive_search_logic_tree, generate_logic +from valor_api.backend.query.filtering import ( + _recursive_search_logic_tree, + generate_logic, +) from valor_api.backend.query.mapping import map_arguments_to_tables -from valor_api.backend.query.solvers import solve_graph from valor_api.backend.query.types import TableTypeAlias from valor_api.schemas.filters import FilterType @@ -60,20 +62,21 @@ def select_from(self, *args): self._selected = map_arguments_to_tables(args) return self - def filter(self, conditions: FilterType, pivot = Annotation): + def filter( + self, + conditions: FilterType | None, + pivot=Annotation, + is_groundtruth: bool = True, + ): + gt_or_pd = GroundTruth if is_groundtruth else Prediction tree, ctes = _recursive_search_logic_tree(conditions) - if not ctes or not tree: raise ValueError - agg = ( select( pivot.id.label("pivot_id"), *[ - case( - (row_id == cte.c.id, 1), - else_=0 - ).label(f"cte{idx}") + case((row_id == cte.c.id, 1), else_=0).label(f"cte{idx}") for idx, (row_id, cte) in enumerate(ctes) ], ) @@ -81,36 +84,47 @@ def filter(self, conditions: FilterType, pivot = Annotation): .join(Datum, Datum.id == Annotation.datum_id) .join(Dataset, Dataset.id == Datum.dataset_id) ) + agg = self.filter(conditions, pivot) + for row_id, cte in ctes: if row_id == Label.id: - gt = alias(GroundTruth) - agg = agg.join(gt, gt.c.annotation_id == Annotation.id) - agg = agg.join(cte, cte.c.id == gt.c.label_id, isouter=True) + label_linker = alias(gt_or_pd) + agg = agg.join( + label_linker, label_linker.c.annotation_id == Annotation.id + ) + agg = agg.join( + cte, cte.c.id == label_linker.c.label_id, isouter=True + ) else: agg = agg.join(cte, cte.c.id == row_id, isouter=True) agg = agg.cte() - q = ( - select(*self._args) - .select_from(pivot) - ) + q = select(*self._args).select_from(pivot) if pivot is Annotation: q = q.join(Datum, Datum.id == Annotation.datum_id) q = q.join(Dataset, Dataset.id == Datum.dataset_id) - q = q.join(GroundTruth, GroundTruth.annotation_id == Annotation.id) - q = q.join(Label, Label.id == GroundTruth.label_id) - + q = q.join(gt_or_pd, gt_or_pd.annotation_id == Annotation.id) + q = q.join(Label, Label.id == gt_or_pd.label_id) + elif pivot is Datum: q = q.join(Annotation, Annotation.datum_id == Datum.id) q = q.join(Dataset, Dataset.id == Datum.dataset_id) - q = q.join(GroundTruth, GroundTruth.annotation_id == Annotation.id) - q = q.join(Label, Label.id == GroundTruth.label_id) - + q = q.join(gt_or_pd, gt_or_pd.annotation_id == Annotation.id) + q = q.join(Label, Label.id == gt_or_pd.label_id) + q = q.join(agg, agg.c.pivot_id == pivot.id) return q.where(generate_logic(agg, tree)) - - def filter_annotations(self, conditions: FilterType): - return self.filter(conditions, pivot=Annotation) - - def filter_datums(self, conditions: FilterType): - return self.filter(conditions, pivot=Datum) + + def filter_groundtruths( + self, conditions: FilterType | None, pivot=Annotation + ): + return self.filter( + conditions=conditions, pivot=pivot, is_groundtruth=False + ) + + def filter_predictions( + self, conditions: FilterType | None, pivot=Annotation + ): + return self.filter( + conditions=conditions, pivot=pivot, is_groundtruth=False + )