Skip to content

Commit

Permalink
feat: request ids on API related endpoints (#12663)
Browse files Browse the repository at this point in the history
* feat: request ids on API related endpoints

* rename ids to include_ids
  • Loading branch information
dpgaspar authored Jan 27, 2021
1 parent 11ca730 commit 365770e
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 26 deletions.
5 changes: 4 additions & 1 deletion superset/charts/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,10 @@
}
},
"related": {
"get": {"description": "Get a list of all possible owners for a chart."}
"get": {
"description": "Get a list of all possible owners for a chart. "
"Use `owners` has the `column_name` parameter"
}
},
}

Expand Down
64 changes: 49 additions & 15 deletions superset/views/base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"properties": {
"page_size": {"type": "integer"},
"page": {"type": "integer"},
"include_ids": {"type": "array", "items": {"type": "integer"}},
"filter": {"type": "string"},
},
}
Expand Down Expand Up @@ -213,7 +214,10 @@ def __init__(self) -> None:
super().__init__()

def add_apispec_components(self, api_spec: APISpec) -> None:

"""
Adds extra OpenApi schema spec components, these are declared
on the `openapi_spec_component_schemas` class property
"""
for schema in self.openapi_spec_component_schemas:
try:
api_spec.components.schema(
Expand Down Expand Up @@ -271,6 +275,40 @@ def _get_distinct_filter(self, column_name: str, value: str) -> Filters:
)
return filters

def _get_text_for_model(self, model: Model, column_name: str) -> str:
if column_name in self.text_field_rel_fields:
model_column_name = self.text_field_rel_fields.get(column_name)
if model_column_name:
return getattr(model, model_column_name)
return str(model)

def _get_result_from_rows(
self, datamodel: SQLAInterface, rows: List[Model], column_name: str
) -> List[Dict[str, Any]]:
return [
{
"value": datamodel.get_pk_value(row),
"text": self._get_text_for_model(row, column_name),
}
for row in rows
]

def _add_extra_ids_to_result(
self,
datamodel: SQLAInterface,
column_name: str,
ids: List[int],
result: List[Dict[str, Any]],
) -> None:
if ids:
# Filter out already present values on the result
values = [row["value"] for row in result]
ids = [id_ for id_ in ids if id_ not in values]
pk_col = datamodel.get_pk()
# Fetch requested values from ids
extra_rows = db.session.query(datamodel.obj).filter(pk_col.in_(ids)).all()
result += self._get_result_from_rows(datamodel, extra_rows, column_name)

def incr_stats(self, action: str, func_name: str) -> None:
"""
Proxy function for statsd.incr to impose a key structure for REST API's
Expand Down Expand Up @@ -424,18 +462,11 @@ def related(self, column_name: str, **kwargs: Any) -> FlaskResponse:
500:
$ref: '#/components/responses/500'
"""

def get_text_for_model(model: Model) -> str:
if column_name in self.text_field_rel_fields:
model_column_name = self.text_field_rel_fields.get(column_name)
if model_column_name:
return getattr(model, model_column_name)
return str(model)

if column_name not in self.allowed_rel_fields:
self.incr_stats("error", self.related.__name__)
return self.response_404()
args = kwargs.get("rison", {})

# handle pagination
page, page_size = self._handle_page_args(args)
try:
Expand All @@ -452,15 +483,18 @@ def get_text_for_model(model: Model) -> str:
# handle filters
filters = self._get_related_filter(datamodel, column_name, args.get("filter"))
# Make the query
count, values = datamodel.query(
_, rows = datamodel.query(
filters, order_column, order_direction, page=page, page_size=page_size
)

# produce response
result = [
{"value": datamodel.get_pk_value(value), "text": get_text_for_model(value)}
for value in values
]
return self.response(200, count=count, result=result)
result = self._get_result_from_rows(datamodel, rows, column_name)

# If ids are specified make sure we fetch and include them on the response
ids = args.get("include_ids")
self._add_extra_ids_to_result(datamodel, column_name, ids, result)

return self.response(200, count=len(result), result=result)

@expose("/distinct/<column_name>", methods=["GET"])
@protect()
Expand Down
58 changes: 48 additions & 10 deletions tests/base_api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,48 +184,86 @@ class ApiOwnersTestCaseMixin:

def test_get_related_owners(self):
"""
API: Test get related owners
API: Test get related owners
"""
self.login(username="admin")
uri = f"api/v1/{self.resource_name}/related/owners"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
users = db.session.query(security_manager.user_model).all()
expected_users = [str(user) for user in users]
self.assertEqual(response["count"], len(users))
assert response["count"] == len(users)
# This needs to be implemented like this, because ordering varies between
# postgres and mysql
response_users = [result["text"] for result in response["result"]]
for expected_user in expected_users:
self.assertIn(expected_user, response_users)
assert expected_user in response_users

def test_get_filter_related_owners(self):
"""
API: Test get filter related owners
API: Test get filter related owners
"""
self.login(username="admin")
argument = {"filter": "gamma"}
uri = f"api/v1/{self.resource_name}/related/owners?q={prison.dumps(argument)}"

rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(3, response["count"])
assert 3 == response["count"]
sorted_results = sorted(response["result"], key=lambda value: value["text"])
expected_results = [
{"text": "gamma user", "value": 2},
{"text": "gamma2 user", "value": 3},
{"text": "gamma_sqllab user", "value": 4},
]
self.assertEqual(expected_results, sorted_results)
assert expected_results == sorted_results

def test_get_ids_related_owners(self):
"""
API: Test get filter related owners
"""
self.login(username="admin")
argument = {"filter": "gamma_sqllab", "include_ids": [2]}
uri = f"api/v1/{self.resource_name}/related/owners?q={prison.dumps(argument)}"

rv = self.client.get(uri)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert 2 == response["count"]
sorted_results = sorted(response["result"], key=lambda value: value["text"])
expected_results = [
{"text": "gamma user", "value": 2},
{"text": "gamma_sqllab user", "value": 4},
]
assert expected_results == sorted_results

def test_get_repeated_ids_related_owners(self):
"""
API: Test get filter related owners
"""
self.login(username="admin")
argument = {"filter": "gamma_sqllab", "include_ids": [2, 4]}
uri = f"api/v1/{self.resource_name}/related/owners?q={prison.dumps(argument)}"

rv = self.client.get(uri)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert 2 == response["count"]
sorted_results = sorted(response["result"], key=lambda value: value["text"])
expected_results = [
{"text": "gamma user", "value": 2},
{"text": "gamma_sqllab user", "value": 4},
]
assert expected_results == sorted_results

def test_get_related_fail(self):
"""
API: Test get related fail
API: Test get related fail
"""
self.login(username="admin")
uri = f"api/v1/{self.resource_name}/related/owner"

rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404

0 comments on commit 365770e

Please sign in to comment.