Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,17 @@ def get_temporal_partitions(preagg: PreAggregation) -> list[TemporalPartitionCol

temporal_partitions: list[TemporalPartitionColumn] = []
if preagg.node_revision: # pragma: no branch
# Build reverse mapping: source column name -> dimension attribute
# dimensions_to_columns_map returns {dim_attr (AST Column): source_col (AST Column)}
col_to_dim: dict[str, str] = {}
# Build reverse mapping: source column name -> list of dimension attributes.
# dimensions_to_columns_map returns {dim_attr (AST Column): source_col (AST Column)}.
# A single source column may appear in multiple dimension links (e.g., one
# simple link plus one multi-column link), so this must be multi-valued —
# otherwise only one mapping survives and we may miss the dim_attr that
# the user actually selected in the cube grain.
col_to_dims: dict[str, list[str]] = {}
dim_to_col = preagg.node_revision.dimensions_to_columns_map()
for dim_attr, source_col in dim_to_col.items():
# dim_attr is like "dimensions.date.dateint"
# source_col is an AST Column, get its name (e.g., "utc_date")
source_col_name = source_col.identifier().split(SEPARATOR)[-1]
col_to_dim[source_col_name] = dim_attr
col_to_dims.setdefault(source_col_name, []).append(dim_attr)

for temporal_col in preagg.node_revision.temporal_partition_columns():
source_name = temporal_col.name
Expand All @@ -191,27 +193,31 @@ def get_temporal_partitions(preagg: PreAggregation) -> list[TemporalPartitionCol
if full_source_col in preagg.grain_columns:
output_name = source_name # pragma: no cover

# Strategy 2: Check dimension links via dimensions_to_columns_map
# If temporal column maps to a dimension attribute, find that in grain
elif source_name in col_to_dim:
dim_attr = col_to_dim[source_name]
dim_node = dim_attr.rsplit(SEPARATOR, 1)[0]
# Check if this dimension attribute or its parent node is in grain_columns
for gc in preagg.grain_columns:
if gc == dim_attr or gc.startswith(dim_node + SEPARATOR):
# Parse the dimension ref to handle role syntax properly
# e.g., "v3.date.week[order]" -> column_name="week", role="order"
# -> output_name="week_order"
parsed = parse_dimension_ref(gc)
output_name = parsed.column_name
if parsed.role: # pragma: no branch
output_name = f"{output_name}_{parsed.role}"
logger.info(
"Temporal column %s links to dimension %s -> output %s",
source_name,
dim_attr,
output_name,
)
# Strategy 2: Check dimension links via dimensions_to_columns_map.
# Try every dim_attr that maps back to this source column — the first
# one whose attribute (or parent node) appears in grain_columns wins.
elif source_name in col_to_dims:
for dim_attr in col_to_dims[source_name]:
dim_node = dim_attr.rsplit(SEPARATOR, 1)[0]
matched = False
for gc in preagg.grain_columns:
if gc == dim_attr or gc.startswith(dim_node + SEPARATOR):
# Parse the dimension ref to handle role syntax properly
# e.g., "v3.date.week[order]" -> column_name="week", role="order"
# -> output_name="week_order"
parsed = parse_dimension_ref(gc)
output_name = parsed.column_name
if parsed.role: # pragma: no branch
output_name = f"{output_name}_{parsed.role}"
logger.info(
"Temporal column %s links to dimension %s -> output %s",
source_name,
dim_attr,
output_name,
)
matched = True
break
if matched:
break

# Strategy 3: Check column.dimension reference link
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,22 @@

import pytest
import pytest_asyncio
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload, selectinload

from datajunction_server.construction.build_v3.preagg_matcher import (
find_matching_preagg,
get_preagg_measure_column,
get_required_measure_hashes,
get_temporal_partitions,
)
from datajunction_server.construction.build_v3.types import BuildContext, GrainGroup
from datajunction_server.database.availabilitystate import AvailabilityState
from datajunction_server.database.column import Column
from datajunction_server.database.dimensionlink import DimensionLink
from datajunction_server.database.node import Node, NodeRevision
from datajunction_server.database.partition import Partition
from datajunction_server.database.preaggregation import (
PreAggregation,
compute_expression_hash,
Expand All @@ -24,9 +29,11 @@
MetricComponent,
PreAggMeasure,
)
from datajunction_server.models.dimensionlink import JoinType
from datajunction_server.models.node_type import NodeType
from datajunction_server.models.partition import Granularity, PartitionType
from datajunction_server.models.user import OAuthProvider
from datajunction_server.sql.parsing.types import IntegerType
from datajunction_server.sql.parsing.types import IntegerType, StringType


def make_component(
Expand Down Expand Up @@ -805,3 +812,191 @@ async def test_handles_measure_without_expr_hash(
result = get_preagg_measure_column(preagg, component)

assert result is None


class TestGetTemporalPartitionsMultipleDimLinks:
"""Regression tests for get_temporal_partitions when a source column is
referenced by multiple dimension links (e.g. one simple link and one
multi-column link sharing the same FK column).
"""

@pytest_asyncio.fixture
async def multi_link_fact(
self,
session: AsyncSession,
test_user: User,
) -> PreAggregation:
"""
Build a fact node whose `region_date` column has two dimension links:
- a simple link to `test.dim_date.dateint`
- a multi-column link to `test.dim_tcd` (title_id, country_code, date_id),
also referencing `region_date`.

Links are inserted simple-first, multi-column-second, so in the buggy
single-valued col_to_dim the multi-column link wins — matching the
in-the-wild failure.

Returns a PreAggregation attached to the fact with grain_columns set
by the caller via `preagg.grain_columns = [...]`.
"""
dim_date_node = Node(
name="test.dim_date",
type=NodeType.DIMENSION,
created_by_id=test_user.id,
)
dim_tcd_node = Node(
name="test.dim_tcd",
type=NodeType.DIMENSION,
created_by_id=test_user.id,
)
session.add_all([dim_date_node, dim_tcd_node])
await session.flush()

dim_date_rev = NodeRevision(
name=dim_date_node.name,
node_id=dim_date_node.id,
type=NodeType.DIMENSION,
version="1",
columns=[Column(name="dateint", type=IntegerType(), order=0)],
created_by_id=test_user.id,
)
dim_tcd_rev = NodeRevision(
name=dim_tcd_node.name,
node_id=dim_tcd_node.id,
type=NodeType.DIMENSION,
version="1",
columns=[
Column(name="title_id", type=IntegerType(), order=0),
Column(name="country_code", type=StringType(), order=1),
Column(name="date_id", type=IntegerType(), order=2),
],
created_by_id=test_user.id,
)
session.add_all([dim_date_rev, dim_tcd_rev])
await session.flush()
dim_date_node.current_version = "1"
dim_date_node.current = dim_date_rev
dim_tcd_node.current_version = "1"
dim_tcd_node.current = dim_tcd_rev

fact_node = Node(
name="test.fact",
type=NodeType.SOURCE,
created_by_id=test_user.id,
)
session.add(fact_node)
await session.flush()

region_date_col = Column(name="region_date", type=IntegerType(), order=0)
fact_rev = NodeRevision(
name=fact_node.name,
node_id=fact_node.id,
type=NodeType.SOURCE,
version="1",
columns=[
region_date_col,
Column(name="title_id", type=IntegerType(), order=1),
Column(name="country_code", type=StringType(), order=2),
Column(name="amt", type=IntegerType(), order=3),
],
created_by_id=test_user.id,
)
session.add(fact_rev)
await session.flush()
fact_node.current_version = "1"
fact_node.current = fact_rev

# Temporal partition on region_date
partition = Partition(
column_id=region_date_col.id,
type_=PartitionType.TEMPORAL,
granularity=Granularity.DAY,
format="yyyyMMdd",
)
session.add(partition)
await session.flush()
region_date_col.partition = partition

# Simple link: fact.region_date = dim_date.dateint (inserted first)
simple_link = DimensionLink(
node_revision=fact_rev,
dimension=dim_date_node,
join_sql="test.fact.region_date = test.dim_date.dateint",
join_type=JoinType.LEFT,
)
# Multi-column link: fact joins dim_tcd on (title_id, country_code, date_id),
# where date_id also comes from region_date
multi_link = DimensionLink(
node_revision=fact_rev,
dimension=dim_tcd_node,
join_sql=(
"test.fact.title_id = test.dim_tcd.title_id "
"AND test.fact.country_code = test.dim_tcd.country_code "
"AND test.fact.region_date = test.dim_tcd.date_id"
),
join_type=JoinType.LEFT,
)
session.add_all([simple_link, multi_link])
await session.flush()

avail = AvailabilityState(
catalog="test",
schema_="test",
table="fact_preagg",
valid_through_ts=9999999999,
)
session.add(avail)
await session.flush()

preagg = PreAggregation(
node_revision_id=fact_rev.id,
grain_columns=[], # caller sets this per-test
measures=[make_preagg_measure("sum_amt", "amt")],
sql="SELECT ...",
grain_group_hash="hash_tp",
preagg_hash="tp01",
availability_id=avail.id,
)
session.add(preagg)
await session.flush()

# Re-query with eager-loaded relationships. get_temporal_partitions is
# synchronous and walks node_revision.columns, .dimension_links, and
# each column's .partition — all of which must already be populated
# to avoid MissingGreenlet lazy-load errors in an async session.
result = await session.execute(
select(PreAggregation)
.where(PreAggregation.id == preagg.id)
.options(
joinedload(PreAggregation.node_revision).options(
selectinload(NodeRevision.columns).joinedload(Column.partition),
selectinload(NodeRevision.dimension_links).joinedload(
DimensionLink.dimension,
),
),
),
)
return result.unique().scalar_one()

@pytest.mark.asyncio
async def test_resolves_output_name_via_simple_link_when_both_links_share_fk(
self,
session: AsyncSession,
multi_link_fact: PreAggregation,
):
"""
When `region_date` maps to both `dim_date.dateint` (simple) and
`dim_tcd.date_id` (multi-column), and the user selected the simple
link's dim in the cube grain, the output column must resolve to
`dateint` — not the source name `region_date`.

Before the fix, col_to_dim was a single-valued dict, so the
multi-column link overwrote the simple link and strategy 2 failed
to match grain_columns, leaving output_name = 'region_date'.
"""
multi_link_fact.grain_columns = ["test.dim_date.dateint"]

partitions = get_temporal_partitions(multi_link_fact)

assert len(partitions) == 1
assert partitions[0].column_name == "dateint"
Loading