Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions sqlmesh/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,9 @@ def merged_missing_intervals(
validate_date_range(start, end)

snapshots: t.Collection[Snapshot] = self.snapshot_per_version.values()
if selected_snapshots is not None:
snapshots = [s for s in snapshots if s.name in selected_snapshots]

self.state_sync.refresh_snapshot_intervals(snapshots)

return compute_interval_params(
snapshots_to_intervals = compute_interval_params(
snapshots,
start=start or earliest_start_date(snapshots),
end=end or now(),
Expand All @@ -132,6 +129,13 @@ def merged_missing_intervals(
ignore_cron=ignore_cron,
end_bounded=end_bounded,
)
# Filtering snapshots after computing missing intervals because we need all snapshots in order
# to correctly infer start dates.
if selected_snapshots is not None:
snapshots_to_intervals = {
s: i for s, i in snapshots_to_intervals.items() if s.name in selected_snapshots
}
return snapshots_to_intervals

def evaluate(
self,
Expand Down
13 changes: 3 additions & 10 deletions sqlmesh/core/snapshot/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1340,13 +1340,7 @@ def _visit(node: SnapshotId, deployable: bool = True) -> None:

if deployable and node in snapshots:
snapshot = snapshots[node]
# Capture uncategorized snapshot which represents a forward-only model.
is_uncategorized_forward_only_model = (
snapshot.change_category is None
and snapshot.previous_versions
and snapshot.is_model
and snapshot.model.forward_only
)
is_forward_only_model = snapshot.is_model and snapshot.model.forward_only

is_valid_start = (
snapshot.is_valid_start(
Expand All @@ -1359,7 +1353,7 @@ def _visit(node: SnapshotId, deployable: bool = True) -> None:
if (
snapshot.is_forward_only
or snapshot.is_indirect_non_breaking
or is_uncategorized_forward_only_model
or is_forward_only_model
or not is_valid_start
):
# FORWARD_ONLY and INDIRECT_NON_BREAKING snapshots are not deployable by nature.
Expand All @@ -1372,8 +1366,7 @@ def _visit(node: SnapshotId, deployable: bool = True) -> None:
else:
this_deployable = True
children_deployable = is_valid_start and not (
snapshot.is_paused
and (snapshot.is_forward_only or is_uncategorized_forward_only_model)
snapshot.is_paused and (snapshot.is_forward_only or is_forward_only_model)
)
else:
this_deployable, children_deployable = False, False
Expand Down
74 changes: 70 additions & 4 deletions tests/core/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1774,6 +1774,72 @@ def test_new_forward_only_model_concurrent_versions(init_and_plan_context: t.Cal
assert df.to_dict() == {"ds": {0: "2023-01-07"}, "b": {0: None}}


@freeze_time("2023-01-08 15:00:00")
def test_new_forward_only_model_same_dev_environment(init_and_plan_context: t.Callable):
context, plan = init_and_plan_context("examples/sushi")
context.apply(plan)

new_model_expr = d.parse(
"""
MODEL (
name memory.sushi.new_model,
kind INCREMENTAL_BY_TIME_RANGE (
time_column ds,
forward_only TRUE,
on_destructive_change 'allow',
),
);

SELECT '2023-01-07' AS ds, 1 AS a;
"""
)
new_model = load_sql_based_model(new_model_expr)

# Add the first version of the model and apply it to dev.
context.upsert_model(new_model)
snapshot_a = context.get_snapshot(new_model.name)
plan_a = context.plan("dev", no_prompts=True)
snapshot_a = plan_a.snapshots[snapshot_a.snapshot_id]

assert snapshot_a.snapshot_id in plan_a.context_diff.new_snapshots
assert snapshot_a.snapshot_id in plan_a.context_diff.added
assert snapshot_a.change_category == SnapshotChangeCategory.BREAKING

context.apply(plan_a)

df = context.fetchdf("SELECT * FROM memory.sushi__dev.new_model")
assert df.to_dict() == {"ds": {0: "2023-01-07"}, "a": {0: 1}}

new_model_alt_expr = d.parse(
"""
MODEL (
name memory.sushi.new_model,
kind INCREMENTAL_BY_TIME_RANGE (
time_column ds,
forward_only TRUE,
on_destructive_change 'allow',
),
);

SELECT '2023-01-07' AS ds, 1 AS b;
"""
)
new_model_alt = load_sql_based_model(new_model_alt_expr)

# Add the second version of the model and apply it to the same environment.
context.upsert_model(new_model_alt)
snapshot_b = context.get_snapshot(new_model_alt.name)

context.invalidate_environment("dev", sync=True)
plan_b = context.plan("dev", no_prompts=True)
snapshot_b = plan_b.snapshots[snapshot_b.snapshot_id]

context.apply(plan_b)

df = context.fetchdf("SELECT * FROM memory.sushi__dev.new_model").replace({np.nan: None})
assert df.to_dict() == {"ds": {0: "2023-01-07"}, "b": {0: 1}}


def test_plan_twice_with_star_macro_yields_no_diff(tmp_path: Path):
init_example_project(tmp_path, dialect="duckdb")

Expand Down Expand Up @@ -2564,7 +2630,7 @@ def get_default_catalog_and_non_tables(
) = get_default_catalog_and_non_tables(metadata, context.default_catalog)
assert len(prod_views) == 13
assert len(dev_views) == 0
assert len(user_default_tables) == 13
assert len(user_default_tables) == 16
assert state_metadata.schemas == ["sqlmesh"]
assert {x.sql() for x in state_metadata.qualified_tables}.issuperset(
{
Expand All @@ -2583,7 +2649,7 @@ def get_default_catalog_and_non_tables(
) = get_default_catalog_and_non_tables(metadata, context.default_catalog)
assert len(prod_views) == 13
assert len(dev_views) == 13
assert len(user_default_tables) == 13
assert len(user_default_tables) == 16
assert len(non_default_tables) == 0
assert state_metadata.schemas == ["sqlmesh"]
assert {x.sql() for x in state_metadata.qualified_tables}.issuperset(
Expand All @@ -2603,7 +2669,7 @@ def get_default_catalog_and_non_tables(
) = get_default_catalog_and_non_tables(metadata, context.default_catalog)
assert len(prod_views) == 13
assert len(dev_views) == 26
assert len(user_default_tables) == 13
assert len(user_default_tables) == 16
assert len(non_default_tables) == 0
assert state_metadata.schemas == ["sqlmesh"]
assert {x.sql() for x in state_metadata.qualified_tables}.issuperset(
Expand All @@ -2624,7 +2690,7 @@ def get_default_catalog_and_non_tables(
) = get_default_catalog_and_non_tables(metadata, context.default_catalog)
assert len(prod_views) == 13
assert len(dev_views) == 13
assert len(user_default_tables) == 13
assert len(user_default_tables) == 16
assert len(non_default_tables) == 0
assert state_metadata.schemas == ["sqlmesh"]
assert {x.sql() for x in state_metadata.qualified_tables}.issuperset(
Expand Down
10 changes: 4 additions & 6 deletions tests/core/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1734,17 +1734,15 @@ def test_deployability_index_categorized_forward_only_model(make_snapshot):
snapshot_b.parents = (snapshot_a.snapshot_id,)
snapshot_b.categorize_as(SnapshotChangeCategory.METADATA)

# The fact that the model is forward only should be ignored if an actual category
# has been assigned.
deployability_index = DeployabilityIndex.create(
{s.snapshot_id: s for s in [snapshot_a, snapshot_b]}
)

assert deployability_index.is_deployable(snapshot_a)
assert deployability_index.is_deployable(snapshot_b)
assert not deployability_index.is_deployable(snapshot_a)
assert not deployability_index.is_deployable(snapshot_b)

assert deployability_index.is_representative(snapshot_a)
assert deployability_index.is_representative(snapshot_b)
assert not deployability_index.is_representative(snapshot_a)
assert not deployability_index.is_representative(snapshot_b)


def test_deployability_index_missing_parent(make_snapshot):
Expand Down
2 changes: 1 addition & 1 deletion tests/integrations/jupyter/test_magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def test_plan(

# TODO: Should this be going to stdout? This is printing the status updates for when each batch finishes for
# the models and how long it took
assert len(output.stdout.strip().split("\n")) == 22
assert len(output.stdout.strip().split("\n")) == 23
assert not output.stderr
assert len(output.outputs) == 4
text_output = convert_all_html_output_to_text(output)
Expand Down