Skip to content

Commit

Permalink
Rename allowed_filter_attrs to allowed_sort_attrs (#38626)
Browse files Browse the repository at this point in the history
These are the attrs we allow sorting on, not filtering on.
  • Loading branch information
jedcunningham committed Mar 30, 2024
1 parent 0723a8f commit e700f41
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 16 deletions.
4 changes: 2 additions & 2 deletions airflow/api_connexion/endpoints/connection_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,11 @@ def get_connections(
) -> APIResponse:
"""Get all connection entries."""
to_replace = {"connection_id": "conn_id"}
allowed_filter_attrs = ["connection_id", "conn_type", "description", "host", "port", "id"]
allowed_sort_attrs = ["connection_id", "conn_type", "description", "host", "port", "id"]

total_entries = session.execute(select(func.count(Connection.id))).scalar_one()
query = select(Connection)
query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs)
query = apply_sorting(query, order_by, to_replace, allowed_sort_attrs)
connections = session.scalars(query.offset(offset).limit(limit)).all()
return connection_collection_schema.dump(
ConnectionCollection(connections=connections, total_entries=total_entries)
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_connexion/endpoints/dag_run_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _fetch_dag_runs(

total_entries = get_query_count(query, session=session)
to_replace = {"dag_run_id": "run_id"}
allowed_filter_attrs = [
allowed_sort_attrs = [
"id",
"state",
"dag_id",
Expand All @@ -184,7 +184,7 @@ def _fetch_dag_runs(
"external_trigger",
"conf",
]
query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs)
query = apply_sorting(query, order_by, to_replace, allowed_sort_attrs)
return session.scalars(query.offset(offset).limit(limit)).all(), total_entries


Expand Down
4 changes: 2 additions & 2 deletions airflow/api_connexion/endpoints/dag_warning_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_dag_warnings(
:param dag_id: the dag_id to optionally filter by
:param warning_type: the warning type to optionally filter by
"""
allowed_filter_attrs = ["dag_id", "warning_type", "message", "timestamp"]
allowed_sort_attrs = ["dag_id", "warning_type", "message", "timestamp"]
query = select(DagWarningModel)
if dag_id:
query = query.where(DagWarningModel.dag_id == dag_id)
Expand All @@ -65,7 +65,7 @@ def get_dag_warnings(
if warning_type:
query = query.where(DagWarningModel.warning_type == warning_type)
total_entries = get_query_count(query, session=session)
query = apply_sorting(query=query, order_by=order_by, allowed_attrs=allowed_filter_attrs)
query = apply_sorting(query=query, order_by=order_by, allowed_attrs=allowed_sort_attrs)
dag_warnings = session.scalars(query.offset(offset).limit(limit)).all()
return dag_warning_collection_schema.dump(
DagWarningCollection(dag_warnings=dag_warnings, total_entries=total_entries)
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_connexion/endpoints/import_error_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ def get_import_errors(
) -> APIResponse:
"""Get all import errors."""
to_replace = {"import_error_id": "id"}
allowed_filter_attrs = ["import_error_id", "timestamp", "filename"]
allowed_sort_attrs = ["import_error_id", "timestamp", "filename"]
count_query = select(func.count(ImportErrorModel.id))
query = select(ImportErrorModel)
query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs)
query = apply_sorting(query, order_by, to_replace, allowed_sort_attrs)

can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET")

Expand Down
4 changes: 2 additions & 2 deletions airflow/api_connexion/endpoints/pool_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ def get_pools(
) -> APIResponse:
"""Get all pools."""
to_replace = {"name": "pool"}
allowed_filter_attrs = ["name", "slots", "id"]
allowed_sort_attrs = ["name", "slots", "id"]
total_entries = session.scalars(func.count(Pool.id)).one()
query = select(Pool)
query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs)
query = apply_sorting(query, order_by, to_replace, allowed_sort_attrs)
pools = session.scalars(query.offset(offset).limit(limit)).all()
return pool_collection_schema.dump(PoolCollection(pools=pools, total_entries=total_entries))

Expand Down
4 changes: 2 additions & 2 deletions airflow/api_connexion/endpoints/variable_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def get_variables(
"""Get all variable values."""
total_entries = session.execute(select(func.count(Variable.id))).scalar()
to_replace = {"value": "val"}
allowed_filter_attrs = ["value", "key", "id"]
allowed_sort_attrs = ["value", "key", "id"]
query = select(Variable)
query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs)
query = apply_sorting(query, order_by, to_replace, allowed_sort_attrs)
variables = session.scalars(query.offset(offset).limit(limit)).all()
return variable_collection_schema.dump(
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def get_roles(*, order_by: str = "name", limit: int, offset: int | None = None)
to_replace = {"role_id": "id"}
order_param = order_by.strip("-")
order_param = to_replace.get(order_param, order_param)
allowed_filter_attrs = ["role_id", "name"]
if order_by not in allowed_filter_attrs:
allowed_sort_attrs = ["role_id", "name"]
if order_by not in allowed_sort_attrs:
raise BadRequest(
detail=f"Ordering with '{order_by}' is disallowed or "
f"the attribute does not exist on the model"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def get_users(*, limit: int, order_by: str = "id", offset: str | None = None) ->
to_replace = {"user_id": "id"}
order_param = order_by.strip("-")
order_param = to_replace.get(order_param, order_param)
allowed_filter_attrs = [
allowed_sort_attrs = [
"id",
"first_name",
"last_name",
Expand All @@ -74,7 +74,7 @@ def get_users(*, limit: int, order_by: str = "id", offset: str | None = None) ->
"is_active",
"role",
]
if order_by not in allowed_filter_attrs:
if order_by not in allowed_sort_attrs:
raise BadRequest(
detail=f"Ordering with '{order_by}' is disallowed or "
f"the attribute does not exist on the model"
Expand Down

0 comments on commit e700f41

Please sign in to comment.