Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: request ids on API related endpoints #12663

Merged
merged 2 commits into from
Jan 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion superset/charts/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@
}
},
"related": {
"get": {"description": "Get a list of all possible owners for a chart."}
"get": {
"description": "Get a list of all possible owners for a chart. "
"Use `owners` has the `column_name` parameter"
}
},
}

Expand Down
64 changes: 49 additions & 15 deletions superset/views/base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"properties": {
"page_size": {"type": "integer"},
"page": {"type": "integer"},
"include_ids": {"type": "array", "items": {"type": "integer"}},
"filter": {"type": "string"},
},
}
Expand Down Expand Up @@ -213,7 +214,10 @@ def __init__(self) -> None:
super().__init__()

def add_apispec_components(self, api_spec: APISpec) -> None:

"""
Adds extra OpenApi schema spec components, these are declared
on the `openapi_spec_component_schemas` class property
"""
for schema in self.openapi_spec_component_schemas:
try:
api_spec.components.schema(
Expand Down Expand Up @@ -271,6 +275,40 @@ def _get_distinct_filter(self, column_name: str, value: str) -> Filters:
)
return filters

def _get_text_for_model(self, model: Model, column_name: str) -> str:
if column_name in self.text_field_rel_fields:
model_column_name = self.text_field_rel_fields.get(column_name)
if model_column_name:
return getattr(model, model_column_name)
return str(model)

def _get_result_from_rows(
self, datamodel: SQLAInterface, rows: List[Model], column_name: str
) -> List[Dict[str, Any]]:
return [
{
"value": datamodel.get_pk_value(row),
"text": self._get_text_for_model(row, column_name),
}
for row in rows
]

def _add_extra_ids_to_result(
self,
datamodel: SQLAInterface,
column_name: str,
ids: List[int],
result: List[Dict[str, Any]],
) -> None:
if ids:
# Filter out already present values on the result
values = [row["value"] for row in result]
ids = [id_ for id_ in ids if id_ not in values]
pk_col = datamodel.get_pk()
# Fetch requested values from ids
extra_rows = db.session.query(datamodel.obj).filter(pk_col.in_(ids)).all()
result += self._get_result_from_rows(datamodel, extra_rows, column_name)

def incr_stats(self, action: str, func_name: str) -> None:
"""
Proxy function for statsd.incr to impose a key structure for REST API's
Expand Down Expand Up @@ -424,18 +462,11 @@ def related(self, column_name: str, **kwargs: Any) -> FlaskResponse:
500:
$ref: '#/components/responses/500'
"""

def get_text_for_model(model: Model) -> str:
if column_name in self.text_field_rel_fields:
model_column_name = self.text_field_rel_fields.get(column_name)
if model_column_name:
return getattr(model, model_column_name)
return str(model)

if column_name not in self.allowed_rel_fields:
self.incr_stats("error", self.related.__name__)
return self.response_404()
args = kwargs.get("rison", {})

# handle pagination
page, page_size = self._handle_page_args(args)
try:
Expand All @@ -452,15 +483,18 @@ def get_text_for_model(model: Model) -> str:
# handle filters
filters = self._get_related_filter(datamodel, column_name, args.get("filter"))
# Make the query
count, values = datamodel.query(
_, rows = datamodel.query(
filters, order_column, order_direction, page=page, page_size=page_size
)

# produce response
result = [
{"value": datamodel.get_pk_value(value), "text": get_text_for_model(value)}
for value in values
]
return self.response(200, count=count, result=result)
result = self._get_result_from_rows(datamodel, rows, column_name)

# If ids are specified make sure we fetch and include them on the response
ids = args.get("include_ids")
self._add_extra_ids_to_result(datamodel, column_name, ids, result)

return self.response(200, count=len(result), result=result)

@expose("/distinct/<column_name>", methods=["GET"])
@protect()
Expand Down
58 changes: 48 additions & 10 deletions tests/base_api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,48 +184,86 @@ class ApiOwnersTestCaseMixin:

def test_get_related_owners(self):
"""
API: Test get related owners
API: Test get related owners
"""
self.login(username="admin")
uri = f"api/v1/{self.resource_name}/related/owners"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
users = db.session.query(security_manager.user_model).all()
expected_users = [str(user) for user in users]
self.assertEqual(response["count"], len(users))
assert response["count"] == len(users)
# This needs to be implemented like this, because ordering varies between
# postgres and mysql
response_users = [result["text"] for result in response["result"]]
for expected_user in expected_users:
self.assertIn(expected_user, response_users)
assert expected_user in response_users

def test_get_filter_related_owners(self):
"""
API: Test get filter related owners
API: Test get filter related owners
"""
self.login(username="admin")
argument = {"filter": "gamma"}
uri = f"api/v1/{self.resource_name}/related/owners?q={prison.dumps(argument)}"

rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(3, response["count"])
assert 3 == response["count"]
sorted_results = sorted(response["result"], key=lambda value: value["text"])
expected_results = [
{"text": "gamma user", "value": 2},
{"text": "gamma2 user", "value": 3},
{"text": "gamma_sqllab user", "value": 4},
]
self.assertEqual(expected_results, sorted_results)
assert expected_results == sorted_results

def test_get_ids_related_owners(self):
"""
API: Test get filter related owners
"""
self.login(username="admin")
argument = {"filter": "gamma_sqllab", "include_ids": [2]}
uri = f"api/v1/{self.resource_name}/related/owners?q={prison.dumps(argument)}"

rv = self.client.get(uri)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert 2 == response["count"]
sorted_results = sorted(response["result"], key=lambda value: value["text"])
expected_results = [
{"text": "gamma user", "value": 2},
{"text": "gamma_sqllab user", "value": 4},
]
assert expected_results == sorted_results

def test_get_repeated_ids_related_owners(self):
"""
API: Test get filter related owners
"""
self.login(username="admin")
argument = {"filter": "gamma_sqllab", "include_ids": [2, 4]}
uri = f"api/v1/{self.resource_name}/related/owners?q={prison.dumps(argument)}"

rv = self.client.get(uri)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert 2 == response["count"]
sorted_results = sorted(response["result"], key=lambda value: value["text"])
expected_results = [
{"text": "gamma user", "value": 2},
{"text": "gamma_sqllab user", "value": 4},
]
assert expected_results == sorted_results

def test_get_related_fail(self):
"""
API: Test get related fail
API: Test get related fail
"""
self.login(username="admin")
uri = f"api/v1/{self.resource_name}/related/owner"

rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404