Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 38 additions & 31 deletions sqlmesh/core/plan/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import logging
import re
import sys
import typing as t
from collections import defaultdict
from functools import cached_property
Expand All @@ -27,7 +26,13 @@
from sqlmesh.core.snapshot.definition import Interval, SnapshotId
from sqlmesh.utils import columns_to_types_all_known, random_id
from sqlmesh.utils.dag import DAG
from sqlmesh.utils.date import TimeLike, now, to_datetime, yesterday_ds, to_timestamp
from sqlmesh.utils.date import (
TimeLike,
now,
to_datetime,
yesterday_ds,
to_timestamp,
)
from sqlmesh.utils.errors import NoChangesPlanError, PlanError, SQLMeshError

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -322,56 +327,58 @@ def is_restateable_snapshot(snapshot: Snapshot) -> bool:
if not restate_models:
return {}

start = self._start or earliest_interval_start
end = self._end or now()

# Add restate snapshots and their downstream snapshots
dummy_interval = (sys.maxsize, -sys.maxsize)
for model_fqn in restate_models:
snapshot = self._model_fqn_to_snapshot.get(model_fqn)
if not snapshot:
if model_fqn not in self._model_fqn_to_snapshot:
raise PlanError(f"Cannot restate model '{model_fqn}'. Model does not exist.")

# Get restatement intervals for all restated snapshots and make sure that if an incremental snapshot expands it's
# restatement range that it's downstream dependencies all expand their restatement ranges as well.
for s_id in dag:
snapshot = self._context_diff.snapshots[s_id]

if not forward_only_preview_needed:
if self._is_dev and not snapshot.is_paused:
self._console.log_warning(
f"Cannot restate model '{model_fqn}' because the current version is used in production. "
f"Cannot restate model '{snapshot.name}' because the current version is used in production. "
"Run the restatement against the production environment instead to restate this model."
)
continue
elif (not self._is_dev or not snapshot.is_paused) and snapshot.disable_restatement:
self._console.log_warning(
f"Cannot restate model '{model_fqn}'. "
f"Cannot restate model '{snapshot.name}'. "
"Restatement is disabled for this model to prevent possible data loss."
"If you want to restate this model, change the model's `disable_restatement` setting to `false`."
)
continue
elif snapshot.is_symbolic or snapshot.is_seed:
logger.info("Skipping restatement for model '%s'", model_fqn)
logger.info("Skipping restatement for model '%s'", snapshot.name)
continue

restatements[snapshot.snapshot_id] = dummy_interval
for downstream_s_id in dag.downstream(snapshot.snapshot_id):
if is_restateable_snapshot(self._context_diff.snapshots[downstream_s_id]):
restatements[downstream_s_id] = dummy_interval
# Since we are traversing the graph in topological order and the largest interval range is pushed down
# the graph we just have to check our immediate parents in the graph and not the whole upstream graph.
restating_parents = [
self._context_diff.snapshots[s] for s in snapshot.parents if s in restatements
]

# Get restatement intervals for all restated snapshots and make sure that if an incremental snapshot expands it's
# restatement range that it's downstream dependencies all expand their restatement ranges as well.
for s_id in dag:
if s_id not in restatements:
if not restating_parents and snapshot.name not in restate_models:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do

