diff --git a/sqlmesh/core/plan/definition.py b/sqlmesh/core/plan/definition.py index d883f4754e..b2ac9eaeb0 100644 --- a/sqlmesh/core/plan/definition.py +++ b/sqlmesh/core/plan/definition.py @@ -255,6 +255,11 @@ def to_evaluatable(self) -> EvaluatablePlan: models_to_backfill=self.models_to_backfill, interval_end_per_model=self.interval_end_per_model, execution_time=self.execution_time, + disabled_restatement_models={ + s.name + for s in self.snapshots.values() + if s.is_model and s.model.disable_restatement + }, ) @cached_property @@ -285,6 +290,7 @@ class EvaluatablePlan(PydanticModel): models_to_backfill: t.Optional[t.Set[str]] = None interval_end_per_model: t.Optional[t.Dict[str, int]] = None execution_time: t.Optional[TimeLike] = None + disabled_restatement_models: t.Set[str] def is_selected_for_backfill(self, model_fqn: str) -> bool: return self.models_to_backfill is None or model_fqn in self.models_to_backfill diff --git a/sqlmesh/core/plan/evaluator.py b/sqlmesh/core/plan/evaluator.py index 246a144805..508a2ab7fb 100644 --- a/sqlmesh/core/plan/evaluator.py +++ b/sqlmesh/core/plan/evaluator.py @@ -409,7 +409,9 @@ def _restate(self, plan: EvaluatablePlan, snapshots_by_name: t.Dict[str, Snapsho # # Without this rule, its possible that promoting a dev table to prod will introduce old data to prod snapshot_intervals_to_restate.update( - self._restatement_intervals_across_all_environments(plan.restatements) + self._restatement_intervals_across_all_environments( + plan.restatements, plan.disabled_restatement_models + ) ) self.state_sync.remove_intervals( @@ -418,12 +420,12 @@ def _restate(self, plan: EvaluatablePlan, snapshots_by_name: t.Dict[str, Snapsho ) def _restatement_intervals_across_all_environments( - self, prod_restatements: t.Dict[str, Interval] + self, prod_restatements: t.Dict[str, Interval], disable_restatement_models: t.Set[str] ) -> t.Set[t.Tuple[SnapshotTableInfo, Interval]]: """ Given a map of snapshot names + intervals to restate in prod: - Look up matching snapshots across all environments (match based on name - regardless of version) - - For each match, also match downstream snapshots + - For each match, also match downstream snapshots while filtering out models that have restatement disabled - Return all matches mapped to the intervals of the prod snapshot being restated The goal here is to produce a list of intervals to invalidate across all environments so that a cadence @@ -444,7 +446,11 @@ def _restatement_intervals_across_all_environments( for restatement, intervals in prod_restatements.items(): if restatement not in keyed_snapshots: continue - affected_snapshot_names = [restatement] + env_dag.downstream(restatement) + affected_snapshot_names = [ + x + for x in ([restatement] + env_dag.downstream(restatement)) + if x not in disable_restatement_models + ] snapshots_to_restate.update( {(keyed_snapshots[a], intervals) for a in affected_snapshot_names} ) diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index 001f795c63..3066a5cefb 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -2172,6 +2172,118 @@ def _dates_in_table(table_name: str) -> t.List[str]: ], f"Table {tbl} wasnt cleared" +def test_restatement_plan_respects_disable_restatements(tmp_path: Path): + model_a = """ + MODEL ( + name test.a, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column "ts" + ), + start '2024-01-01', + cron '@daily' + ); + + select account_id, ts from test.external_table; + """ + + model_b = """ + MODEL ( + name test.b, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column "ts", + disable_restatement true, + ), + start '2024-01-01', + cron '@daily' + ); + + select account_id, ts from test.a; + """ + + models_dir = tmp_path / "models" + models_dir.mkdir() + + for path, defn in {"a.sql": model_a, "b.sql": model_b}.items(): + with open(models_dir / path, "w") as f: + f.write(defn) + + config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")) + ctx = Context(paths=[tmp_path], config=config) + + engine_adapter = ctx.engine_adapter + engine_adapter.create_schema("test") + + # source data + df = pd.DataFrame( + { + "account_id": [1001, 1002, 1003, 1004], + "ts": [ + "2024-01-01 00:30:00", + "2024-01-01 01:30:00", + "2024-01-01 02:30:00", + "2024-01-02 00:30:00", + ], + } + ) + columns_to_types = { + "account_id": exp.DataType.build("int"), + "ts": exp.DataType.build("timestamp"), + } + external_table = exp.table_(table="external_table", db="test", quoted=True) + engine_adapter.create_table(table_name=external_table, columns_to_types=columns_to_types) + engine_adapter.insert_append( + table_name=external_table, query_or_df=df, columns_to_types=columns_to_types + ) + + # plan + apply + ctx.plan(auto_apply=True, no_prompts=True) + + def _dates_in_table(table_name: str) -> t.List[str]: + return [ + str(r[0]) for r in engine_adapter.fetchall(f"select ts from {table_name} order by ts") + ] + + def get_snapshot_intervals(snapshot_id): + return list(ctx.state_sync.get_snapshots([snapshot_id]).values())[0].intervals + + # verify initial state + for tbl in ["test.a", "test.b"]: + assert _dates_in_table(tbl) == [ + "2024-01-01 00:30:00", + "2024-01-01 01:30:00", + "2024-01-01 02:30:00", + "2024-01-02 00:30:00", + ] + + # restate A and expect b to be ignored + starting_b_intervals = get_snapshot_intervals(ctx.snapshots['"memory"."test"."b"'].snapshot_id) + engine_adapter.execute("delete from test.external_table where ts = '2024-01-01 01:30:00'") + ctx.plan( + restate_models=["test.a"], + start="2024-01-01", + end="2024-01-02", + auto_apply=True, + no_prompts=True, + ) + + # verify A was changed and not b + assert _dates_in_table("test.a") == [ + "2024-01-01 00:30:00", + "2024-01-01 02:30:00", + "2024-01-02 00:30:00", + ] + assert _dates_in_table("test.b") == [ + "2024-01-01 00:30:00", + "2024-01-01 01:30:00", + "2024-01-01 02:30:00", + "2024-01-02 00:30:00", + ] + + # Verify B intervals were not touched + b_intervals = get_snapshot_intervals(ctx.snapshots['"memory"."test"."b"'].snapshot_id) + assert starting_b_intervals == b_intervals + + def test_restatement_plan_clears_correct_intervals_across_environments(tmp_path: Path): model1 = """ MODEL ( diff --git a/tests/schedulers/airflow/test_client.py b/tests/schedulers/airflow/test_client.py index 0e01d3e029..a53834af15 100644 --- a/tests/schedulers/airflow/test_client.py +++ b/tests/schedulers/airflow/test_client.py @@ -77,6 +77,7 @@ def test_apply_plan(mocker: MockerFixture, snapshot: Snapshot): removed_snapshots=[], requires_backfill=True, models_to_backfill={'"test_model"'}, + disabled_restatement_models=set(), ) client = AirflowClient(airflow_url=common.AIRFLOW_LOCAL_URL, session=requests.Session()) @@ -196,6 +197,7 @@ def test_apply_plan(mocker: MockerFixture, snapshot: Snapshot): '"test_model"': [to_timestamp("2024-01-01"), to_timestamp("2024-01-02")] }, "requires_backfill": True, + "disabled_restatement_models": [], }, "notification_targets": [], "backfill_concurrent_tasks": 1, diff --git a/tests/schedulers/airflow/test_integration.py b/tests/schedulers/airflow/test_integration.py index 4e1c9f67ae..2bd6c2cb14 100644 --- a/tests/schedulers/airflow/test_integration.py +++ b/tests/schedulers/airflow/test_integration.py @@ -115,6 +115,7 @@ def _create_evaluatable_plan( indirectly_modified_snapshots={}, removed_snapshots=[], requires_backfill=True, + disabled_restatement_models=set(), ) diff --git a/tests/schedulers/airflow/test_plan.py b/tests/schedulers/airflow/test_plan.py index 6bb9df33df..11b409cc67 100644 --- a/tests/schedulers/airflow/test_plan.py +++ b/tests/schedulers/airflow/test_plan.py @@ -125,6 +125,7 @@ def test_create_plan_dag_spec( interval_end_per_model=None, allow_destructive_models=set(), requires_backfill=True, + disabled_restatement_models=set(), ) plan_request = common.PlanApplicationRequest( @@ -269,6 +270,7 @@ def test_restatement( interval_end_per_model=None, allow_destructive_models=set(), requires_backfill=True, + disabled_restatement_models=set(), ) plan_request = common.PlanApplicationRequest( @@ -390,6 +392,7 @@ def test_select_models_for_backfill(mocker: MockerFixture, random_name, make_sna interval_end_per_model=None, allow_destructive_models=set(), requires_backfill=True, + disabled_restatement_models=set(), ) plan_request = common.PlanApplicationRequest( @@ -475,6 +478,7 @@ def test_create_plan_dag_spec_duplicated_snapshot( interval_end_per_model=None, allow_destructive_models=set(), requires_backfill=True, + disabled_restatement_models=set(), ) plan_request = common.PlanApplicationRequest( @@ -537,6 +541,7 @@ def test_create_plan_dag_spec_unbounded_end( interval_end_per_model=None, allow_destructive_models=set(), requires_backfill=True, + disabled_restatement_models=set(), ) plan_request = common.PlanApplicationRequest(