Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions datajunction-server/datajunction_server/api/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
MaterializationConfigOutput,
)
from datajunction_server.models.query import ColumnMetadata, QueryWithResults
from datajunction_server.naming import LOOKUP_CHARS
from datajunction_server.naming import LOOKUP_CHARS, from_amenable_name
from datajunction_server.service_clients import QueryServiceClient
from datajunction_server.sql.parsing import ast
from datajunction_server.typing import END_JOB_STATES
Expand Down Expand Up @@ -743,7 +743,7 @@ async def build_sql_for_multiple_metrics(
query_parameters=query_parameters,
)
columns = [
assemble_column_metadata(col) # type: ignore
assemble_column_metadata(col, use_semantic_metadata=True) # type: ignore
for col in query_ast.select.projection
]
upstream_tables = [tbl for tbl in query_ast.find_all(ast.Table) if tbl.dj_node]
Expand Down Expand Up @@ -892,24 +892,29 @@ async def build_sql_for_dj_query( # pragma: no cover

def assemble_column_metadata(
column: ast.Column,
# node_name: Union[List[str], str],
use_semantic_metadata: bool = False,
) -> ColumnMetadata:
"""
Extract column metadata from AST
"""
has_semantic_entity = hasattr(column, "semantic_entity") and column.semantic_entity

if use_semantic_metadata and has_semantic_entity:
column_name = column.semantic_entity.split(SEPARATOR)[-1] # type: ignore
node_name = SEPARATOR.join(column.semantic_entity.split(SEPARATOR)[:-1]) # type: ignore
else:
column_name = getattr(column.name, "name", None)
node_name = (
from_amenable_name(column.table.alias_or_name.name) # type: ignore
if hasattr(column, "table") and column.table
else None
)

metadata = ColumnMetadata(
name=column.alias_or_name.name,
type=str(column.type),
column=(
column.semantic_entity.split(SEPARATOR)[-1]
if hasattr(column, "semantic_entity") and column.semantic_entity
else None
),
node=(
SEPARATOR.join(column.semantic_entity.split(SEPARATOR)[:-1])
if hasattr(column, "semantic_entity") and column.semantic_entity
else None
),
column=column_name,
node=node_name,
semantic_entity=column.semantic_entity
if hasattr(column, "semantic_entity")
else None,
Expand Down
2 changes: 1 addition & 1 deletion datajunction-server/datajunction_server/api/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ async def build_and_save_node_sql(
ignore_errors=ignore_errors,
)
columns = [
assemble_column_metadata(col) # type: ignore
assemble_column_metadata(col, use_semantic_metadata=True) # type: ignore
for col in query_ast.select.projection
]
query = str(query_ast)
Expand Down
6 changes: 6 additions & 0 deletions datajunction-server/datajunction_server/construction/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ def rename_columns(
projection = []
node_columns = {col.name for col in node.columns}
for expression in built_ast.select.projection:
if hasattr(expression, "semantic_entity") and expression.semantic_entity: # type: ignore
# If the expression already has a semantic entity, we assume it is already
# fully qualified and skip renaming.
projection.append(expression)
expression.set_alias(ast.Name(amenable_name(expression.semantic_entity))) # type: ignore
continue
if (
not isinstance(expression, ast.Alias)
and not isinstance(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ async def get_measures_query(
from_amenable_name(identifier).split(SEPARATOR)[-1]
in parent_columns[parent_node.name]
or identifier in parent_columns[parent_node.name]
or expr.semantic_entity in dimensions_without_roles
or from_amenable_name(identifier) in dimensions_without_roles
)
)
Expand Down Expand Up @@ -253,6 +254,7 @@ async def get_measures_query(
columns_metadata = [
assemble_column_metadata( # pragma: no cover
cast(ast.Column, col),
preaggregate,
)
for col in final_query.select.projection
]
Expand Down Expand Up @@ -863,6 +865,8 @@ def set_dimension_aliases(self):
new_alias = amenable_name(dim_name)
if node_col and new_alias not in self.final_ast.select.column_mapping:
node_col.set_alias(ast.Name(amenable_name(dim_name)))
node_col.set_semantic_entity(dim_name)
node_col.set_semantic_type(SemanticType.DIMENSION)

async def add_request_by_node_name(self, node_name):
"""Add a node request to the access control validator."""
Expand Down
4 changes: 2 additions & 2 deletions datajunction-server/tests/api/dimension_links_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,8 +996,8 @@ async def test_measures_sql_with_reference_dimension_links(
{
"name": "default_DOT_users_DOT_registration_country",
"type": "string",
"column": "registration_country",
"node": "default.users",
"column": "user_registration_country",
"node": "default.events",
"semantic_entity": "default.users.registration_country",
"semantic_type": "dimension",
},
Expand Down
2 changes: 1 addition & 1 deletion datajunction-server/tests/api/sql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,7 @@ async def test_transform_sql_filter_dimension_pk_col(
"name": "default_DOT_hard_hat_DOT_hard_hat_id",
"node": "default.hard_hat",
"semantic_entity": "default.hard_hat.hard_hat_id",
"semantic_type": None,
"semantic_type": "dimension",
"type": "int",
},
{
Expand Down
4 changes: 2 additions & 2 deletions datajunction-server/tests/api/sql_v2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ async def fix_dimension_links(module__client_with_roads: AsyncClient):
{
"column": "dispatcher_id",
"name": "default_DOT_dispatcher_DOT_dispatcher_id",
"node": "default.dispatcher",
"node": "default.repair_orders_fact",
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason this is being renamed is that the origin of dispatcher_id is actually default.repair_orders_fact.dispatcher_id. The semantic_entity reflects something separate.

"semantic_entity": "default.dispatcher.dispatcher_id",
"semantic_type": "dimension",
"type": "int",
Expand Down Expand Up @@ -182,7 +182,7 @@ async def fix_dimension_links(module__client_with_roads: AsyncClient):
{
"column": "dispatcher_id",
"name": "default_DOT_dispatcher_DOT_dispatcher_id",
"node": "default.dispatcher",
"node": "default.repair_orders_fact",
"semantic_entity": "default.dispatcher.dispatcher_id",
"semantic_type": "dimension",
"type": "int",
Expand Down