Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions sqlmesh/core/plan/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,11 @@ def to_evaluatable(self) -> EvaluatablePlan:
models_to_backfill=self.models_to_backfill,
interval_end_per_model=self.interval_end_per_model,
execution_time=self.execution_time,
disabled_restatement_models={
s.name
for s in self.snapshots.values()
if s.is_model and s.model.disable_restatement
},
)

@cached_property
Expand Down Expand Up @@ -285,6 +290,7 @@ class EvaluatablePlan(PydanticModel):
models_to_backfill: t.Optional[t.Set[str]] = None
interval_end_per_model: t.Optional[t.Dict[str, int]] = None
execution_time: t.Optional[TimeLike] = None
disabled_restatement_models: t.Set[str]

def is_selected_for_backfill(self, model_fqn: str) -> bool:
return self.models_to_backfill is None or model_fqn in self.models_to_backfill
Expand Down
14 changes: 10 additions & 4 deletions sqlmesh/core/plan/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,9 @@ def _restate(self, plan: EvaluatablePlan, snapshots_by_name: t.Dict[str, Snapsho
#
# Without this rule, its possible that promoting a dev table to prod will introduce old data to prod
snapshot_intervals_to_restate.update(
self._restatement_intervals_across_all_environments(plan.restatements)
self._restatement_intervals_across_all_environments(
plan.restatements, plan.disabled_restatement_models
)
)

self.state_sync.remove_intervals(
Expand All @@ -418,12 +420,12 @@ def _restate(self, plan: EvaluatablePlan, snapshots_by_name: t.Dict[str, Snapsho
)

def _restatement_intervals_across_all_environments(
self, prod_restatements: t.Dict[str, Interval]
self, prod_restatements: t.Dict[str, Interval], disable_restatement_models: t.Set[str]
) -> t.Set[t.Tuple[SnapshotTableInfo, Interval]]:
"""
Given a map of snapshot names + intervals to restate in prod:
- Look up matching snapshots across all environments (match based on name - regardless of version)
- For each match, also match downstream snapshots
- For each match, also match downstream snapshots while filtering out models that have restatement disabled
- Return all matches mapped to the intervals of the prod snapshot being restated

The goal here is to produce a list of intervals to invalidate across all environments so that a cadence
Expand All @@ -444,7 +446,11 @@ def _restatement_intervals_across_all_environments(
for restatement, intervals in prod_restatements.items():
if restatement not in keyed_snapshots:
continue
affected_snapshot_names = [restatement] + env_dag.downstream(restatement)
affected_snapshot_names = [
x
for x in ([restatement] + env_dag.downstream(restatement))
if x not in disable_restatement_models
]
snapshots_to_restate.update(
{(keyed_snapshots[a], intervals) for a in affected_snapshot_names}
)
Expand Down
112 changes: 112 additions & 0 deletions tests/core/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2172,6 +2172,118 @@ def _dates_in_table(table_name: str) -> t.List[str]:
], f"Table {tbl} wasnt cleared"


def test_restatement_plan_respects_disable_restatements(tmp_path: Path):
model_a = """
MODEL (
name test.a,
kind INCREMENTAL_BY_TIME_RANGE (
time_column "ts"
),
start '2024-01-01',
cron '@daily'
);

select account_id, ts from test.external_table;
"""

model_b = """
MODEL (
name test.b,
kind INCREMENTAL_BY_TIME_RANGE (
time_column "ts",
disable_restatement true,
),
start '2024-01-01',
cron '@daily'
);

select account_id, ts from test.a;
"""

models_dir = tmp_path / "models"
models_dir.mkdir()

for path, defn in {"a.sql": model_a, "b.sql": model_b}.items():
with open(models_dir / path, "w") as f:
f.write(defn)

config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))
ctx = Context(paths=[tmp_path], config=config)

engine_adapter = ctx.engine_adapter
engine_adapter.create_schema("test")

# source data
df = pd.DataFrame(
{
"account_id": [1001, 1002, 1003, 1004],
"ts": [
"2024-01-01 00:30:00",
"2024-01-01 01:30:00",
"2024-01-01 02:30:00",
"2024-01-02 00:30:00",
],
}
)
columns_to_types = {
"account_id": exp.DataType.build("int"),
"ts": exp.DataType.build("timestamp"),
}
external_table = exp.table_(table="external_table", db="test", quoted=True)
engine_adapter.create_table(table_name=external_table, columns_to_types=columns_to_types)
engine_adapter.insert_append(
table_name=external_table, query_or_df=df, columns_to_types=columns_to_types
)

# plan + apply
ctx.plan(auto_apply=True, no_prompts=True)

def _dates_in_table(table_name: str) -> t.List[str]:
return [
str(r[0]) for r in engine_adapter.fetchall(f"select ts from {table_name} order by ts")
]

def get_snapshot_intervals(snapshot_id):
return list(ctx.state_sync.get_snapshots([snapshot_id]).values())[0].intervals

# verify initial state
for tbl in ["test.a", "test.b"]:
assert _dates_in_table(tbl) == [
"2024-01-01 00:30:00",
"2024-01-01 01:30:00",
"2024-01-01 02:30:00",
"2024-01-02 00:30:00",
]

# restate A and expect b to be ignored
starting_b_intervals = get_snapshot_intervals(ctx.snapshots['"memory"."test"."b"'].snapshot_id)
engine_adapter.execute("delete from test.external_table where ts = '2024-01-01 01:30:00'")
ctx.plan(
restate_models=["test.a"],
start="2024-01-01",
end="2024-01-02",
auto_apply=True,
no_prompts=True,
)

# verify A was changed and not b
assert _dates_in_table("test.a") == [
"2024-01-01 00:30:00",
"2024-01-01 02:30:00",
"2024-01-02 00:30:00",
]
assert _dates_in_table("test.b") == [
"2024-01-01 00:30:00",
"2024-01-01 01:30:00",
"2024-01-01 02:30:00",
"2024-01-02 00:30:00",
]

# Verify B intervals were not touched
b_intervals = get_snapshot_intervals(ctx.snapshots['"memory"."test"."b"'].snapshot_id)
assert starting_b_intervals == b_intervals


def test_restatement_plan_clears_correct_intervals_across_environments(tmp_path: Path):
model1 = """
MODEL (
Expand Down
2 changes: 2 additions & 0 deletions tests/schedulers/airflow/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def test_apply_plan(mocker: MockerFixture, snapshot: Snapshot):
removed_snapshots=[],
requires_backfill=True,
models_to_backfill={'"test_model"'},
disabled_restatement_models=set(),
)

client = AirflowClient(airflow_url=common.AIRFLOW_LOCAL_URL, session=requests.Session())
Expand Down Expand Up @@ -196,6 +197,7 @@ def test_apply_plan(mocker: MockerFixture, snapshot: Snapshot):
'"test_model"': [to_timestamp("2024-01-01"), to_timestamp("2024-01-02")]
},
"requires_backfill": True,
"disabled_restatement_models": [],
},
"notification_targets": [],
"backfill_concurrent_tasks": 1,
Expand Down
1 change: 1 addition & 0 deletions tests/schedulers/airflow/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def _create_evaluatable_plan(
indirectly_modified_snapshots={},
removed_snapshots=[],
requires_backfill=True,
disabled_restatement_models=set(),
)


Expand Down
5 changes: 5 additions & 0 deletions tests/schedulers/airflow/test_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def test_create_plan_dag_spec(
interval_end_per_model=None,
allow_destructive_models=set(),
requires_backfill=True,
disabled_restatement_models=set(),
)

plan_request = common.PlanApplicationRequest(
Expand Down Expand Up @@ -269,6 +270,7 @@ def test_restatement(
interval_end_per_model=None,
allow_destructive_models=set(),
requires_backfill=True,
disabled_restatement_models=set(),
)

plan_request = common.PlanApplicationRequest(
Expand Down Expand Up @@ -390,6 +392,7 @@ def test_select_models_for_backfill(mocker: MockerFixture, random_name, make_sna
interval_end_per_model=None,
allow_destructive_models=set(),
requires_backfill=True,
disabled_restatement_models=set(),
)

plan_request = common.PlanApplicationRequest(
Expand Down Expand Up @@ -475,6 +478,7 @@ def test_create_plan_dag_spec_duplicated_snapshot(
interval_end_per_model=None,
allow_destructive_models=set(),
requires_backfill=True,
disabled_restatement_models=set(),
)

plan_request = common.PlanApplicationRequest(
Expand Down Expand Up @@ -537,6 +541,7 @@ def test_create_plan_dag_spec_unbounded_end(
interval_end_per_model=None,
allow_destructive_models=set(),
requires_backfill=True,
disabled_restatement_models=set(),
)

plan_request = common.PlanApplicationRequest(
Expand Down