Skip to content

Commit

Permalink
REST API: add separate user projection in query help (#2793)
Browse files Browse the repository at this point in the history
Automatically projection user email for nodes relied on hybrid
properties in SqlAlchemy, which have been removed. The alternative is to
explicitly join on the User entity and project separately.
  • Loading branch information
Snehal Kumbhar authored and sphuber committed May 3, 2019
1 parent 948fa0c commit dadc5d1
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 9 deletions.
2 changes: 1 addition & 1 deletion aiida/restapi/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def get(self, id=None, page=None): # pylint: disable=redefined-builtin,invalid-
results[query_type]["filename"])
return response

results = results[query_type]["data"]
results = results[query_type]

headers = self.utils.build_headers(url=request.url, total_count=total_count)

Expand Down
15 changes: 13 additions & 2 deletions aiida/restapi/translator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class BaseTranslator(object):
_result_type = __label__

_default = _default_projections = ["**"]
_default_user_projections = None

_schema_projections = {"column_order": [], "additional_info": {}}

Expand Down Expand Up @@ -80,6 +81,7 @@ def __init__(self, Class=None, **kwargs):

self._default = Class._default # pylint: disable=protected-access
self._default_projections = Class._default_projections # pylint: disable=protected-access
self._default_user_projections = Class._default_user_projections # pylint: disable=protected-access
self._schema_projections = Class._schema_projections # pylint: disable=protected-access
self._is_qb_initialized = Class._is_qb_initialized # pylint: disable=protected-access
self._is_id_query = Class._is_id_query # pylint: disable=protected-access
Expand All @@ -92,7 +94,7 @@ def __init__(self, Class=None, **kwargs):
# basic query_help object
self._query_help = {
"path": [{
"entity_type": self._qb_type,
"cls": self._aiida_class,
"tag": self.__label__
}],
"filters": {},
Expand Down Expand Up @@ -281,6 +283,11 @@ def set_projections(self, projections):
if projections:
for project_key, project_list in projections.items():
self._query_help["project"][project_key] = project_list
if self._default_user_projections:
from aiida.orm import User
self._query_help["path"].insert(0, {"cls": User, "tag": "user"})
self._query_help["path"][1]["with_user"] = "user"
self._query_help["project"]["user"] = self._default_user_projections
else:
raise InputValidationError("Pass data in dictionary format where "
"keys are the tag names given in the "
Expand Down Expand Up @@ -462,7 +469,11 @@ def get_formatted_result(self, label):

results = []
if self._total_count > 0:
results = [res[label] for res in self.qbobj.dict()]
for res in self.qbobj.dict():
tmp = res[label]
if self._default_user_projections:
tmp["user_email"] = res["user"]["email"]
results.append(tmp)

# TODO think how to make it less hardcoded
if self._result_type == 'with_outgoing':
Expand Down
12 changes: 6 additions & 6 deletions aiida/restapi/translator/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ def __init__(self, Class=None, **kwargs):
super(NodeTranslator, self).__init__(Class=Class, **kwargs)

self._default_projections = [
"id", "label", "node_type", "ctime", "mtime", "uuid", "user_id", "user_email", "attributes", "extras"
"id", "label", "node_type", "ctime", "mtime", "uuid", "user_id", "attributes", "extras"
]
self._default_user_projections = ["email"]

## node schema
# All the values from column_order must present in additional info dict
Expand Down Expand Up @@ -258,8 +259,7 @@ def _get_content(self):
if not self._is_qb_initialized:
raise InvalidOperation("query builder object has not been initialized.")

## Count the total number of rows returned by the query (if not
# already done)
# Count the total number of rows returned by the query (if not already done)
if self._total_count is None:
self.count()

Expand All @@ -268,7 +268,7 @@ def _get_content(self):
return {}

# otherwise ...
node = self.qbobj.first()[0]
node = self.qbobj.first()[1]

# content/attributes
if self._content_type == "attributes":
Expand Down Expand Up @@ -715,12 +715,12 @@ def get_node_shape(ntype):
# count total no of nodes
builder = QueryBuilder()
builder.append(Node, tag="main", project=['id'], filters=self._id_filter)
builder.append(Node, tag="in", project=['id'], input_of='main')
builder.append(Node, tag="in", project=['id'], with_outgoing='main')
total_no_of_incomings = builder.count()

builder = QueryBuilder()
builder.append(Node, tag="main", project=['id'], filters=self._id_filter)
builder.append(Node, tag="out", project=['id'], output_of='main')
builder.append(Node, tag="out", project=['id'], with_incoming='main')
total_no_of_outgoings = builder.count()

return {
Expand Down

0 comments on commit dadc5d1

Please sign in to comment.