Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
czaloom committed May 23, 2024
1 parent 82dc08f commit 87b0130
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 89 deletions.
60 changes: 1 addition & 59 deletions api/valor_api/backend/query/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,6 @@
Polygon,
)

category_to_supported_operations = {
"nullable": {"isnull", "isnotnull"},
"equatable": {"eq", "ne"},
"quantifiable": {"gt", "ge", "lt", "le"},
"spatial": {"intersects", "inside", "outside"},
}


opstr_to_operator = {
"equal": operator.eq,
"notequal": operator.ne,
Expand All @@ -75,51 +67,6 @@
}


filterable_types_to_function_category = {
"bool": {"equatable"},
"string": {"equatable"},
"integer": {"equatable", "quantifiable"},
"float": {"equatable", "quantifiable"},
"datetime": {"equatable", "quantifiable"},
"date": {"equatable", "quantifiable"},
"time": {"equatable", "quantifiable"},
"duration": {"equatable", "quantifiable"},
"point": {"equatable", "spatial"},
"multipoint": {"spatial"},
"linestring": {"spatial"},
"multilinestring": {"spatial"},
"polygon": {"spatial"},
"box": {"spatial"},
"multipolygon": {"spatial"},
"raster": {"spatial"},
"tasktypeenum": {"equatable"},
"label": {"equatable"},
"embedding": {},
}


symbol_name_to_categories = {
"dataset.name": {"equatable"},
"dataset.metadata": {},
"model.name": {"equatable"},
"model.metadata": {},
"datum.uid": {"equatable"},
"datum.metadata": {},
"annotation.box": {"spatial", "nullable"},
"annotation.polygon": {"spatial", "nullable"},
"annotation.raster": {"spatial", "nullable"},
"annotation.embedding": {},
"annotation.metadata": {},
"annotation.labels": {"equatable"},
"label.key": {"equatable"},
"label.value": {"equatable"},
}


symbol_attribute_to_categories = {
"area": {"equatable", "quantifiable"},
}

symbol_name_to_row_id_value = {
"dataset.name": (Dataset.id, Dataset.name),
"dataset.metadata": (Dataset.id, Dataset.meta),
Expand All @@ -137,6 +84,7 @@
"label.value": (Label.id, Label.value),
}


symbol_supports_attribute = {
"area": {
"annotation.box": lambda x: ST_Area(x),
Expand All @@ -149,12 +97,6 @@
}
}

symbol_supports_key = {
"dataset.metadata",
"model.metadata",
"datum.metadata",
"annotation.metadata",
}

metadata_symbol_type_casting = {
"bool": lambda x: x.astext.cast(Boolean),
Expand Down
74 changes: 44 additions & 30 deletions api/valor_api/backend/query/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
Boolean,
Float,
Integer,
alias,
and_,
case,
cast,
func,
not_,
or_,
select,
case,
alias,
)
from sqlalchemy.sql.elements import BinaryExpression, ColumnElement

Expand All @@ -23,9 +23,11 @@
Model,
Prediction,
)
from valor_api.backend.query.filtering import _recursive_search_logic_tree, generate_logic
from valor_api.backend.query.filtering import (
_recursive_search_logic_tree,
generate_logic,
)
from valor_api.backend.query.mapping import map_arguments_to_tables
from valor_api.backend.query.solvers import solve_graph
from valor_api.backend.query.types import TableTypeAlias
from valor_api.schemas.filters import FilterType

Expand Down Expand Up @@ -60,57 +62,69 @@ def select_from(self, *args):
self._selected = map_arguments_to_tables(args)
return self

def filter(self, conditions: FilterType, pivot = Annotation):
def filter(
self,
conditions: FilterType | None,
pivot=Annotation,
is_groundtruth: bool = True,
):
gt_or_pd = GroundTruth if is_groundtruth else Prediction
tree, ctes = _recursive_search_logic_tree(conditions)

if not ctes or not tree:
raise ValueError

agg = (
select(
pivot.id.label("pivot_id"),
*[
case(
(row_id == cte.c.id, 1),
else_=0
).label(f"cte{idx}")
case((row_id == cte.c.id, 1), else_=0).label(f"cte{idx}")
for idx, (row_id, cte) in enumerate(ctes)
],
)
.select_from(Annotation)
.join(Datum, Datum.id == Annotation.datum_id)
.join(Dataset, Dataset.id == Datum.dataset_id)
)
agg = self.filter(conditions, pivot)

for row_id, cte in ctes:
if row_id == Label.id:
gt = alias(GroundTruth)
agg = agg.join(gt, gt.c.annotation_id == Annotation.id)
agg = agg.join(cte, cte.c.id == gt.c.label_id, isouter=True)
label_linker = alias(gt_or_pd)
agg = agg.join(
label_linker, label_linker.c.annotation_id == Annotation.id
)
agg = agg.join(
cte, cte.c.id == label_linker.c.label_id, isouter=True
)
else:
agg = agg.join(cte, cte.c.id == row_id, isouter=True)
agg = agg.cte()

q = (
select(*self._args)
.select_from(pivot)
)
q = select(*self._args).select_from(pivot)
if pivot is Annotation:
q = q.join(Datum, Datum.id == Annotation.datum_id)
q = q.join(Dataset, Dataset.id == Datum.dataset_id)
q = q.join(GroundTruth, GroundTruth.annotation_id == Annotation.id)
q = q.join(Label, Label.id == GroundTruth.label_id)
q = q.join(gt_or_pd, gt_or_pd.annotation_id == Annotation.id)
q = q.join(Label, Label.id == gt_or_pd.label_id)

elif pivot is Datum:
q = q.join(Annotation, Annotation.datum_id == Datum.id)
q = q.join(Dataset, Dataset.id == Datum.dataset_id)
q = q.join(GroundTruth, GroundTruth.annotation_id == Annotation.id)
q = q.join(Label, Label.id == GroundTruth.label_id)
q = q.join(gt_or_pd, gt_or_pd.annotation_id == Annotation.id)
q = q.join(Label, Label.id == gt_or_pd.label_id)

q = q.join(agg, agg.c.pivot_id == pivot.id)
return q.where(generate_logic(agg, tree))

def filter_annotations(self, conditions: FilterType):
return self.filter(conditions, pivot=Annotation)

def filter_datums(self, conditions: FilterType):
return self.filter(conditions, pivot=Datum)

def filter_groundtruths(
self, conditions: FilterType | None, pivot=Annotation
):
return self.filter(
conditions=conditions, pivot=pivot, is_groundtruth=False
)

def filter_predictions(
self, conditions: FilterType | None, pivot=Annotation
):
return self.filter(
conditions=conditions, pivot=pivot, is_groundtruth=False
)

0 comments on commit 87b0130

Please sign in to comment.