removal_interval = snapshot.get_removal_interval(

after this check and not before?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, this check doesn't seem useful, since restating_parents doesn't filter out non-incremental models. So restating_models might be non-empty, but then possible_intervals will be empty.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

possible_intervals can be empty because parents are full, but get_removal_interval should always exist with expansion

continue
snapshot = self._context_diff.snapshots[s_id]
interval = snapshot.get_removal_interval(
self._start or earliest_interval_start,
self._end or now(),
self._execution_time,
strict=False,
is_preview=is_preview,

possible_intervals = {
restatements[p.snapshot_id] for p in restating_parents if p.is_incremental
}
possible_intervals.add(
snapshot.get_removal_interval(
start,
end,
self._execution_time,
strict=False,
is_preview=is_preview,
)
)
# Since we are traversing the graph in topological order and the largest interval range is pushed down
# the graph we just have to check our immediate parents in the graph and not the whole upstream graph.
snapshot_dependencies = snapshot.parents
possible_intervals = [
restatements.get(s, dummy_interval)
for s in snapshot_dependencies
if self._context_diff.snapshots[s].is_incremental
] + [interval]
snapshot_start = min(i[0] for i in possible_intervals)
snapshot_end = max(i[1] for i in possible_intervals)

Expand Down
61 changes: 33 additions & 28 deletions sqlmesh/core/snapshot/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,8 @@ def add_interval(self, start: TimeLike, end: TimeLike, is_dev: bool = False) ->
f"Attempted to add an Invalid interval ({start}, {end}) to snapshot {self.snapshot_id}"
)

start_ts, end_ts = self.inclusive_exclusive(start, end, strict=False)
start_ts, end_ts = self.inclusive_exclusive(start, end, strict=False, expand=False)

if start_ts >= end_ts:
# Skipping partial interval.
return
Expand Down Expand Up @@ -744,12 +745,17 @@ def get_removal_interval(

return removal_interval

@property
def allow_partials(self) -> bool:
return self.is_model and self.model.allow_partials

def inclusive_exclusive(
self,
start: TimeLike,
end: TimeLike,
strict: bool = True,
allow_partial: t.Optional[bool] = None,
expand: bool = True,
) -> Interval:
"""Transform the inclusive start and end into a [start, end) pair.

Expand All @@ -758,19 +764,18 @@ def inclusive_exclusive(
end: The end date/time of the interval (inclusive)
strict: Whether to fail when the inclusive start is the same as the exclusive end.
allow_partial: Whether the interval can be partial or not.
expand: Whether or not partial intervals are expanded outwards.

Returns:
A [start, end) pair.
"""
if allow_partial is None:
allow_partial = self.is_model and self.model.allow_partials
return inclusive_exclusive(
start,
end,
self.node.interval_unit,
model_allow_partials=self.is_model and self.model.allow_partials,
strict=strict,
allow_partial=allow_partial,
allow_partial=self.allow_partials if allow_partial is None else allow_partial,
expand=expand,
)

def merge_intervals(self, other: t.Union[Snapshot, SnapshotIntervals]) -> None:
Expand Down Expand Up @@ -847,9 +852,10 @@ def missing_intervals(
# If the amount of time being checked is less than the size of a single interval then we
# know that there can't being missing intervals within that range and return
validate_date_range(start, end)

if (
not is_date(end)
and not (self.is_model and self.model.allow_partials)
and not self.allow_partials
and to_timestamp(end) - to_timestamp(start) < self.node.interval_unit.milliseconds
):
return []
Expand All @@ -862,16 +868,7 @@ def missing_intervals(
if not self.evaluatable or (self.is_seed and intervals):
return []

allow_partials = self.is_model and self.model.allow_partials
start_ts, end_ts = (
to_timestamp(ts)
for ts in self.inclusive_exclusive(
start,
end,
strict=False,
allow_partial=allow_partials,
)
)
start_ts, end_ts = (to_timestamp(ts) for ts in self.inclusive_exclusive(start, end))

interval_unit = self.node.interval_unit
execution_time_ts = to_timestamp(execution_time) if execution_time else now_timestamp()
Expand All @@ -882,7 +879,7 @@ def missing_intervals(
)
if end_bounded:
upper_bound_ts = min(upper_bound_ts, end_ts)
if not allow_partials:
if not self.allow_partials:
upper_bound_ts = to_timestamp(interval_unit.cron_floor(upper_bound_ts))

end_ts = min(end_ts, upper_bound_ts)
Expand Down Expand Up @@ -1865,36 +1862,44 @@ def inclusive_exclusive(
start: TimeLike,
end: TimeLike,
interval_unit: IntervalUnit,
model_allow_partials: bool,
strict: bool = True,
allow_partial: bool = False,
expand: bool = True,
) -> Interval:
"""Transform the inclusive start and end into a [start, end) pair.

Args:
start: The start date/time of the interval (inclusive)
end: The end date/time of the interval (inclusive)
interval_unit: The interval unit.
model_allow_partials: Whether or not the model allows partials.
strict: Whether to fail when the inclusive start is the same as the exclusive end.
allow_partial: Whether the interval can be partial or not.
expand: Whether or not partial intervals are expanded outwards.

Returns:
A [start, end) pair.
"""
start_ts = to_timestamp(interval_unit.cron_floor(start))
if start_ts < to_timestamp(start) and not model_allow_partials:
start_ts = to_timestamp(interval_unit.cron_next(start_ts))
start_dt = interval_unit.cron_floor(start)

if not expand and not allow_partial and start_dt < to_datetime(start):
start_dt = interval_unit.cron_next(start_dt)

start_ts = to_timestamp(start_dt)

if is_date(end):
end = to_datetime(end) + timedelta(days=1)
end_ts = to_timestamp(interval_unit.cron_floor(end) if not allow_partial else end)
if end_ts < start_ts and to_timestamp(end) > to_timestamp(start) and not strict:
# This can happen when the interval unit is coarser than the size of the input interval.
# For example, if the interval unit is monthly, but the input interval is only 1 hour long.
return (start_ts, end_ts)

if (strict and start_ts >= end_ts) or (start_ts > end_ts):
if allow_partial:
end_dt = end
else:
end_dt = interval_unit.cron_floor(end)

if expand and end_dt != to_datetime(end):
end_dt = interval_unit.cron_next(end_dt)

end_ts = to_timestamp(end_dt)

if strict and start_ts >= end_ts:
raise ValueError(
f"`end` ({to_datetime(end_ts)}) must be greater than `start` ({to_datetime(start_ts)})"
)
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/core/state_sync/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def add_interval(
end: The end of the interval to add.
is_dev: Indicates whether the given interval is being added while in development mode
"""
start_ts, end_ts = snapshot.inclusive_exclusive(start, end, strict=False)
start_ts, end_ts = snapshot.inclusive_exclusive(start, end, strict=False, expand=False)
if not snapshot.version:
raise SQLMeshError("Snapshot version must be set to add an interval.")
intervals = [(start_ts, end_ts)]
Expand Down
8 changes: 3 additions & 5 deletions tests/core/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2156,7 +2156,7 @@ def test_restatement_plan_ignores_changes(init_and_plan_context: t.Callable):
assert not plan.new_snapshots
assert plan.requires_backfill
assert plan.restatements == {
restated_snapshot.snapshot_id: (to_timestamp("2023-01-01"), to_timestamp("2023-01-08"))
restated_snapshot.snapshot_id: (to_timestamp("2023-01-01"), to_timestamp("2023-01-09"))
}
assert plan.missing_intervals == [
SnapshotIntervals(
Expand Down Expand Up @@ -4565,16 +4565,14 @@ def test_restatement_of_full_model_with_start(init_and_plan_context: t.Callable)
no_prompts=True,
)

restatement_end = to_timestamp("2023-01-08")

sushi_customer_interval = restatement_plan.restatements[
context.get_snapshot("sushi.customers").snapshot_id
]
assert sushi_customer_interval == (to_timestamp("2023-01-01"), restatement_end)
assert sushi_customer_interval == (to_timestamp("2023-01-01"), to_timestamp("2023-01-09"))
waiter_by_day_interval = restatement_plan.restatements[
context.get_snapshot("sushi.waiter_as_customer_by_day").snapshot_id
]
assert waiter_by_day_interval == (to_timestamp("2023-01-07"), restatement_end)
assert waiter_by_day_interval == (to_timestamp("2023-01-07"), to_timestamp("2023-01-08"))


def initial_add(context: Context, environment: str):
Expand Down
Loading