Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 31 additions & 13 deletions sqlmesh/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,7 @@ def _dag(
}
snapshots_to_create = snapshots_to_create or set()
original_snapshots_to_create = snapshots_to_create.copy()
upstream_dependencies_cache: t.Dict[SnapshotId, t.Set[SchedulingUnit]] = {}

snapshot_dag = snapshot_dag or snapshots_to_dag(batches)
dag = DAG[SchedulingUnit]()
Expand All @@ -670,12 +671,15 @@ def _dag(
snapshot = self.snapshots_by_name[snapshot_id.name]
intervals = intervals_per_snapshot.get(snapshot.name, [])

upstream_dependencies: t.List[SchedulingUnit] = []
upstream_dependencies: t.Set[SchedulingUnit] = set()

for p_sid in snapshot.parents:
upstream_dependencies.extend(
upstream_dependencies.update(
self._find_upstream_dependencies(
p_sid, intervals_per_snapshot, original_snapshots_to_create
p_sid,
intervals_per_snapshot,
original_snapshots_to_create,
upstream_dependencies_cache,
)
)

Expand Down Expand Up @@ -726,29 +730,43 @@ def _find_upstream_dependencies(
parent_sid: SnapshotId,
intervals_per_snapshot: t.Dict[str, Intervals],
snapshots_to_create: t.Set[SnapshotId],
) -> t.List[SchedulingUnit]:
cache: t.Optional[t.Dict[SnapshotId, t.Set[SchedulingUnit]]] = None,
) -> t.Set[SchedulingUnit]:
cache = cache or {}
if parent_sid not in self.snapshots:
return []
return set()
if parent_sid in cache:
return cache[parent_sid]

p_intervals = intervals_per_snapshot.get(parent_sid.name, [])

parent_node: t.Optional[SchedulingUnit] = None
if p_intervals:
if len(p_intervals) > 1:
return [DummyNode(snapshot_name=parent_sid.name)]
interval = p_intervals[0]
return [EvaluateNode(snapshot_name=parent_sid.name, interval=interval, batch_index=0)]
if parent_sid in snapshots_to_create:
return [CreateNode(snapshot_name=parent_sid.name)]
parent_node = DummyNode(snapshot_name=parent_sid.name)
else:
interval = p_intervals[0]
parent_node = EvaluateNode(
snapshot_name=parent_sid.name, interval=interval, batch_index=0
)
elif parent_sid in snapshots_to_create:
parent_node = CreateNode(snapshot_name=parent_sid.name)

if parent_node is not None:
cache[parent_sid] = {parent_node}
return {parent_node}

# This snapshot has no intervals and doesn't need creation which means
# that it can be a transitive dependency
transitive_deps: t.List[SchedulingUnit] = []
transitive_deps: t.Set[SchedulingUnit] = set()
parent_snapshot = self.snapshots[parent_sid]
for grandparent_sid in parent_snapshot.parents:
transitive_deps.extend(
transitive_deps.update(
self._find_upstream_dependencies(
grandparent_sid, intervals_per_snapshot, snapshots_to_create
grandparent_sid, intervals_per_snapshot, snapshots_to_create, cache
)
)
cache[parent_sid] = transitive_deps
return transitive_deps

def _run_or_audit(
Expand Down
87 changes: 87 additions & 0 deletions tests/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,3 +1126,90 @@ def test_dag_multiple_chain_transitive_deps(mocker: MockerFixture, make_snapshot
)
},
}


def test_dag_upstream_dependency_caching_with_complex_diamond(mocker: MockerFixture, make_snapshot):
r"""
Test that the upstream dependency caching correctly handles a complex diamond dependency graph.

Dependency graph:
A (has intervals)
/ \
B C (no intervals - transitive)
/ \ / \
D E F (no intervals - transitive)
\ / \ /
G H (has intervals - selected)

This creates multiple paths from G and H to A. Without caching, A's dependencies would be
computed multiple times (once for each path). With caching, they should be computed once
and reused.
"""
snapshots = {}

for name in ["a", "b", "c", "d", "e", "f", "g", "h"]:
snapshots[name] = make_snapshot(SqlModel(name=name, query=parse_one("SELECT 1 as id")))
snapshots[name].categorize_as(SnapshotChangeCategory.BREAKING)

# A is the root
snapshots["b"] = snapshots["b"].model_copy(update={"parents": (snapshots["a"].snapshot_id,)})
snapshots["c"] = snapshots["c"].model_copy(update={"parents": (snapshots["a"].snapshot_id,)})

# Middle layer: D, E, F depend on B and/or C
snapshots["d"] = snapshots["d"].model_copy(update={"parents": (snapshots["b"].snapshot_id,)})
snapshots["e"] = snapshots["e"].model_copy(
update={"parents": (snapshots["b"].snapshot_id, snapshots["c"].snapshot_id)}
)
snapshots["f"] = snapshots["f"].model_copy(update={"parents": (snapshots["c"].snapshot_id,)})

# Bottom layer: G and H depend on D/E and E/F respectively
snapshots["g"] = snapshots["g"].model_copy(
update={"parents": (snapshots["d"].snapshot_id, snapshots["e"].snapshot_id)}
)
snapshots["h"] = snapshots["h"].model_copy(
update={"parents": (snapshots["e"].snapshot_id, snapshots["f"].snapshot_id)}
)

scheduler = Scheduler(
snapshots=list(snapshots.values()),
snapshot_evaluator=mocker.Mock(),
state_sync=mocker.Mock(),
default_catalog=None,
)

batched_intervals = {
snapshots["a"]: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))],
snapshots["g"]: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))],
snapshots["h"]: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))],
}

full_dag = snapshots_to_dag(snapshots.values())
dag = scheduler._dag(batched_intervals, snapshot_dag=full_dag)

# Verify the DAG structure:
# 1. A should be evaluated first (no dependencies)
# 2. Both G and H should depend on A (through transitive dependencies)
# 3. Transitive nodes (B, C, D, E, F) should not appear as separate evaluation nodes
expected_a_node = EvaluateNode(
snapshot_name='"a"',
interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")),
batch_index=0,
)

expected_g_node = EvaluateNode(
snapshot_name='"g"',
interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")),
batch_index=0,
)

expected_h_node = EvaluateNode(
snapshot_name='"h"',
interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")),
batch_index=0,
)

assert dag.graph == {
expected_a_node: set(),
expected_g_node: {expected_a_node},
expected_h_node: {expected_a_node},
}