From 4a607fc90a9eceaf6df3b470f856ead9ccfb708d Mon Sep 17 00:00:00 2001 From: Marko Lekic Date: Tue, 25 Nov 2025 17:01:26 +0100 Subject: [PATCH 1/2] Implemented query stats and budgets in run results + added this to cache demo --- docs/Cost_Monitoring.md | 319 ++++++++++++++ docs/examples/Cache_Demo.md | 23 +- examples/api_demo/profiles.yml | 4 +- examples/api_demo/project.yml | 6 +- examples/basic_demo/Makefile | 2 +- examples/basic_demo/profiles.yml | 4 +- examples/cache_demo/Makefile | 5 +- examples/cache_demo/README.md | 88 +++- examples/cache_demo/budgets.yml | 62 +++ .../models/marts/mart_user_orders.ff.sql | 11 +- .../models/seeds_consumers/stg_orders.ff.sql | 11 +- examples/cache_demo/profiles.yml | 5 +- examples/cache_demo/project.yml | 7 +- examples/ci_demo/profiles.yml | 4 +- examples/dq_demo/profiles.yml | 4 +- examples/dq_demo/project.yml | 9 +- .../.env.dev_bigquery_bigframes | 7 + .../incremental_demo/.env.dev_bigquery_pandas | 7 + examples/incremental_demo/profiles.yml | 4 +- examples/macros_demo/profiles.yml | 4 +- examples/macros_demo/project.yml | 9 +- examples/materializations_demo/profiles.yml | 4 +- examples/materializations_demo/project.yml | 9 +- examples/snapshot_demo/Makefile | 2 +- examples/snapshot_demo/profiles.yml | 4 +- .../building_locally_demo/README.md | 7 + .../building_locally_demo/docs/README.md | 4 + .../building_locally_demo/models/README.md | 4 + .../models/staging/stg_users.ff.sql | 9 + .../building_locally_demo/profiles.yml | 14 + .../building_locally_demo/project.yml | 18 + .../building_locally_demo/seeds/README.md | 4 + .../building_locally_demo/seeds/users.csv | 3 + .../building_locally_demo/sources.yml | 9 + .../building_locally_demo/tests/dq/README.md | 4 + .../tests/unit/README.md | 4 + mkdocs.yml | 1 + pyproject.toml | 2 +- pytest.ini | 1 + src/fastflowtransform/artifacts.py | 46 ++ src/fastflowtransform/cli/bootstrap.py | 31 +- src/fastflowtransform/cli/options.py | 34 +- src/fastflowtransform/cli/run.py | 416 +++++++++++++++++- src/fastflowtransform/cli/snapshot_cmd.py | 10 +- src/fastflowtransform/config/budgets.py | 247 +++++++++++ src/fastflowtransform/executors/base.py | 222 +++++++++- .../executors/bigquery/base.py | 82 ++-- .../executors/bigquery/pandas.py | 17 + src/fastflowtransform/executors/budget.py | 152 +++++++ .../executors/databricks_spark.py | 301 +++++++++++-- src/fastflowtransform/executors/duckdb.py | 409 +++++++++++++++-- src/fastflowtransform/executors/postgres.py | 322 +++++++++++--- .../executors/query_stats.py | 52 +++ .../executors/snowflake_snowpark.py | 226 +++++++++- src/fastflowtransform/incremental.py | 8 +- src/fastflowtransform/run_executor.py | 204 ++++++++- src/fastflowtransform/seeding.py | 4 + .../table_formats/__init__.py | 14 +- src/fastflowtransform/table_formats/base.py | 12 +- .../table_formats/spark_default.py | 9 +- .../table_formats/spark_delta.py | 19 +- .../table_formats/spark_hudi.py | 13 +- .../table_formats/spark_iceberg.py | 19 +- src/fastflowtransform/utils/timefmt.py | 16 + tests/common/fixtures.py | 173 ++++---- tests/conftest.py | 25 ++ tests/integration/examples/config.py | 56 +++ .../examples/test_examples_matrix.py | 42 +- .../test_snapshots_postgres_integration.py | 31 +- tests/unit/cache/test_cache_policy_cli.py | 1 + .../test_databricks_spark_exec_unit.py | 1 - .../unit/executors/test_postgres_exec_unit.py | 5 +- .../executors/test_snowflake_snowpark_exec.py | 12 +- uv.lock | 2 +- 74 files changed, 3449 insertions(+), 482 deletions(-) create mode 100644 docs/Cost_Monitoring.md create mode 100644 examples/cache_demo/budgets.yml create mode 100644 examples/incremental_demo/.env.dev_bigquery_bigframes create mode 100644 examples/incremental_demo/.env.dev_bigquery_pandas create mode 100644 examples_article/building_locally_demo/README.md create mode 100644 examples_article/building_locally_demo/docs/README.md create mode 100644 examples_article/building_locally_demo/models/README.md create mode 100644 examples_article/building_locally_demo/models/staging/stg_users.ff.sql create mode 100644 examples_article/building_locally_demo/profiles.yml create mode 100644 examples_article/building_locally_demo/project.yml create mode 100644 examples_article/building_locally_demo/seeds/README.md create mode 100644 examples_article/building_locally_demo/seeds/users.csv create mode 100644 examples_article/building_locally_demo/sources.yml create mode 100644 examples_article/building_locally_demo/tests/dq/README.md create mode 100644 examples_article/building_locally_demo/tests/unit/README.md create mode 100644 src/fastflowtransform/config/budgets.py create mode 100644 src/fastflowtransform/executors/budget.py create mode 100644 src/fastflowtransform/executors/query_stats.py diff --git a/docs/Cost_Monitoring.md b/docs/Cost_Monitoring.md new file mode 100644 index 0000000..68cde22 --- /dev/null +++ b/docs/Cost_Monitoring.md @@ -0,0 +1,319 @@ +# Cost monitoring, query stats & budgets + +This document describes the new **cost & observability** features in +`fastflowtransform`: + +- Per-query **cost guards** via environment variables +- Per-model **query statistics** in `run_results.json` +- Optional **project-level budgets** via `budgets.yml` + +All of these are *opt-in* and backwards-compatible: if you do nothing, existing +projects keep working exactly as before. + +--- + +## 1. Per-query cost guards (budgets.yml or FF_*_MAX_BYTES) + +For engines that can estimate query size, FFT can now abort a query **before** +running it if it would scan more data than you allow. + +### Supported engines & env vars + +| Engine | Env var | Applies to | +|------------------------|-----------------------|---------------------------------------| +| BigQuery | `FF_BQ_MAX_BYTES` | All model-driven SQL queries | +| DuckDB | `FF_DUCKDB_MAX_BYTES` | All model-driven SQL queries | +| Postgres | `FF_PG_MAX_BYTES` | All model-driven SQL queries | +| Databricks / Spark | `FF_SPK_MAX_BYTES` | All `spark.sql(...)` queries | +| Snowflake (Snowpark) | `FF_SF_MAX_BYTES` | All `session.sql(...)` queries | + +If the env var is **unset or ≤ 0**, the guard is **disabled** for that engine. + +### Value format + +All `FF_*_MAX_BYTES` use the same parser as `fastflowtransform.executors.budget.parse_max_bytes_env`: + +- Plain integers: `5000000000` +- With unit suffix (case-insensitive, `_` and `,` ignored): + + - `10GB`, `5G` + - `750MB`, `100M` + - `42KB`, `10K` + - Units are powers of 1024 + +Examples: + +```bash +# Limit BigQuery queries to ~1 GB +export FF_BQ_MAX_BYTES=1GB + +# Limit DuckDB queries to ~2 GB +export FF_DUCKDB_MAX_BYTES=2GB + +# Limit Snowflake queries to ~100 MB +export FF_SF_MAX_BYTES=100mb + +# Limit Postgres queries to ~500 MB +export FF_PG_MAX_BYTES=500M + +# Limit Spark queries to ~10 GB +export FF_SPK_MAX_BYTES=10g +```` + +### Behaviour + +1. **Before running a query**, the executor asks the engine for a best-effort + size estimate: + + * BigQuery: dry-run load/query job reports logical bytes. + * DuckDB: `EXPLAIN (FORMAT JSON)` stats (cardinality × row width), best-effort. + * Postgres: `EXPLAIN (FORMAT JSON)` → `Plan Rows × Plan Width`. + * Spark/Databricks: `DataFrame._jdf.queryExecution().optimizedPlan().stats().sizeInBytes`. + * Snowflake Snowpark: `EXPLAIN USING JSON` and `GlobalStats.bytesAssigned`. + +2. If the estimate is **`None`** or **≤ 0** (cannot estimate), the query runs + as usual and is **not blocked**. + +3. If the estimate is **> limit**, FFT raises a `RuntimeError` *before* the + query is submitted: + + > Aborting query: estimated scanned bytes 123456789 (117 MB) + > exceed FF_XY_MAX_BYTES=100000000 (95.4 MB). + > Increase FF_XY_MAX_BYTES or unset it to allow this query. + +The CLI surfaces this error as a **failed model** like any other engine error. + +### Configuring limits in budgets.yml + +You can define the same per-query limits directly in `budgets.yml` so that +every engineer sees a single source of truth: + +```yaml +version: 1 + +query_limits: + duckdb: + max_bytes: 2_000_000_000 + postgres: + max_bytes: 500_000_000 +``` + +When `query_limits..max_bytes` is set, FFT automatically enforces it for +that engine. Environment variables still work and explicitly override the YAML +value, which is useful for quick experiments in CI or ad-hoc runs. + +--- + +## 2. Query statistics per model (run_results.json) + +FFT now collects **per-model query statistics** and writes them into +`.fastflowtransform/target/run_results.json` alongside the usual run +information. + +### What is recorded? + +For each **model** in a run, FFT now aggregates (best effort): + +* `bytes_scanned`: total bytes processed by all SQL queries in that model +* `rows`: total rows affected/returned (where the engine reports it) +* `query_duration_ms`: total time spent *inside the engine* for SQL queries + (excluding Python overhead) + +These are stored on each `RunNodeResult`: + +```jsonc +{ + "results": [ + { + "name": "fct_events", + "status": "success", + "duration_ms": 1234, + "bytes_scanned": 987654321, + "rows": 123456, + "query_duration_ms": 1100 + }, + { + "name": "dim_users", + "status": "success", + "duration_ms": 250, + "bytes_scanned": 1048576, + "rows": 1000, + "query_duration_ms": 200 + } + ] +} +``` + +Engines that don’t support a metric simply leave it as `null`/`0`; all +statistics are **best effort** and must never break a run. + +### Where does it come from? + +* Each executor implements a private `_execute_sql(...)` method that is used for + all model-driven SQL. +* After each query, the executor updates its internal “current node stats”. +* At the end of a model run, `_RunEngine` calls `executor.get_node_stats()` and + stores the aggregated numbers in `run_results.json`. + +This design means: + +* **All SQL** executed by FFT for a model contributes to these stats. +* Python models now contribute metrics on engines that can inspect their DataFrames: + * Postgres (pandas): uses in-memory dataframe size / row count. + * BigQuery (pandas): `load_table_from_dataframe` duration + pandas memory usage. + * Databricks/Spark: optimized plan stats via Snowpark/Spark query execution. + * Snowflake Snowpark: `EXPLAIN USING JSON` for the Snowpark plan. + Engines without support fall back to `None` for bytes/rows. + +### How to use it + +You can post-process `run_results.json` to: + +* Identify the **most expensive models** (by bytes scanned or query time) +* Build dashboards / cost reports in your own tooling +* Feed data into a documentation UI (e.g. show last run cost per model) + +Example (pseudo-Python): + +```python +import json +from pathlib import Path + +path = Path("my_project/.fastflowtransform/target/run_results.json") +data = json.loads(path.read_text()) + +for node in data["results"]: + print( + node["name"], + node.get("bytes_scanned"), + node.get("query_duration_ms"), + ) +``` + +--- + +## 3. Project-level budgets (budgets.yml) + +In addition to per-query env guards, you can define **project-level budgets** in +a `budgets.yml` file at the project root: + +* Env vars (`FF_*_MAX_BYTES`) protect **individual queries** +* `budgets.yml` protects the **entire run** (and optional per-engine / per-model + slices) based on the aggregated stats in `run_results.json` + +> ⚠️ If you don’t create `budgets.yml`, nothing changes – budgets are fully +> opt-in. + +### File location + +Place the file next to `project.yml`: + +```text +my_project/ + project.yml + budgets.yml # ← new + models/ + ... +``` + +### Basic example + +```yaml +version: 1 + +# Default behaviour when a budget is exceeded. +# One of: "warn", "error" +defaults: + on_exceed: "warn" + +budgets: + # Global cap for total bytes scanned by all models. + total_bytes_scanned: + warn_after: "10GB" # log a warning above this + error_after: "50GB" # fail the run above this + + # Optional: total query time across all models. + total_query_duration_ms: + warn_after: 600000 # 10 minutes + error_after: 3600000 # 60 minutes + + # Per-engine budgets + by_engine: + bigquery: + bytes_scanned: + warn_after: "5GB" + error_after: "20GB" + + snowflake_snowpark: + bytes_scanned: + warn_after: "2GB" + error_after: "10GB" + + # Optional per-model budgets (key = model name) + per_model: + fct_events: + bytes_scanned: + warn_after: "500MB" + error_after: "2GB" + + dim_users: + bytes_scanned: + warn_after: "100MB" +``` + +#### Value syntax + +* `warn_after` / `error_after` for **bytes** use the same notation as + `FF_*_MAX_BYTES`: + + ```yaml + warn_after: "5GB" + error_after: "500_000_000" + ``` + +* Durations are currently in **milliseconds** (plain integers). + +### When are budgets evaluated? + +1. You run `fft run` as usual. +2. FFT executes models and writes `run_results.json` with the new stats. +3. After the run completes, FFT: + + * Aggregates `bytes_scanned` / `query_duration_ms` across all `RunNodeResult` + * Compares the totals to the budgets defined in `budgets.yml` + * Emits **warnings** and/or fails the run depending on your config + +### What happens on exceed? + +* If a **warn** budget is exceeded, FFT prints a warning but the CLI exit code + stays `0` (unless there were other errors). +* If an **error** budget is exceeded, FFT treats it like a failed run and exits + with `1`. + +If both `warn_after` and `error_after` are defined and exceeded, the **error** +behaviour wins. + +### Interaction with env-level guards + +* Env vars (`FF_*_MAX_BYTES`) are **hard per-query gates**. If an estimate is + above the limit, the query is never sent to the engine. +* Budgets (`budgets.yml`) are **soft run-level gates**. They look at what was + actually executed and may warn or fail after the fact. + +You can happily combine both: + +* Env vars to avoid accidentally launching a single “monster query” +* Budgets to catch an overall run that’s becoming too expensive over time + +--- + +## 4. Backwards compatibility + +* All new features are **opt-in**: + + * No env var → no per-query guard + * No `budgets.yml` → no run-level budgets +* Existing `run_results.json` readers remain valid; they simply gain more fields + (`bytes_scanned`, `rows`, `query_duration_ms`). +* Engines that can’t provide certain metrics just leave them empty/`null`. No + behaviour changes for those engines. diff --git a/docs/examples/Cache_Demo.md b/docs/examples/Cache_Demo.md index 6cfecf5..4c0a508 100644 --- a/docs/examples/Cache_Demo.md +++ b/docs/examples/Cache_Demo.md @@ -44,6 +44,7 @@ This demo showcases several FastFlowTransform features: | Environment invalidation | Any `FF_*` change triggers rebuild | | Python model caching | Fingerprints derived from function source | | HTTP response caching | Persistent API result cache with offline mode | +| Cost guards + budgets | `budgets.yml → query_limits` + run-level budgets | --- @@ -58,6 +59,7 @@ make change_seed # use patches/seed_users_patch.csv -> rebuilds staging + make change_env # set FF_* env -> invalidates cache globally make change_py # edit py_constants.ff.py -> rebuilds that model make run_parallel # runs entire DAG with 4 workers per level +make cost_guard_example ENGINE=duckdb # demonstrate per-query guard ``` > Engines: set `ENGINE=` and copy the matching `.env.dev_*` file (`.env.dev_snowflake` for Snowflake; install `fastflowtransform[snowflake]`). @@ -67,7 +69,7 @@ Seeds stay immutable: `change_seed` assembles a temporary combined copy in `.loc Inspect results: -* `.fastflowtransform/target/run_results.json` – fingerprints, results, timings, HTTP stats +* `.fastflowtransform/target/run_results.json` – per-model stats (bytes, rows, durations, HTTP) * `site/dag/index.html` – DAG visualization * `.local/http-cache/` – persisted API responses @@ -154,15 +156,16 @@ fft run . --env dev_duckdb --jobs 4 ## 🧪 Example Experiments -| Scenario | Command | Expected behavior | -| ------------------------- | -------------------------------------- | ------------------------------- | -| First full run | `make cache_first` | All models build, cache written | -| No-op run | `make cache_second` | All skipped (no rebuilds) | -| Modify SQL | `make change_sql` | Downstream mart rebuilds | -| Add seed row | `make change_seed` | Staging + mart rebuild (temp combined seed from patches/) | -| Change env | `make change_env` | All nodes rebuild | -| Edit Python constant | `make change_py` | Only that Python model rebuilds | -| Warm & offline HTTP cache | `make http_first && make http_offline` | HTTP cache reused, no network | +| Scenario | Command | Expected behavior | +| ------------------------- | ---------------------------------------- | ------------------------------------- | +| First full run | `make cache_first` | All models build, cache written | +| No-op run | `make cache_second` | All skipped (no rebuilds) | +| Modify SQL | `make change_sql` | Downstream mart rebuilds | +| Add seed row | `make change_seed` | Staging + mart rebuild using patches/ | +| Change env | `make change_env` | All nodes rebuild | +| Edit Python constant | `make change_py` | Only that Python model rebuilds | +| Warm & offline HTTP cache | `make http_first && make http_offline` | HTTP cache reused, no network | +| Per-query guard demo | `make cost_guard_example ENGINE=duckdb` | Query aborted by bytes limit | --- diff --git a/examples/api_demo/profiles.yml b/examples/api_demo/profiles.yml index 16fc055..68a08c0 100644 --- a/examples/api_demo/profiles.yml +++ b/examples/api_demo/profiles.yml @@ -23,7 +23,9 @@ dev_databricks: spark.hadoop.datanucleus.rdbms.datastoreAdapterClassName: "org.datanucleus.store.rdbms.adapter.DerbyAdapter" spark.hadoop.datanucleus.schema.autoCreateAll: "true" spark.hadoop.javax.jdo.option.ConnectionDriverName: "org.apache.derby.jdbc.EmbeddedDriver" - spark.driver.extraJavaOptions: "-Dderby.stream.error.file={{ project_dir() }}/.local/derby.log" + spark.driver.extraJavaOptions: > + -Dderby.stream.error.file={{ project_dir() }}/.local/derby.log + -Dderby.system.home={{ project_dir() }}/.local/derby_home dev_bigquery_bigframes: engine: bigquery diff --git a/examples/api_demo/project.yml b/examples/api_demo/project.yml index 13f8276..3f8dd8f 100644 --- a/examples/api_demo/project.yml +++ b/examples/api_demo/project.yml @@ -5,11 +5,7 @@ vars: {} models: {} -seeds: - storage: - seed_users: - path: ".local/spark/seed_users" - format: parquet +seeds: {} tests: - type: not_null diff --git a/examples/basic_demo/Makefile b/examples/basic_demo/Makefile index 22f0d63..558bc20 100644 --- a/examples/basic_demo/Makefile +++ b/examples/basic_demo/Makefile @@ -64,7 +64,7 @@ else ifeq ($(ENGINE),bigquery) else ifeq ($(ENGINE),snowflake_snowpark) CLEAN_CMD = env $(BASE_ENV) $(UV) run python $(CLEAN_SCRIPT) --engine snowflake_snowpark --env "$(PROFILE_ENV)" --project "$(PROJECT)" else - $(error Unsupported ENGINE=$(ENGINE) - pick duckdb|postgres|databricks_spark) + $(error Unsupported ENGINE=$(ENGINE) - pick duckdb|postgres|databricks_spark|bigquery|snowflake_snowpark) endif # --- Targets ---------------------------------------------------------------- diff --git a/examples/basic_demo/profiles.yml b/examples/basic_demo/profiles.yml index 361d0c5..94e0330 100644 --- a/examples/basic_demo/profiles.yml +++ b/examples/basic_demo/profiles.yml @@ -24,7 +24,9 @@ dev_databricks: spark.hadoop.datanucleus.rdbms.datastoreAdapterClassName: "org.datanucleus.store.rdbms.adapter.DerbyAdapter" spark.hadoop.datanucleus.schema.autoCreateAll: "true" spark.hadoop.javax.jdo.option.ConnectionDriverName: "org.apache.derby.jdbc.EmbeddedDriver" - spark.driver.extraJavaOptions: "-Dderby.stream.error.file={{ project_dir() }}/.local/derby.log" + spark.driver.extraJavaOptions: > + -Dderby.stream.error.file={{ project_dir() }}/.local/derby.log + -Dderby.system.home={{ project_dir() }}/.local/derby_home dev_bigquery_bigframes: engine: bigquery diff --git a/examples/cache_demo/Makefile b/examples/cache_demo/Makefile index 3511d0d..42029bf 100644 --- a/examples/cache_demo/Makefile +++ b/examples/cache_demo/Makefile @@ -1,7 +1,7 @@ .PHONY: seed run run_parallel cache_first cache_second \ change_sql change_seed change_env change_py \ http_first http_offline http_cache_clear artifacts dag clean \ - demo + demo cost_guard_example ENGINE ?= duckdb # duckdb | postgres | databricks_spark | bigquery | snowflake_snowpark # BigQuery frame selector (pandas | bigframes) @@ -73,7 +73,7 @@ run: run_parallel: env $(RUN_ENV) $(UV) run fft run "$(PROJECT)" --env $(PROFILE_ENV) $(SELECT_ALL) --cache=rw --jobs 4 -cache_first: seed run +cache_first: run cache_second: run change_sql: @@ -124,6 +124,7 @@ clean: demo: clean @echo "== 🚀 Cache Demo ($(ENGINE)) ==" @echo "== 1) First full build (writes cache) ==" + +$(MAKE) seed +$(MAKE) cache_first @echo @echo "== 2) No-op run (should skip everything) ==" diff --git a/examples/cache_demo/README.md b/examples/cache_demo/README.md index 62ca284..0855dcb 100644 --- a/examples/cache_demo/README.md +++ b/examples/cache_demo/README.md @@ -6,6 +6,7 @@ This demo shows: - Environment-driven invalidation (only `FF_*`) - Parallelism within levels (`--jobs`) - HTTP response cache + offline mode +- Bytes-based cost monitoring (per-query guards + run-level budgets) ## Quickstart @@ -37,8 +38,91 @@ site/dag/index.html .fastflowtransform/target/run_results.json (HTTP stats, results) -markdown -Code kopieren +## Cost & budgets (bytes-based) + +This example also showcases the **bytes-based cost features**: + +1. **Per-query cost guard** (`budgets.yml → query_limits`, overridable with env vars) +2. **Per-run budgets** (via `budgets.yml`) +3. **Per-model stats** in `run_results.json` + +### 1. Per-query guard demo + +Each engine now reads its default per-query limit from `budgets.yml` (`query_limits..max_bytes`), +but you can override it ad-hoc via the legacy env vars: + +| Engine | Env var | +|---------------------|----------------------| +| DuckDB | `FF_DUCKDB_MAX_BYTES`| +| Postgres | `FF_PG_MAX_BYTES` | +| Databricks / Spark | `FF_SPK_MAX_BYTES` | +| BigQuery | `FF_BQ_MAX_BYTES` | +| Snowflake Snowpark | `FF_SF_MAX_BYTES` | + +The Makefile wraps a tiny demo that intentionally sets the env-var override so low that the query is blocked: + +```bash +# Pick an engine (defaults to duckdb) +make cost_guard_example ENGINE=duckdb +# or: +make cost_guard_example ENGINE=postgres +make cost_guard_example ENGINE=bigquery BQ_FRAME=bigframes +``` + +Under the hood this runs: + +- `FF_*_MAX_BYTES=1` → the executor estimates bytes before running SQL +- If the estimate is > 1 byte, the query is **aborted** with a clear RuntimeError +- The `make` target uses `|| true` so you can still see the error without the overall demo failing + +Unset the env var (or set it to a large value) to disable the guard again. + +### 2. Run-level budgets (`budgets.yml`) + +`examples/cache_demo/budgets.yml` defines the **project-level budgets** and the per-query limits: + +- `query_limits..max_bytes` – aborts a single query before execution +- `total.bytes_scanned` – aggregate over all models (driven by run stats) +- `total.query_duration_ms` – aggregate query time +- `models.mart_user_orders.ff.bytes_scanned` – per-model example + +By default all budgets are using `on_exceed: "warn"`, so they **never break** the main demo. If a run exceeds any threshold, FFT will print a warning after the run finishes (based on `run_results.json`). + +If you want to see a *failing* run: + +```yaml +# In budgets.yml, temporarily tighten one budget, e.g.: +per_model: + mart_user_orders.ff: + bytes_scanned: + warn_after: "1KB" + error_after: "2KB" # will likely fail even this tiny demo +``` + +Then run: + +```bash +make run ENGINE=duckdb +``` + +and the CLI will exit with code `1` once the budget is exceeded. + +### 3. Inspecting bytes & timing stats + +After any run, `run_results.json` includes per-model stats (best effort per engine): + +- `bytes_scanned` +- `rows` +- `query_duration_ms` + +Example inspection: + +```bash +jq '.results[] | {name, status, bytes_scanned, query_duration_ms}' \ + .fastflowtransform/target/run_results.json +``` + +You can use this to build your own cost dashboards or to tune `budgets.yml` for your real project. --- diff --git a/examples/cache_demo/budgets.yml b/examples/cache_demo/budgets.yml new file mode 100644 index 0000000..8767053 --- /dev/null +++ b/examples/cache_demo/budgets.yml @@ -0,0 +1,62 @@ +version: 1 + +# Per-engine query limits (applied before executing individual queries) +query_limits: + duckdb: + max_bytes: 5_000_000 + postgres: + max_bytes: 10_000_000 + bigquery: + max_bytes: 50_000_000 + databricks_spark: + max_bytes: 50_000_000 + snowflake_snowpark: + max_bytes: 50_000_000 + +# Global limits across the entire fft run +total: + bytes_scanned: + # ~10 MB – adjust down if you want to force a warning + warn: 100 + # ~100 MB – adjust down if you want to force an error + error: 100_000_000 + + # Optional: total query time across all queries in the run + query_duration_ms: + warn: "30s" # human-friendly duration, parsed to ms + error: "2m" + +# Per-model limits (keys must match node names: stg_users.ff, mart_user_orders.ff, http_users, ...) +models: + stg_users.ff: + bytes_scanned: + # keep this fairly low so you can see a warn if you want + warn: 100 + error: 10_000_000 + + stg_orders.ff: + bytes_scanned: + warn: 1_000_000 + error: 10_000_000 + + mart_user_orders.ff: + bytes_scanned: + warn: 1_000_000 + error: 100_000_000 + + http_users: + # HTTP model → mainly interesting on engines that can report bytes_scanned + bytes_scanned: + warn: 5_000_000 + error: 50_000_000 + + py_constants: + bytes_scanned: + warn: 5_000_000 + error: 50_000_000 + +# Per-tag budgets (aggregated over all models with that tag) +tags: + "example:cache_demo": + bytes_scanned: + warn: 10_000_000 diff --git a/examples/cache_demo/models/marts/mart_user_orders.ff.sql b/examples/cache_demo/models/marts/mart_user_orders.ff.sql index f2bcce5..2d2e7d4 100644 --- a/examples/cache_demo/models/marts/mart_user_orders.ff.sql +++ b/examples/cache_demo/models/marts/mart_user_orders.ff.sql @@ -1,4 +1,13 @@ -{{ config(materialized='table', tags=['example:cache_demo','engine:duckdb','engine:postgres','engine:databricks_spark','engine:bigquery','engine:snowflake_snowpark']) }} +{{ config( + materialized='table', + tags=[ + 'example:cache_demo', + 'engine:duckdb', + 'engine:postgres', + 'engine:databricks_spark', + 'engine:bigquery', + 'engine:snowflake_snowpark' + ]) }} with u as ( select user_id, email from {{ ref('stg_users.ff') }} ), diff --git a/examples/cache_demo/models/seeds_consumers/stg_orders.ff.sql b/examples/cache_demo/models/seeds_consumers/stg_orders.ff.sql index ff87f94..f3b7437 100644 --- a/examples/cache_demo/models/seeds_consumers/stg_orders.ff.sql +++ b/examples/cache_demo/models/seeds_consumers/stg_orders.ff.sql @@ -1,4 +1,13 @@ -{{ config(materialized='view', tags=['example:cache_demo','engine:duckdb','engine:postgres','engine:databricks_spark','engine:bigquery','engine:snowflake_snowpark']) }} +{{ config( + materialized='view', + tags=[ + 'example:cache_demo', + 'engine:duckdb', + 'engine:postgres', + 'engine:databricks_spark', + 'engine:bigquery', + 'engine:snowflake_snowpark' + ]) }} select cast(order_id as int) as order_id, cast(customer_id as int) as user_id, diff --git a/examples/cache_demo/profiles.yml b/examples/cache_demo/profiles.yml index a1d96ac..3aa3926 100644 --- a/examples/cache_demo/profiles.yml +++ b/examples/cache_demo/profiles.yml @@ -17,8 +17,11 @@ dev_databricks: warehouse_dir: "{{ project_dir() }}/{{ env('FF_DBR_WAREHOUSE_DIR', '.local/spark_warehouse') }}" database: "{{ env('FF_DBR_DATABASE', 'cache_demo') }}" extra_conf: + spark.sql.statistics.fallBackToHdfs: true spark.sql.shuffle.partitions: "{{ env('SPARK_SQL_SHUFFLE_PARTITIONS', '8') }}" - spark.driver.extraJavaOptions: "-Dderby.stream.error.file={{ project_dir() }}/.local/derby.log" + spark.driver.extraJavaOptions: > + -Dderby.stream.error.file={{ project_dir() }}/.local/derby.log + -Dderby.system.home={{ project_dir() }}/.local/derby_home dev_bigquery_bigframes: engine: bigquery diff --git a/examples/cache_demo/project.yml b/examples/cache_demo/project.yml index 77a0e83..7c81893 100644 --- a/examples/cache_demo/project.yml +++ b/examples/cache_demo/project.yml @@ -5,12 +5,7 @@ vars: {} models: {} -seeds: - storage: - seed_users: - path: ".local/spark/seed_users" - seed_orders: - path: ".local/spark/seed_orders" +seeds: {} tests: - type: row_count_between diff --git a/examples/ci_demo/profiles.yml b/examples/ci_demo/profiles.yml index c48eb3c..2994e97 100644 --- a/examples/ci_demo/profiles.yml +++ b/examples/ci_demo/profiles.yml @@ -23,7 +23,9 @@ dev_databricks: spark.hadoop.datanucleus.rdbms.datastoreAdapterClassName: "org.datanucleus.store.rdbms.adapter.DerbyAdapter" spark.hadoop.datanucleus.schema.autoCreateAll: "true" spark.hadoop.javax.jdo.option.ConnectionDriverName: "org.apache.derby.jdbc.EmbeddedDriver" - spark.driver.extraJavaOptions: "-Dderby.stream.error.file={{ project_dir() }}/.local/derby.log" + spark.driver.extraJavaOptions: > + -Dderby.stream.error.file={{ project_dir() }}/.local/derby.log + -Dderby.system.home={{ project_dir() }}/.local/derby_home dev_bigquery_bigframes: engine: bigquery diff --git a/examples/dq_demo/profiles.yml b/examples/dq_demo/profiles.yml index be54f0e..8e61f15 100644 --- a/examples/dq_demo/profiles.yml +++ b/examples/dq_demo/profiles.yml @@ -23,7 +23,9 @@ dev_databricks: spark.hadoop.datanucleus.rdbms.datastoreAdapterClassName: "org.datanucleus.store.rdbms.adapter.DerbyAdapter" spark.hadoop.datanucleus.schema.autoCreateAll: "true" spark.hadoop.javax.jdo.option.ConnectionDriverName: "org.apache.derby.jdbc.EmbeddedDriver" - spark.driver.extraJavaOptions: "-Dderby.stream.error.file={{ project_dir() }}/.local/derby.log" + spark.driver.extraJavaOptions: > + -Dderby.stream.error.file={{ project_dir() }}/.local/derby.log + -Dderby.system.home={{ project_dir() }}/.local/derby_home dev_bigquery_bigframes: engine: bigquery diff --git a/examples/dq_demo/project.yml b/examples/dq_demo/project.yml index 63cef31..49b6f85 100644 --- a/examples/dq_demo/project.yml +++ b/examples/dq_demo/project.yml @@ -5,14 +5,7 @@ vars: {} models: {} -seeds: - storage: - seed_customers: - path: ".local/spark/seed_customers" - format: parquet - seed_orders: - path: ".local/spark/seed_orders" - format: parquet +seeds: {} tests: # --- Single-table checks ---------------------------------------------------- diff --git a/examples/incremental_demo/.env.dev_bigquery_bigframes b/examples/incremental_demo/.env.dev_bigquery_bigframes new file mode 100644 index 0000000..5264dc7 --- /dev/null +++ b/examples/incremental_demo/.env.dev_bigquery_bigframes @@ -0,0 +1,7 @@ +# BigQuery profile for the basic demo +FF_BQ_PROJECT=fft-basic-demo +FF_BQ_DATASET=incremental_demo +FF_BQ_LOCATION=EU + +# Path to service account JSON key (or rely on gcloud / workload identity) +GOOGLE_APPLICATION_CREDENTIALS=../secrets/fft-bigquery-demo-key.json diff --git a/examples/incremental_demo/.env.dev_bigquery_pandas b/examples/incremental_demo/.env.dev_bigquery_pandas new file mode 100644 index 0000000..5264dc7 --- /dev/null +++ b/examples/incremental_demo/.env.dev_bigquery_pandas @@ -0,0 +1,7 @@ +# BigQuery profile for the basic demo +FF_BQ_PROJECT=fft-basic-demo +FF_BQ_DATASET=incremental_demo +FF_BQ_LOCATION=EU + +# Path to service account JSON key (or rely on gcloud / workload identity) +GOOGLE_APPLICATION_CREDENTIALS=../secrets/fft-bigquery-demo-key.json diff --git a/examples/incremental_demo/profiles.yml b/examples/incremental_demo/profiles.yml index cf6e309..50bff28 100644 --- a/examples/incremental_demo/profiles.yml +++ b/examples/incremental_demo/profiles.yml @@ -21,7 +21,9 @@ dev_databricks_parquet: &incremental_databricks_parquet spark.hadoop.datanucleus.rdbms.datastoreAdapterClassName: "org.datanucleus.store.rdbms.adapter.DerbyAdapter" spark.hadoop.datanucleus.schema.autoCreateAll: "true" spark.hadoop.javax.jdo.option.ConnectionDriverName: "org.apache.derby.jdbc.EmbeddedDriver" - spark.driver.extraJavaOptions: "-Dderby.stream.error.file={{ project_dir() }}/.local/derby.log" + spark.driver.extraJavaOptions: > + -Dderby.stream.error.file={{ project_dir() }}/.local/derby.log + -Dderby.system.home={{ project_dir() }}/.local/derby_home dev_databricks_delta: engine: databricks_spark diff --git a/examples/macros_demo/profiles.yml b/examples/macros_demo/profiles.yml index 4c35722..495c075 100644 --- a/examples/macros_demo/profiles.yml +++ b/examples/macros_demo/profiles.yml @@ -20,7 +20,9 @@ dev_databricks: spark.hadoop.datanucleus.rdbms.datastoreAdapterClassName: "org.datanucleus.store.rdbms.adapter.DerbyAdapter" spark.hadoop.datanucleus.schema.autoCreateAll: "true" spark.hadoop.javax.jdo.option.ConnectionDriverName: "org.apache.derby.jdbc.EmbeddedDriver" - spark.driver.extraJavaOptions: "-Dderby.stream.error.file={{ project_dir() }}/.local/derby.log" + spark.driver.extraJavaOptions: > + -Dderby.stream.error.file={{ project_dir() }}/.local/derby.log + -Dderby.system.home={{ project_dir() }}/.local/derby_home dev_bigquery_bigframes: engine: bigquery diff --git a/examples/macros_demo/project.yml b/examples/macros_demo/project.yml index 3d0a371..bf210b5 100644 --- a/examples/macros_demo/project.yml +++ b/examples/macros_demo/project.yml @@ -7,14 +7,7 @@ vars: models: {} -seeds: - storage: - seed_users: - path: ".local/spark/seed_users" - format: parquet - seed_orders: - path: ".local/spark/seed_orders" - format: parquet +seeds: {} tests: - type: not_null diff --git a/examples/materializations_demo/profiles.yml b/examples/materializations_demo/profiles.yml index 5d4e79f..8e51089 100644 --- a/examples/materializations_demo/profiles.yml +++ b/examples/materializations_demo/profiles.yml @@ -22,7 +22,9 @@ dev_databricks: spark.hadoop.datanucleus.rdbms.datastoreAdapterClassName: "org.datanucleus.store.rdbms.adapter.DerbyAdapter" spark.hadoop.datanucleus.schema.autoCreateAll: "true" spark.hadoop.javax.jdo.option.ConnectionDriverName: "org.apache.derby.jdbc.EmbeddedDriver" - spark.driver.extraJavaOptions: "-Dderby.stream.error.file={{ project_dir() }}/.local/derby.log" + spark.driver.extraJavaOptions: > + -Dderby.stream.error.file={{ project_dir() }}/.local/derby.log + -Dderby.system.home={{ project_dir() }}/.local/derby_home dev_bigquery_bigframes: engine: bigquery diff --git a/examples/materializations_demo/project.yml b/examples/materializations_demo/project.yml index 7acb2dd..836d117 100644 --- a/examples/materializations_demo/project.yml +++ b/examples/materializations_demo/project.yml @@ -5,14 +5,7 @@ vars: {} models: {} -seeds: - storage: - seed_customers: - path: ".local/spark/seed_customers" - format: parquet - seed_orders: - path: ".local/spark/seed_orders" - format: parquet +seeds: {} tests: # Simple sanity tests; add more if you like diff --git a/examples/snapshot_demo/Makefile b/examples/snapshot_demo/Makefile index 0ee8b27..3c98f26 100644 --- a/examples/snapshot_demo/Makefile +++ b/examples/snapshot_demo/Makefile @@ -76,7 +76,7 @@ else ifeq ($(ENGINE),bigquery) else ifeq ($(ENGINE),snowflake_snowpark) CLEAN_CMD = env $(BASE_ENV) $(UV) run python $(CLEAN_SCRIPT) --engine snowflake_snowpark --env "$(PROFILE_ENV)" --project "$(PROJECT)" else - $(error Unsupported ENGINE=$(ENGINE) - pick duckdb|postgres|databricks_spark|bigquery) + $(error Unsupported ENGINE=$(ENGINE) - pick duckdb|postgres|databricks_spark|bigquery|snowflake_snowpark) endif # --- Targets ---------------------------------------------------------------- diff --git a/examples/snapshot_demo/profiles.yml b/examples/snapshot_demo/profiles.yml index 5fbca70..56fe2e1 100644 --- a/examples/snapshot_demo/profiles.yml +++ b/examples/snapshot_demo/profiles.yml @@ -23,7 +23,9 @@ dev_databricks_parquet: &snapshot_demo_databricks_parquet spark.hadoop.datanucleus.rdbms.datastoreAdapterClassName: "org.datanucleus.store.rdbms.adapter.DerbyAdapter" spark.hadoop.datanucleus.schema.autoCreateAll: "true" spark.hadoop.javax.jdo.option.ConnectionDriverName: "org.apache.derby.jdbc.EmbeddedDriver" - spark.driver.extraJavaOptions: "-Dderby.stream.error.file={{ project_dir() }}/.local/derby.log" + spark.driver.extraJavaOptions: > + -Dderby.stream.error.file={{ project_dir() }}/.local/derby.log + -Dderby.system.home={{ project_dir() }}/.local/derby_home dev_databricks_delta: engine: databricks_spark diff --git a/examples_article/building_locally_demo/README.md b/examples_article/building_locally_demo/README.md new file mode 100644 index 0000000..5e977f7 --- /dev/null +++ b/examples_article/building_locally_demo/README.md @@ -0,0 +1,7 @@ +# FastFlowTransform project scaffold + +This project was created with `fft init`. +Next steps: +1. Update `profiles.yml` with real connection details (docs/Profiles.md). +2. Add sources in `sources.yml` and author models under `models/` (docs/Config_and_Macros.md). +3. Seed sample data with `fft seed` and execute models with `fft run` (docs/Quickstart.md). diff --git a/examples_article/building_locally_demo/docs/README.md b/examples_article/building_locally_demo/docs/README.md new file mode 100644 index 0000000..69e73e7 --- /dev/null +++ b/examples_article/building_locally_demo/docs/README.md @@ -0,0 +1,4 @@ +# Project documentation + +Write operator or contributor notes here and keep them in sync with generated docs. +See docs/Technical_Overview.md#auto-docs-and-lineage for `fft dag` / `fft docgen` guidance. diff --git a/examples_article/building_locally_demo/models/README.md b/examples_article/building_locally_demo/models/README.md new file mode 100644 index 0000000..32818bb --- /dev/null +++ b/examples_article/building_locally_demo/models/README.md @@ -0,0 +1,4 @@ +# Models directory + +Place SQL (`*.ff.sql`) and Python (`*.ff.py`) models here. +See docs/Config_and_Macros.md for modeling guidance and config options. diff --git a/examples_article/building_locally_demo/models/staging/stg_users.ff.sql b/examples_article/building_locally_demo/models/staging/stg_users.ff.sql new file mode 100644 index 0000000..412d37d --- /dev/null +++ b/examples_article/building_locally_demo/models/staging/stg_users.ff.sql @@ -0,0 +1,9 @@ +{{ config(materialized='table') }} + +select + id, + -- Standardizing email casing + lower(email) as email, + -- Casting types explicitly + cast(signup_date as date) as signup_date +from {{ source('raw', 'users') }} \ No newline at end of file diff --git a/examples_article/building_locally_demo/profiles.yml b/examples_article/building_locally_demo/profiles.yml new file mode 100644 index 0000000..e0ef1f0 --- /dev/null +++ b/examples_article/building_locally_demo/profiles.yml @@ -0,0 +1,14 @@ +# My Local Playground +dev_duckdb: + engine: duckdb + duckdb: + path: ".local/dev.duckdb" + +# My Production Environment +prod_bigquery: + engine: bigquery + bigquery: + project: "{{ env('FF_BQ_PROJECT') }}" + dataset: "production_marts" + # FFT handles the client (BigFrames or Pandas) automatically + use_bigframes: true \ No newline at end of file diff --git a/examples_article/building_locally_demo/project.yml b/examples_article/building_locally_demo/project.yml new file mode 100644 index 0000000..c16dbca --- /dev/null +++ b/examples_article/building_locally_demo/project.yml @@ -0,0 +1,18 @@ +# Project configuration generated by `fft init`. +# Read docs/Project_Config.md for the complete reference. +name: building_locally_demo +version: "0.1" +models_dir: models + +docs: + # Adjust `dag_dir` to change where `fft dag --html` writes documentation (docs/Technical_Overview.md#auto-docs-and-lineage). + dag_dir: site/dag + +# Project-level variables accessible via {{ var('key') }} inside models. +# Example: +# vars: +# run_date: "2024-01-01" +vars: {} + +# Declare project-wide data quality checks under `tests`. See docs/Data_Quality_Tests.md. +tests: [] diff --git a/examples_article/building_locally_demo/seeds/README.md b/examples_article/building_locally_demo/seeds/README.md new file mode 100644 index 0000000..2e553ed --- /dev/null +++ b/examples_article/building_locally_demo/seeds/README.md @@ -0,0 +1,4 @@ +# Seeds directory + +Add CSV or Parquet files for reproducible seeds. +Usage examples are covered in docs/Quickstart.md and docs/Config_and_Macros.md#13-seeds-sources-and-dependencies. diff --git a/examples_article/building_locally_demo/seeds/users.csv b/examples_article/building_locally_demo/seeds/users.csv new file mode 100644 index 0000000..1b417ba --- /dev/null +++ b/examples_article/building_locally_demo/seeds/users.csv @@ -0,0 +1,3 @@ +id,email,signup_date +1,alice@example.com,2023-01-01 +2,bob@example.com,2023-01-02 \ No newline at end of file diff --git a/examples_article/building_locally_demo/sources.yml b/examples_article/building_locally_demo/sources.yml new file mode 100644 index 0000000..83436dc --- /dev/null +++ b/examples_article/building_locally_demo/sources.yml @@ -0,0 +1,9 @@ +# Source declarations describe external tables. See docs/Sources.md for details. +version: 2 +# sources: + # Example: + # - name: raw + # schema: staging + # tables: + # - name: users + # identifier: seed_users diff --git a/examples_article/building_locally_demo/tests/dq/README.md b/examples_article/building_locally_demo/tests/dq/README.md new file mode 100644 index 0000000..1acd01d --- /dev/null +++ b/examples_article/building_locally_demo/tests/dq/README.md @@ -0,0 +1,4 @@ +# Data quality tests + +Store custom data-quality tests that run via `fft test` (docs/Data_Quality_Tests.md). +Use this directory for schema-bound tests separate from unit specs. diff --git a/examples_article/building_locally_demo/tests/unit/README.md b/examples_article/building_locally_demo/tests/unit/README.md new file mode 100644 index 0000000..b3c3c8d --- /dev/null +++ b/examples_article/building_locally_demo/tests/unit/README.md @@ -0,0 +1,4 @@ +# Unit tests + +Define YAML unit specs as described in docs/Config_and_Macros.md#73-model-unit-tests-fft-utest. +Invoke them with `fft utest --env `. diff --git a/mkdocs.yml b/mkdocs.yml index a39db2b..c3cd004 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -39,6 +39,7 @@ nav: - Unit Tests: Unit_Tests.md - Snapshots: Snapshots.md - CI Checks & Change-Aware Runs: CI_Check.md + - Cost Monitoring: Cost_Monitoring.md - Troubleshooting: Troubleshooting.md - Examples: - Basic Demo: examples/Basic_Demo.md diff --git a/pyproject.toml b/pyproject.toml index 54337f2..1df700f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "fastflowtransform" -version = "0.6.7" +version = "0.6.8" description = "Python framework for SQL & Python data transformation, ETL pipelines, and dbt-style data modeling" readme = "README.md" license = { text = "Apache-2.0" } diff --git a/pytest.ini b/pytest.ini index d3d73bc..91121cb 100644 --- a/pytest.ini +++ b/pytest.ini @@ -12,3 +12,4 @@ markers = integration: integration tests example: run the examples as tests flaky: tests that may be rerun on failure + remote_engine: generic marker for all remote/costly engines diff --git a/src/fastflowtransform/artifacts.py b/src/fastflowtransform/artifacts.py index 11381e6..f2bf6c7 100644 --- a/src/fastflowtransform/artifacts.py +++ b/src/fastflowtransform/artifacts.py @@ -123,6 +123,11 @@ class RunNodeResult: message: str | None = None http: dict | None = None + # per-node query stats (aggregated across all SQL queries) + bytes_scanned: int | None = None + rows: int | None = None + query_duration_ms: int | None = None + def write_run_results( project_dir: Path, @@ -130,9 +135,11 @@ def write_run_results( started_at: str, finished_at: str, node_results: list[RunNodeResult], + budgets: dict[str, Any] | None = None, ) -> Path: """ Write run_results.json containing run envelope and per-node results. + Optionally includes a 'budgets' summary block. """ project_dir = Path(project_dir) out_dir = _target_dir(project_dir) @@ -144,6 +151,10 @@ def write_run_results( "run_finished_at": finished_at, "results": [asdict(nr) for nr in sorted(node_results, key=lambda r: r.name)], } + + if budgets is not None: + data["budgets"] = budgets + _json_dump(results_path, data) return results_path @@ -302,3 +313,38 @@ def write_catalog(project_dir: Path, executor: Any) -> Path: } _json_dump(catalog_path, data) return catalog_path + + +# ---------- READ DURATIONS ---------- + + +def load_last_run_durations(project_dir: Path) -> dict[str, float]: + """ + Best-effort reader for the last run_results.json. + + Returns: { model_name: duration_in_seconds }. + On any error or missing file: {}. + """ + path = _target_dir(project_dir) + if not path.exists(): + return {} + + try: + raw = json.loads(path.read_text(encoding="utf-8")) + except Exception: + return {} + + # tolerate a few possible shapes + items: list[dict[str, Any]] = ( + raw.get("results") or raw.get("node_results") or raw.get("nodes") or [] + ) + + out: dict[str, float] = {} + for item in items: + if not isinstance(item, dict): + continue + name = item.get("name") + dur_ms = item.get("duration_ms") + if isinstance(name, str) and isinstance(dur_ms, (int, float)): + out[name] = float(dur_ms) / 1000.0 + return out diff --git a/src/fastflowtransform/cli/bootstrap.py b/src/fastflowtransform/cli/bootstrap.py index e303a12..4fc4140 100644 --- a/src/fastflowtransform/cli/bootstrap.py +++ b/src/fastflowtransform/cli/bootstrap.py @@ -13,6 +13,7 @@ from dotenv import dotenv_values from jinja2 import Environment +from fastflowtransform.config.budgets import BudgetsConfig, load_budgets_config from fastflowtransform.core import REGISTRY from fastflowtransform.errors import DependencyNotFoundError from fastflowtransform.executors._shims import BigQueryConnShim, SAConnShim @@ -32,9 +33,23 @@ class CLIContext: jinja_env: Environment env_settings: EnvSettings profile: Profile + budgets_cfg: BudgetsConfig | None = None def make_executor(self) -> tuple[Any, Callable, Callable]: - return _make_executor(self.profile, self.jinja_env) + executor, run_sql, run_py = _make_executor(self.profile, self.jinja_env) + self._configure_budget_limit(executor) + return executor, run_sql, run_py + + def _configure_budget_limit(self, executor: Any) -> None: + if executor is None or not hasattr(executor, "configure_query_budget_limit"): + return + limit = None + engine_name = (self.profile.engine or "").lower() + if self.budgets_cfg and engine_name: + entry = self.budgets_cfg.query_limits.get(engine_name) + if entry: + limit = entry.max_bytes + executor.configure_query_budget_limit(limit) def _resolve_project_path(project_arg: str) -> Path: @@ -245,7 +260,19 @@ def _prepare_context( proj = Path(proj_raw) REGISTRY.set_cli_vars(_parse_cli_vars(vars_opt or [])) - return CLIContext(project=proj, jinja_env=jenv, env_settings=env_settings, profile=prof) + + try: + budgets_cfg = load_budgets_config(proj) + except Exception as exc: + raise typer.BadParameter(f"Failed to parse budgets.yml: {exc}") from exc + + return CLIContext( + project=proj, + jinja_env=jenv, + env_settings=env_settings, + profile=prof, + budgets_cfg=budgets_cfg, + ) def _parse_cli_vars(pairs: list[str]) -> dict[str, object]: diff --git a/src/fastflowtransform/cli/options.py b/src/fastflowtransform/cli/options.py index 08d3613..4f09260 100644 --- a/src/fastflowtransform/cli/options.py +++ b/src/fastflowtransform/cli/options.py @@ -3,8 +3,9 @@ from enum import Enum from pathlib import Path -from typing import Annotated +from typing import Annotated, Literal +import click import typer from fastflowtransform.settings import EngineType @@ -52,13 +53,38 @@ typer.Option("--html", help="Generate HTML DAG and mini documentation"), ] +type Jobs = int | Literal["auto"] + + +def _jobs_callback( + ctx: click.Context, + param: click.Parameter, + value: str | None, +) -> Jobs | None: + if value is None: + return None + + if value == "auto": + return "auto" + + try: + n = int(value) + except ValueError as e: + raise typer.BadParameter("`--jobs` must be an integer ≥1 or 'auto'.") from e + + if n < 1: + raise typer.BadParameter("`--jobs` must be ≥1 or 'auto'.") + + return n + + JobsOpt = Annotated[ - int, + str, typer.Option( "--jobs", - help="Max parallel executions per level (≥1).", - min=1, + help="Max parallel executions per level (≥1) or 'auto'.", show_default=True, + callback=_jobs_callback, ), ] diff --git a/src/fastflowtransform/cli/run.py b/src/fastflowtransform/cli/run.py index debd7a5..55fe3cd 100644 --- a/src/fastflowtransform/cli/run.py +++ b/src/fastflowtransform/cli/run.py @@ -10,12 +10,13 @@ from dataclasses import dataclass, field from datetime import UTC, datetime from pathlib import Path -from typing import Any +from typing import Any, cast import typer from fastflowtransform.artifacts import ( RunNodeResult, + load_last_run_durations, write_catalog, write_manifest, write_run_results, @@ -50,8 +51,14 @@ _selected_subgraph_names, augment_with_state_modified, ) +from fastflowtransform.config.budgets import ( + BudgetLimit, + BudgetsConfig, + load_budgets_config, +) from fastflowtransform.core import REGISTRY, relation_for from fastflowtransform.dag import levels as dag_levels +from fastflowtransform.executors.budget import format_bytes from fastflowtransform.fingerprint import ( EnvCtx, build_env_ctx, @@ -60,9 +67,10 @@ get_function_source, ) from fastflowtransform.log_queue import LogQueue -from fastflowtransform.logging import bind_context, bound_context, clear_context, echo, warn +from fastflowtransform.logging import bind_context, bound_context, clear_context, echo, error, warn from fastflowtransform.meta import ensure_meta_table from fastflowtransform.run_executor import ScheduleResult, schedule +from fastflowtransform.utils.timefmt import _format_duration_ms @dataclass @@ -79,10 +87,16 @@ class _RunEngine: computed_fps: dict[str, str] = field(default_factory=dict, init=False) fps_lock: threading.Lock = field(default_factory=threading.Lock, init=False) http_snaps: dict[str, dict] = field(default_factory=dict, init=False) + budgets_cfg: BudgetsConfig | None = None + + # per-node query stats (aggregated across all queries in that node) + query_stats: dict[str, dict[str, int]] = field(default_factory=dict, init=False) + stats_lock: threading.Lock = field(default_factory=threading.Lock, init=False) def __post_init__(self) -> None: echo(f"Profile: {self.env_name} | Engine: {self.ctx.profile.engine}") self.shared = self.ctx.make_executor() + self._configure_budget_limit(self.shared[0]) with suppress(Exception): ensure_meta_table(self.shared[0]) relevant_env = [k for k in os.environ if k.startswith("FF_")] @@ -107,9 +121,14 @@ def _get_runner(self) -> tuple[Any, Callable, Callable]: clone_needed = not (isinstance(db_path, str) and db_path.strip() == ":memory:") if clone_needed: ex = ex.clone() - run_py_wrapped = ex.run_python + + def run_sql_wrapped(node, _env=self.ctx.jinja_env, _ex=ex): + return _ex.run_sql(node, _env) + + run_py_wrapped = ex.run_python except Exception: pass + self._configure_budget_limit(ex) self.tls.runner = (ex, run_sql_wrapped, run_py_wrapped) return self.tls.runner @@ -173,6 +192,17 @@ def _qualified_target(self, name: str) -> str | None: return None return f"{namespace}.{rel}" + def _configure_budget_limit(self, executor: Any) -> None: + if executor is None or not hasattr(executor, "configure_query_budget_limit"): + return + engine_name = (self.ctx.profile.engine or "").lower() + limit = None + if self.budgets_cfg and engine_name: + entry = self.budgets_cfg.query_limits.get(engine_name) + if entry: + limit = entry.max_bytes + executor.configure_query_budget_limit(limit) + def format_run_label(self, name: str) -> str: """ Build the human-facing label for run logs, e.g.: @@ -230,6 +260,12 @@ def run_node(self, name: str) -> None: node = REGISTRY.nodes[name] ex, run_sql_fn, run_py_fn = self._get_runner() + # Reset per-node stats if the executor supports it + with suppress(Exception): + reset = getattr(ex, "reset_node_stats", None) + if callable(reset): + reset() + if name in self.force_rebuild: (run_sql_fn if node.kind == "sql" else run_py_fn)(node) cand_fp = self._maybe_fingerprint(node, ex) @@ -243,6 +279,17 @@ def run_node(self, name: str) -> None: snap = (getattr(node, "meta", {}) or {}).get("_http_snapshot") if snap: self.http_snaps[name] = snap + + # capture per-node stats after successful run + with suppress(Exception): + raw_getter = getattr(ex, "get_node_stats", None) + if callable(raw_getter): + getter = cast(Callable[[], dict[str, int] | None], raw_getter) + stats = getter() + if stats: + with self.stats_lock: + self.query_stats[name] = stats + return cand_fp = self._maybe_fingerprint(node, ex) @@ -258,6 +305,14 @@ def run_node(self, name: str) -> None: ): with self.fps_lock: self.computed_fps[name] = cand_fp + # Even when skipped, we treat stats as zero → nothing to do + with self.stats_lock: + self.query_stats[name] = { + "bytes_scanned": 0, + "rows": 0, + "query_duration_ms": 0, + "cached": True, + } return (run_sql_fn if node.kind == "sql" else run_py_fn)(node) @@ -273,6 +328,16 @@ def run_node(self, name: str) -> None: if snap: self.http_snaps[name] = snap + # capture per-node stats after successful run + with suppress(Exception): + raw_getter = getattr(ex, "get_node_stats", None) + if callable(raw_getter): + getter = cast(Callable[[], dict[str, int] | None], raw_getter) + stats = getter() + if stats: + with self.stats_lock: + self.query_stats[name] = stats + @staticmethod def before(_name: str, lvl_idx: int | None = None) -> None: return @@ -370,6 +435,301 @@ def _is_snapshot_model(node: Any) -> bool: return mat == "snapshot" +def _check_metric_limits( + *, + scope: str, + metric_name: str, + value: int, + limits: BudgetLimit, +) -> tuple[bool, bool]: + """ + Check a single metric against warn/error thresholds. + + Returns (warn_triggered, error_triggered). + """ + if value <= 0: + return False, False + + # Decide how to render the metric & thresholds + if metric_name == "bytes_scanned": + value_str = format_bytes(value) + warn_str = format_bytes(limits.warn) if limits.warn else None + err_str = format_bytes(limits.error) if limits.error else None + unit_label = "bytes_scanned" + elif metric_name == "rows": + value_str = f"{value:,} rows" + warn_str = f"{limits.warn:,} rows" if limits.warn else None + err_str = f"{limits.error:,} rows" if limits.error else None + unit_label = "rows" + else: # "query_duration_ms" + value_str = _format_duration_ms(value) + warn_str = _format_duration_ms(limits.warn) if limits.warn else None + err_str = _format_duration_ms(limits.error) if limits.error else None + unit_label = "query_duration_ms" + + # Prefer error over warn (avoid double-logging) + if limits.error and value > limits.error: + error( + f"[BUDGET] {scope}: {unit_label} {value_str} exceeds " + f"error limit {err_str} (budgets.yml)." + ) + return False, True + + if limits.warn and value > limits.warn: + warn( + f"[BUDGET] {scope}: {unit_label} {value_str} exceeds " + f"warn limit {warn_str} (budgets.yml)." + ) + return True, False + + return False, False + + +def _value_and_limits_str( + metric_name: str, + value: int, + limits: BudgetLimit, +) -> tuple[str, str | None, str | None]: + if metric_name == "bytes_scanned": + v_str = format_bytes(value) + w_str = format_bytes(limits.warn) if limits.warn else None + e_str = format_bytes(limits.error) if limits.error else None + elif metric_name == "rows": + v_str = f"{value:,} rows" + w_str = f"{limits.warn:,} rows" if limits.warn else None + e_str = f"{limits.error:,} rows" if limits.error else None + else: # query_duration_ms + v_str = _format_duration_ms(value) + w_str = _format_duration_ms(limits.warn) if limits.warn else None + e_str = _format_duration_ms(limits.error) if limits.error else None + return v_str, w_str, e_str + + +def _eval_metric(scope: str, metric_name: str, value: int, limits: BudgetLimit) -> str: + """ + Evaluate {warn,error} thresholds for a single metric. + + Returns status: "ok" | "warn" | "error". + Emits log lines for warn/error. + """ + if value <= 0: + return "ok" + if not limits.warn and not limits.error: + return "ok" + + v_str, w_str, e_str = _value_and_limits_str(metric_name, value, limits) + unit_label = metric_name + + # Prefer error over warn + if limits.error and value > limits.error: + error(f"[BUDGET] {scope}: {unit_label} {v_str} exceeds error limit {e_str} (budgets.yml).") + return "error" + + if limits.warn and value > limits.warn: + warn(f"[BUDGET] {scope}: {unit_label} {v_str} exceeds warn limit {w_str} (budgets.yml).") + return "warn" + + return "ok" + + +def _aggregate_totals(stats_by_model: dict[str, dict[str, Any]]) -> dict[str, int]: + totals = {"bytes_scanned": 0, "rows": 0, "query_duration_ms": 0} + for s in stats_by_model.values(): + totals["bytes_scanned"] += int(s.get("bytes_scanned", 0) or 0) + totals["rows"] += int(s.get("rows", 0) or 0) + totals["query_duration_ms"] += int(s.get("query_duration_ms", 0) or 0) + return totals + + +def _evaluate_total_budgets( + cfg: Any, + totals: dict[str, int], + budgets_summary: dict[str, Any], +) -> bool: + had_error = False + if not cfg.total: + return had_error + + for metric_name in ("bytes_scanned", "rows", "query_duration_ms"): + limits: BudgetLimit | None = getattr(cfg.total, metric_name) + if not limits: + continue + value = totals.get(metric_name, 0) + status = _eval_metric("total (all models)", metric_name, value, limits) + if status != "ok" or (limits.warn or limits.error): + budgets_summary["total"][metric_name] = { + "value": value, + "warn": limits.warn, + "error": limits.error, + "status": status, + } + if status == "error": + had_error = True + + return had_error + + +def _evaluate_model_budgets( + cfg: Any, + stats_by_model: dict[str, dict[str, Any]], + budgets_summary: dict[str, Any], +) -> bool: + had_error = False + + for model_name, metrics in (cfg.models or {}).items(): + s = stats_by_model.get(model_name) + if not s: + continue + + model_summary: dict[str, Any] = {} + for metric_name in ("bytes_scanned", "rows", "query_duration_ms"): + limits: BudgetLimit | None = getattr(metrics, metric_name) + if not limits: + continue + value = int(s.get(metric_name, 0) or 0) + status = _eval_metric(f"model '{model_name}'", metric_name, value, limits) + if status != "ok" or (limits.warn or limits.error): + model_summary[metric_name] = { + "value": value, + "warn": limits.warn, + "error": limits.error, + "status": status, + } + if status == "error": + had_error = True + + if model_summary: + budgets_summary["models"][model_name] = model_summary + + return had_error + + +def _aggregate_tag_totals( + cfg: Any, + stats_by_model: dict[str, dict[str, Any]], +) -> dict[str, dict[str, int]]: + tag_totals: dict[str, dict[str, int]] = { + tag: {"bytes_scanned": 0, "rows": 0, "query_duration_ms": 0} for tag in (cfg.tags or {}) + } + + for model_name, s in stats_by_model.items(): + try: + node = REGISTRY.get_node(model_name) + except KeyError: + continue + + meta = getattr(node, "meta", {}) or {} + tags = meta.get("tags") or [] + tags_str = [str(t) for t in tags] + + for tag in tags_str: + if tag not in tag_totals: + continue + aggr = tag_totals[tag] + aggr["bytes_scanned"] += int(s.get("bytes_scanned", 0) or 0) + aggr["rows"] += int(s.get("rows", 0) or 0) + aggr["query_duration_ms"] += int(s.get("query_duration_ms", 0) or 0) + + return tag_totals + + +def _evaluate_tag_budgets( + cfg: Any, + tag_totals: dict[str, dict[str, int]], + budgets_summary: dict[str, Any], +) -> bool: + had_error = False + if not cfg.tags: + return had_error + + for tag, metrics in (cfg.tags or {}).items(): + aggr = tag_totals.get(tag) or {} + tag_summary: dict[str, Any] = {} + for metric_name in ("bytes_scanned", "rows", "query_duration_ms"): + limits: BudgetLimit | None = getattr(metrics, metric_name) + if not limits: + continue + value = int(aggr.get(metric_name, 0) or 0) + status = _eval_metric( + f"tag '{tag}' (all models with this tag)", metric_name, value, limits + ) + if status != "ok" or (limits.warn or limits.error): + tag_summary[metric_name] = { + "value": value, + "warn": limits.warn, + "error": limits.error, + "status": status, + } + if status == "error": + had_error = True + + if tag_summary: + budgets_summary["tags"][tag] = tag_summary + + return had_error + + +def _resolve_budgets_cfg(project_dir: Path, engine_: _RunEngine) -> Any | None: + cfg = engine_.budgets_cfg + if cfg is not None: + return cfg + + try: + cfg = load_budgets_config(project_dir) + except Exception as exc: # pragma: no cover - CLI error path + # Parsing error is considered fatal: surface a clear message and fail the run. + error(f"Failed to parse budgets.yml: {exc}") + raise typer.Exit(1) from exc + + engine_.budgets_cfg = cfg + return cfg + + +def _evaluate_budgets( + project_dir: Path, + engine_: _RunEngine, +) -> tuple[bool, dict[str, Any] | None]: + """ + Enforce budgets.yml against collected query_stats. + + Returns: + (had_error_budget: bool, budgets_summary: dict | None) + """ + cfg = _resolve_budgets_cfg(project_dir, engine_) + + if cfg is None: + # No budgets.yml → nothing to enforce + return False, None + + stats_by_model: dict[str, dict[str, Any]] = getattr(engine_, "query_stats", {}) or {} + if not stats_by_model: + # No stats collected (e.g. purely Python models) → nothing to enforce + return False, None + + totals = _aggregate_totals(stats_by_model) + + # Summary structure for run_results.json + budgets_summary: dict[str, Any] = { + "total": {}, + "models": {}, + "tags": {}, + } + + had_error = False + had_error |= _evaluate_total_budgets(cfg, totals, budgets_summary) + had_error |= _evaluate_model_budgets(cfg, stats_by_model, budgets_summary) + + if cfg.tags: + tag_totals = _aggregate_tag_totals(cfg, stats_by_model) + had_error |= _evaluate_tag_budgets(cfg, tag_totals, budgets_summary) + + # If budgets_summary is entirely empty, return None for clarity + if not any(budgets_summary[section] for section in ("total", "models", "tags")): + return had_error, None + + return had_error, budgets_summary + + # ----------------- helpers (run function) ----------------- @@ -377,7 +737,12 @@ def _build_engine_ctx(project, env_name, engine, vars, cache, no_cache): ctx = _prepare_context(project, env_name, engine, vars) cache_mode = CacheMode.OFF if no_cache else cache engine_ = _RunEngine( - ctx=ctx, pred=None, env_name=env_name, cache_mode=cache_mode, force_rebuild=set() + ctx=ctx, + pred=None, + env_name=env_name, + cache_mode=cache_mode, + force_rebuild=set(), + budgets_cfg=ctx.budgets_cfg, ) return ctx, engine_ @@ -504,6 +869,12 @@ def _run_schedule(engine_, lvls, jobs, keep_going, ctx): bind_context(run_id=started_at) + # Best-effort: use previous run timings to batch small models per worker. + try: + prev_durations_s = load_last_run_durations(ctx.project) + except Exception: + prev_durations_s = {} + def _run_node_with_ctx(name: str) -> None: with bound_context(node=name): engine_.run_node(name) @@ -519,6 +890,7 @@ def _run_node_with_ctx(name: str) -> None: engine_abbr=_abbr(ctx.profile.engine), name_width=100, name_formatter=engine_.format_run_label, + durations_s=prev_durations_s, ) finished_at = datetime.now(UTC).isoformat(timespec="seconds") @@ -526,7 +898,12 @@ def _run_node_with_ctx(name: str) -> None: def _write_artifacts( - ctx: CLIContext, result: ScheduleResult, started_at: str, finished_at: str, engine_: _RunEngine + ctx: CLIContext, + result: ScheduleResult, + started_at: str, + finished_at: str, + engine_: _RunEngine, + budgets: dict[str, Any] | None, ) -> None: write_manifest(ctx.project) @@ -534,10 +911,19 @@ def _write_artifacts( failed = result.failed or {} all_names = set(result.per_node_s.keys()) | set(failed.keys()) + # stats accumulated by the executor per node (if available) + per_node_stats = getattr(engine_, "query_stats", {}) or {} + for name in sorted(all_names): dur_s = float(result.per_node_s.get(name, 0.0)) status = "error" if name in failed else "success" msg = str(failed.get(name)) if name in failed else None + + stats = per_node_stats.get(name) or {} + bytes_scanned = int(stats.get("bytes_scanned", 0)) + rows = int(stats.get("rows", 0)) + q_ms = int(stats.get("query_duration_ms", 0)) + node_results.append( RunNodeResult( name=name, @@ -547,11 +933,18 @@ def _write_artifacts( duration_ms=int(dur_s * 1000), message=msg, http=engine_.http_snaps.get(name), + bytes_scanned=bytes_scanned, + rows=rows, + query_duration_ms=q_ms, ) ) write_run_results( - ctx.project, started_at=started_at, finished_at=finished_at, node_results=node_results + ctx.project, + started_at=started_at, + finished_at=finished_at, + node_results=node_results, + budgets=budgets, ) @@ -581,7 +974,7 @@ def run( vars: VarsOpt = None, select: SelectOpt = None, exclude: ExcludeOpt = None, - jobs: JobsOpt = 1, + jobs: JobsOpt = "1", keep_going: KeepOpt = False, cache: CacheOpt = CacheMode.RW, no_cache: NoCacheOpt = False, @@ -591,7 +984,6 @@ def run( http_cache: HttpCacheOpt = None, changed_since: ChangedSinceOpt = None, ) -> None: - # _ensure_logging() # HTTP/API-Flags → ENV, damit fastflowtransform.api.http sie liest if offline: os.environ["FF_HTTP_OFFLINE"] = "1" @@ -622,11 +1014,15 @@ def run( result, logq, started_at, finished_at = _run_schedule(engine_, lvls, jobs, keep_going, ctx) - _write_artifacts(ctx, result, started_at, finished_at, engine_) + # Evaluate budgets.yml based on collected query stats + budget_error, budgets_summary = _evaluate_budgets(ctx.project, engine_) + + _write_artifacts(ctx, result, started_at, finished_at, engine_, budgets_summary) + _attempt_catalog(ctx) _emit_logs_and_errors(logq, result, engine_) - if result.failed: + if result.failed or budget_error: raise typer.Exit(1) engine_.persist_on_success(result) diff --git a/src/fastflowtransform/cli/snapshot_cmd.py b/src/fastflowtransform/cli/snapshot_cmd.py index d09e499..bcdb22f 100644 --- a/src/fastflowtransform/cli/snapshot_cmd.py +++ b/src/fastflowtransform/cli/snapshot_cmd.py @@ -17,6 +17,7 @@ CacheMode, _attempt_catalog, _emit_logs_and_errors, + _evaluate_budgets, _levels_for_run, _run_schedule, _RunEngine, @@ -84,7 +85,7 @@ def snapshot_run( vars: VarsOpt = None, select: SelectOpt = None, exclude: ExcludeOpt = None, - jobs: JobsOpt = 1, + jobs: JobsOpt = "1", keep_going: KeepOpt = False, prune: bool = typer.Option( False, @@ -150,11 +151,14 @@ def snapshot_run( result, logq, started_at, finished_at = _run_schedule(engine_, lvls, jobs, keep_going, ctx) - _write_artifacts(ctx, result, started_at, finished_at, engine_) + # Evaluate budgets.yml based on collected query stats + budget_error, budgets_summary = _evaluate_budgets(ctx.project, engine_) + + _write_artifacts(ctx, result, started_at, finished_at, engine_, budgets_summary) _attempt_catalog(ctx) _emit_logs_and_errors(logq, result, engine_) - if result.failed: + if result.failed or budget_error: clear_context() raise typer.Exit(1) diff --git a/src/fastflowtransform/config/budgets.py b/src/fastflowtransform/config/budgets.py new file mode 100644 index 0000000..1df9c02 --- /dev/null +++ b/src/fastflowtransform/config/budgets.py @@ -0,0 +1,247 @@ +# fastflowtransform/config/budgets.py +from __future__ import annotations + +import re as _re +from pathlib import Path +from typing import Any + +import yaml +from pydantic import BaseModel, ConfigDict, Field, field_validator + + +def _parse_duration_to_ms(value: Any) -> Any: + """ + Parse a human-friendly duration into milliseconds. + + Supported examples: + - 5000 → 5000 ms + - "5000" → 5000 ms + - "250ms" → 250 ms + - "10s" → 10_000 ms + - "1.5m" → 90_000 ms + - "2h" → 7_200_000 ms + - "1d" → 86_400_000 ms + + If parsing fails, returns the original value (so Pydantic can raise + a clear error if it expects an int). + """ + if value is None: + return None + + if isinstance(value, (int, float)): + iv = int(value) + return iv if iv > 0 else None + + if not isinstance(value, str): + return value + + text = value.strip().lower().replace("_", "").replace(" ", "") + if not text: + return None + + m = _re.match(r"^([0-9]*\.?[0-9]+)(ms|s|m|h|d)?$", text) + if not m: + # Not a duration pattern, let Pydantic validate it later + return value + + num_str, unit = m.groups() + try: + num = float(num_str) + except ValueError: + return value + + unit = unit or "ms" + factor = { + "ms": 1, + "s": 1000, + "m": 60_000, + "h": 3_600_000, + "d": 86_400_000, + }.get(unit) + + if factor is None: + return value + + ms = int(num * factor) + return ms if ms > 0 else None + + +def _normalize_duration_limits(raw: dict[str, Any]) -> dict[str, Any]: + """ + Walk the deserialized budgets.yml and normalize only + `query_duration_ms.warn` / `.error` to integer milliseconds using + _parse_duration_to_ms. + + Shape (reminder): + + version: 1 + total: + query_duration_ms: { warn: "10m", error: "30m" } + models: + my_model: + query_duration_ms: { ... } + tags: + my_tag: + query_duration_ms: { ... } + """ + if not isinstance(raw, dict): + return raw + + def _fix_metrics(metrics: Any) -> None: + if not isinstance(metrics, dict): + return + qdm = metrics.get("query_duration_ms") + if not isinstance(qdm, dict): + return + for key in ("warn", "error"): + if key in qdm: + qdm[key] = _parse_duration_to_ms(qdm[key]) + + total = raw.get("total") + if isinstance(total, dict): + _fix_metrics(total) + + for section_name in ("models", "tags"): + section = raw.get(section_name) + if not isinstance(section, dict): + continue + for _, metrics in section.items(): + if isinstance(metrics, dict): + _fix_metrics(metrics) + + return raw + + +class BudgetLimit(BaseModel): + """ + Thresholds for a single metric. + + After preprocessing, values are integers (e.g. bytes, rows, ms). You + can use either bare numbers or numeric strings: + + warn: 5000000000 + error: "10_000_000_000" + + For query_duration_ms only, we additionally support: + "10m", "30s", "2h", "1d", "250ms" + (these are converted to ms before validation). + """ + + model_config = ConfigDict(extra="forbid") + + warn: int | None = None + error: int | None = None + + @field_validator("warn", "error", mode="before") + @classmethod + def _normalize_int(cls, v: Any) -> int | None: + if v is None: + return None + if isinstance(v, (int, float)): + iv = int(v) + return iv if iv > 0 else None + if isinstance(v, str): + text = v.strip().replace("_", "").replace(",", "") + if not text: + return None + if not text.isdigit(): + # At this point we've already tried to parse duration strings + # for query_duration_ms; non-numeric leftovers here are an error. + raise ValueError(f"budget limits must be integers or numeric strings, got {v!r}") + iv = int(text) + return iv if iv > 0 else None + raise TypeError("budget limits must be integers or strings") + + +class BudgetMetrics(BaseModel): + """ + Metrics we can budget against; all are optional. + + bytes_scanned → sum of bytes across all SQL queries + rows → sum of rows across all SQL queries + query_duration_ms → sum of query durations (ms), not wall-clock + """ + + model_config = ConfigDict(extra="forbid") + + bytes_scanned: BudgetLimit | None = None + rows: BudgetLimit | None = None + query_duration_ms: BudgetLimit | None = None + + +class QueryLimitConfig(BaseModel): + """ + Per-engine query limit configuration. + + Currently only `max_bytes` is supported. + """ + + model_config = ConfigDict(extra="forbid") + + max_bytes: int | None = None + + @field_validator("max_bytes", mode="before") + @classmethod + def _normalize_int(cls, v: Any) -> int | None: + return BudgetLimit._normalize_int(v) + + +class BudgetsConfig(BaseModel): + """ + Strict representation of budgets.yml. + + Example: + + version: 1 + + # Global (across all models in fft run) + total: + bytes_scanned: + warn: 5000000000 + error: 10000000000 + + # Per model limits + models: + fct_events: + bytes_scanned: + warn: 1000000000 + error: 2000000000 + + # Per tag limits (aggregated over all models with that tag) + tags: + heavy: + bytes_scanned: + warn: 5000000000 + error: 8000000000 + + # Optional per-engine query guard limits + query_limits: + duckdb: + max_bytes: 2000000000 + """ + + model_config = ConfigDict(extra="forbid") + + version: int = 1 + + total: BudgetMetrics | None = None + models: dict[str, BudgetMetrics] = Field(default_factory=dict) + tags: dict[str, BudgetMetrics] = Field(default_factory=dict) + query_limits: dict[str, QueryLimitConfig] = Field(default_factory=dict) + + +def load_budgets_config(project_dir: Path) -> BudgetsConfig | None: + """ + Read budgets.yml under `project_dir` and validate it strictly. + + Missing file → returns None (no budgets enforced). + Invalid file → raises, caller should wrap into a user-friendly error. + """ + project_dir = Path(project_dir) + cfg_path = project_dir / "budgets.yml" + if not cfg_path.exists(): + return None + + raw = yaml.safe_load(cfg_path.read_text(encoding="utf-8")) or {} + raw = _normalize_duration_limits(raw) + return BudgetsConfig.model_validate(raw) diff --git a/src/fastflowtransform/executors/base.py b/src/fastflowtransform/executors/base.py index 349a6a5..9eccb77 100644 --- a/src/fastflowtransform/executors/base.py +++ b/src/fastflowtransform/executors/base.py @@ -7,6 +7,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Mapping from contextlib import suppress +from datetime import datetime from pathlib import Path from typing import Any, TypeVar, cast @@ -19,6 +20,8 @@ from fastflowtransform.config.sources import resolve_source_entry from fastflowtransform.core import REGISTRY, Node, relation_for from fastflowtransform.errors import ModelExecutionError +from fastflowtransform.executors.budget import BudgetGuard +from fastflowtransform.executors.query_stats import QueryStats from fastflowtransform.incremental import _normalize_unique_key from fastflowtransform.logging import echo, echo_debug from fastflowtransform.validation import validate_required_columns @@ -364,31 +367,39 @@ def _strip_leading_sql_comments(self, sql: str) -> tuple[str, int]: def _selectable_body(self, sql: str) -> str: """ - Return a valid SELECT-able body for CREATE … AS: - - - If the statement starts (after comments/blank lines) with a CTE (WITH …), - return from the WITH onward. - - If it starts with SELECT, return from SELECT onward. - - Otherwise, fall back to the first SELECT heuristic. + Normalize a SELECT/CTE body: + + - Strip leading SQL comments/blank lines. + - Find the first WITH or SELECT keyword (as a word) anywhere in the statement. + - Return from that keyword onward, stripping trailing semicolons/whitespace. + + This works for: + * plain SELECT + * WITH ... (CTEs) + * CREATE TABLE/VIEW ... AS WITH ... + * CREATE TABLE/VIEW ... AS SELECT ... + * INSERT INTO ... SELECT ... """ - # Keep original for fallback, but check after stripping comments - s0 = sql - s, offset = self._strip_leading_sql_comments(sql) - s_ws = s.lstrip() # in case comments left some spaces - head = s_ws[:6].lower() + s0 = sql or "" + + # Strip leading comments; s starts at 'offset' in the original string. + s, offset = self._strip_leading_sql_comments(s0) + s_ws = s.lstrip() - if s_ws.startswith(("with ", "with\n", "with\t")): - # Return from the start of this WITH (preserve exactly s_ws form) - # Compute index into original string to retain original casing beyond comments - idx = s.find(s_ws) - return sql[offset + (idx if idx >= 0 else 0) :].lstrip() + # Find first WITH or SELECT as a whole word (case-insensitive) + m = re.search(r"\b(with|select)\b", s_ws, flags=re.IGNORECASE) + if not m: + # No obvious SELECT/CTE - just return the statement minus trailing semicolons. + return s0.strip().rstrip(";\n\t ") - if head.startswith("select"): - idx = s.find(s_ws) - return sql[offset + (idx if idx >= 0 else 0) :].lstrip() + # m.start() is index within s_ws. Need to map back into the original sql. + leading_ws_len = len(s) - len(s_ws) # spaces we lstripped + start_in_s = leading_ws_len + m.start() + start_in_sql = offset + start_in_s - # Fallback: first SELECT anywhere in the statement - return self._first_select_body(s0) + body = s0[start_in_sql:] + # Strip trailing semicolons and whitespace + return body.strip().rstrip(";\n\t ") def _looks_like_direct_ddl(self, sql: str) -> bool: """ @@ -439,6 +450,175 @@ def _render_ephemeral_sql(self, name: str, env: Environment) -> str: body = self._selectable_body(sql).rstrip(" ;\n\t") return f"(\n{body}\n)" + # ---------- Query stats (per-node, aggregated across queries) ---------- + + def _record_query_stats(self, stats: QueryStats) -> None: + """ + Append per-query stats to an internal buffer. + + Executors call this from their engine-specific recording logic. + The run engine can later drain this buffer per node. + """ + buf = getattr(self, "_ff_query_stats_buffer", None) + if buf is None: + buf = [] + self._ff_query_stats_buffer = buf + buf.append(stats) + + def _drain_query_stats(self) -> list[QueryStats]: + """ + Drain and return the buffered stats, resetting the buffer. + + Used by the run engine around per-node execution. + """ + buf = getattr(self, "_ff_query_stats_buffer", None) + if not buf: + self._ff_query_stats_buffer = [] + return [] + self._ff_query_stats_buffer = [] + return list(buf) + + def _record_query_job_stats(self, job: Any) -> None: + """ + Best-effort extraction of stats from a 'job-like' object. + + This is intentionally generic; engines that return job handles + (BigQuery, Snowflake, Spark) can pass them here. Engines can + override this if they want more precise logic. + """ + + def _safe_int(val: Any) -> int | None: + try: + if val is None: + return None + return int(val) + except Exception: + return None + + # Heuristic attribute names - BigQuery and others may expose these. + bytes_processed = _safe_int( + getattr(job, "total_bytes_processed", None) or getattr(job, "bytes_processed", None) + ) + + rows = _safe_int( + getattr(job, "num_dml_affected_rows", None) + or getattr(job, "total_rows", None) + or getattr(job, "rowcount", None) + ) + + duration_ms: int | None = None + try: + started = getattr(job, "started", None) + ended = getattr(job, "ended", None) + if isinstance(started, datetime) and isinstance(ended, datetime): + duration_ms = int((ended - started).total_seconds() * 1000) + except Exception: + pass + + self._record_query_stats( + QueryStats( + bytes_processed=bytes_processed, + rows=rows, + duration_ms=duration_ms, + ) + ) + + def configure_query_budget_limit(self, limit: int | None) -> None: + """ + Inject a configured per-query byte limit (e.g. from budgets.yml). + """ + if limit is None: + self._ff_configured_query_limit = None + return + try: + iv = int(limit) + except Exception: + self._ff_configured_query_limit = None + return + self._ff_configured_query_limit = iv if iv > 0 else None + + def _configured_query_limit(self) -> int | None: + val = getattr(self, "_ff_configured_query_limit", None) + if val is None: + return None + try: + iv = int(val) + except Exception: + return None + return iv if iv > 0 else None + + def _set_query_budget_estimate(self, estimate: int | None) -> None: + self._ff_last_query_budget_estimate = estimate + + def _consume_query_budget_estimate(self) -> int | None: + estimate = getattr(self, "_ff_last_query_budget_estimate", None) + self._ff_last_query_budget_estimate = None + return estimate + + def _apply_budget_guard(self, guard: BudgetGuard | None, sql: str) -> int | None: + if guard is None: + self._set_query_budget_estimate(None) + self._ff_budget_guard_active = False + return None + limit, source = guard.resolve_limit(self._configured_query_limit()) + if not limit: + self._set_query_budget_estimate(None) + self._ff_budget_guard_active = False + return None + self._ff_budget_guard_active = True + estimate = guard.enforce(sql, self, limit=limit, source=source) + self._set_query_budget_estimate(estimate) + return estimate + + def _is_budget_guard_active(self) -> bool: + return bool(getattr(self, "_ff_budget_guard_active", False)) + + # ---------- Per-node stats API (used by run engine) ---------- + + def reset_node_stats(self) -> None: + """ + Reset per-node statistics buffer. + + The run engine calls this before executing a model so that all + stats recorded via `_record_query_stats(...)` belong to that node. + """ + # just clear the buffer; next recording will re-create it + self._ff_query_stats_buffer = [] + + def get_node_stats(self) -> dict[str, int]: + """ + Aggregate buffered QueryStats into a simple dict: + + { + "bytes_scanned": , + "rows": , + "query_duration_ms": , + } + + Called by the run engine after a node finishes. + """ + stats_list = self._drain_query_stats() + if not stats_list: + return {} + + total_bytes = 0 + total_rows = 0 + total_duration = 0 + + for s in stats_list: + if s.bytes_processed is not None: + total_bytes += int(s.bytes_processed) + if s.rows is not None: + total_rows += int(s.rows) + if s.duration_ms is not None: + total_duration += int(s.duration_ms) + + return { + "bytes_scanned": total_bytes, + "rows": total_rows, + "query_duration_ms": total_duration, + } + # ---------- Python models ---------- def run_python(self, node: Node) -> None: """Execute the Python model for a given node and materialize its result.""" diff --git a/src/fastflowtransform/executors/bigquery/base.py b/src/fastflowtransform/executors/bigquery/base.py index 394dab2..b7ceb8d 100644 --- a/src/fastflowtransform/executors/bigquery/base.py +++ b/src/fastflowtransform/executors/bigquery/base.py @@ -7,6 +7,8 @@ from fastflowtransform.executors._shims import BigQueryConnShim from fastflowtransform.executors.base import BaseExecutor from fastflowtransform.executors.bigquery._bigquery_mixin import BigQueryIdentifierMixin +from fastflowtransform.executors.budget import BudgetGuard +from fastflowtransform.executors.query_stats import _TrackedQueryJob from fastflowtransform.logging import echo from fastflowtransform.meta import ensure_meta_table, upsert_meta from fastflowtransform.snapshots import resolve_snapshot_config @@ -29,6 +31,12 @@ class BigQueryBaseExecutor(BigQueryIdentifierMixin, BaseExecutor[TFrame]): # Subclasses override ENGINE_NAME ("bigquery", "bigquery_batch", ...) ENGINE_NAME = "bigquery_base" + _BUDGET_GUARD = BudgetGuard( + env_var="FF_BQ_MAX_BYTES", + estimator_attr="_estimate_query_bytes", + engine_label="BigQuery", + what="query", + ) def __init__( self, @@ -54,6 +62,38 @@ def __init__( dataset=self.dataset, ) + def _execute_sql(self, sql: str) -> _TrackedQueryJob: + """ + Central BigQuery query runner. + + - All 'real' SQL statements in this executor should go through here. + - Returns the QueryJob so callers can call .result(). + """ + self._apply_budget_guard(self._BUDGET_GUARD, sql) + job = self.client.query(sql, location=self.location) + return _TrackedQueryJob(job, on_complete=self._record_query_job_stats) + + # --- Cost estimation for the shared BudgetGuard ----------------- + + def _estimate_query_bytes(self, sql: str) -> int | None: + """ + Estimate bytes for a BigQuery SQL statement using a dry-run. + + Returns the estimated bytes, or None if estimation is not possible. + """ + cfg = bigquery.QueryJobConfig( + dry_run=True, + use_query_cache=False, + ) + job = self.client.query( + sql, + job_config=cfg, + location=self.location, + ) + # Dry-run is free; we just need the job metadata + job.result() + return int(getattr(job, "total_bytes_processed", 0) or 0) + # ---- DQ test table formatting (fft test) ---- def _format_test_table(self, table: str | None) -> str | None: """ @@ -109,16 +149,10 @@ def _apply_sql_materialization( ) from e def _create_or_replace_view(self, target_sql: str, select_body: str, node: Node) -> None: - self.client.query( - f"CREATE OR REPLACE VIEW {target_sql} AS {select_body}", - location=self.location, - ).result() + self._execute_sql(f"CREATE OR REPLACE VIEW {target_sql} AS {select_body}").result() def _create_or_replace_table(self, target_sql: str, select_body: str, node: Node) -> None: - self.client.query( - f"CREATE OR REPLACE TABLE {target_sql} AS {select_body}", - location=self.location, - ).result() + self._execute_sql(f"CREATE OR REPLACE TABLE {target_sql} AS {select_body}").result() def _create_or_replace_view_from_table( self, @@ -129,10 +163,7 @@ def _create_or_replace_view_from_table( view_id = self._qualified_identifier(view_name) back_id = self._qualified_identifier(backing_table) self._ensure_dataset() - self.client.query( - f"CREATE OR REPLACE VIEW {view_id} AS SELECT * FROM {back_id}", - location=self.location, - ).result() + self._execute_sql(f"CREATE OR REPLACE VIEW {view_id} AS SELECT * FROM {back_id}").result() # ---- Meta hook ---- def on_node_built(self, node: Node, relation: str, fingerprint: str) -> None: @@ -181,10 +212,7 @@ def create_table_as(self, relation: str, select_sql: str) -> None: project=self.project, dataset=self.dataset, ) - self.client.query( - f"CREATE TABLE {target} AS {body}", - location=self.location, - ).result() + self._execute_sql(f"CREATE TABLE {target} AS {body}").result() def incremental_insert(self, relation: str, select_sql: str) -> None: """ @@ -197,10 +225,7 @@ def incremental_insert(self, relation: str, select_sql: str) -> None: project=self.project, dataset=self.dataset, ) - self.client.query( - f"INSERT INTO {target} {body}", - location=self.location, - ).result() + self._execute_sql(f"INSERT INTO {target} {body}").result() def incremental_merge(self, relation: str, select_sql: str, unique_key: list[str]) -> None: """ @@ -221,10 +246,10 @@ def incremental_merge(self, relation: str, select_sql: str, unique_key: list[str DELETE FROM {target} t WHERE EXISTS (SELECT 1 FROM ({body}) s WHERE {pred}) """ - self.client.query(delete_sql, location=self.location).result() + self._execute_sql(delete_sql).result() insert_sql = f"INSERT INTO {target} SELECT * FROM ({body})" - self.client.query(insert_sql, location=self.location).result() + self._execute_sql(insert_sql).result() def alter_table_sync_schema( self, @@ -273,10 +298,7 @@ def alter_table_sync_schema( for col in to_add: f = out_fields[col] typ = str(f.field_type) if hasattr(f, "field_type") else "STRING" - self.client.query( - f"ALTER TABLE {target} ADD COLUMN {col} {typ}", - location=self.location, - ).result() + self._execute_sql(f"ALTER TABLE {target} ADD COLUMN {col} {typ}").result() # ── Snapshots API (shared for pandas + BigFrames) ───────────────────── def run_snapshot_sql(self, node: Node, env: Any) -> None: @@ -367,7 +389,7 @@ def run_snapshot_sql(self, node: Node, env: Any) -> None: {hash_expr} AS {hash_col} FROM ({body}) AS s """ - self.client.query(create_sql, location=self.location).result() + self._execute_sql(create_sql).result() return # ---- Incremental snapshot update ---- @@ -403,7 +425,7 @@ def run_snapshot_sql(self, node: Node, env: Any) -> None: AND t.{is_cur} = TRUE AND {change_condition} """ - self.client.query(close_sql, location=self.location).result() + self._execute_sql(close_sql).result() # 2) Insert new current versions (new keys or changed rows) first_key = unique_key[0] @@ -424,7 +446,7 @@ def run_snapshot_sql(self, node: Node, env: Any) -> None: t.{first_key} IS NULL OR {change_condition} """ - self.client.query(insert_sql, location=self.location).result() + self._execute_sql(insert_sql).result() def snapshot_prune( self, @@ -499,4 +521,4 @@ def snapshot_prune( AND t.{vf} = r.{vf} ) """ - self.client.query(delete_sql, location=self.location).result() + self._execute_sql(delete_sql).result() diff --git a/src/fastflowtransform/executors/bigquery/pandas.py b/src/fastflowtransform/executors/bigquery/pandas.py index 83b0d49..32b2e7d 100644 --- a/src/fastflowtransform/executors/bigquery/pandas.py +++ b/src/fastflowtransform/executors/bigquery/pandas.py @@ -2,11 +2,13 @@ from __future__ import annotations from collections.abc import Iterable +from time import perf_counter import pandas as pd from fastflowtransform.core import Node from fastflowtransform.executors.bigquery.base import BigQueryBaseExecutor +from fastflowtransform.executors.query_stats import QueryStats from fastflowtransform.typing import BadRequest, Client, LoadJobConfig, NotFound, bigquery @@ -57,6 +59,7 @@ def _materialize_relation(self, relation: str, df: pd.DataFrame, node: Node) -> table_id = f"{self.project}.{self.dataset}.{relation}" job_config = LoadJobConfig(write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE) # Optionally extend dtype mapping here (NUMERIC/STRING etc.) + start = perf_counter() try: job = self.client.load_table_from_dataframe( df, @@ -67,6 +70,9 @@ def _materialize_relation(self, relation: str, df: pd.DataFrame, node: Node) -> job.result() except BadRequest as e: raise RuntimeError(f"BigQuery write failed: {table_id}\n{e}") from e + else: + duration_ms = int((perf_counter() - start) * 1000) + self._record_dataframe_stats(df, duration_ms) def _create_view_over_table(self, view_name: str, backing_table: str, node: Node) -> None: """ @@ -77,3 +83,14 @@ def _create_view_over_table(self, view_name: str, backing_table: str, node: Node def _frame_name(self) -> str: return "pandas" + + def _record_dataframe_stats(self, df: pd.DataFrame, duration_ms: int) -> None: + rows = len(df) + bytes_val = int(df.memory_usage(deep=True).sum()) if rows > 0 else 0 + self._record_query_stats( + QueryStats( + bytes_processed=bytes_val if bytes_val > 0 else None, + rows=rows if rows > 0 else None, + duration_ms=duration_ms, + ) + ) diff --git a/src/fastflowtransform/executors/budget.py b/src/fastflowtransform/executors/budget.py new file mode 100644 index 0000000..87cd281 --- /dev/null +++ b/src/fastflowtransform/executors/budget.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +import os +from collections.abc import Callable +from typing import Any, cast + +from fastflowtransform.logging import echo + + +def parse_max_bytes_env(env_var: str) -> int | None: + """ + Parse an env var like FF_BQ_MAX_BYTES into an integer byte count. + """ + raw = os.getenv(env_var) + if not raw: + return None + + text = raw.strip().lower().replace("_", "").replace(",", "") + if not text: + return None + + multiplier = 1 + suffixes = [ + ("tb", 1024**4), + ("t", 1024**4), + ("gb", 1024**3), + ("g", 1024**3), + ("mb", 1024**2), + ("m", 1024**2), + ("kb", 1024), + ("k", 1024), + ] + for suf, factor in suffixes: + if text.endswith(suf): + text = text[: -len(suf)].strip() + multiplier = factor + break + + try: + value = float(text) + except ValueError: + echo( + f"Warning: invalid {env_var}={raw!r}; expected integer bytes or " + "a number with unit suffix like '10GB'. Ignoring limit." + ) + return None + + bytes_val = int(value * multiplier) + if bytes_val <= 0: + return None + return bytes_val + + +def format_bytes(num: int) -> str: + """Human-readable byte formatting for error messages.""" + units = ["B", "KB", "MB", "GB", "TB", "PB"] + size = float(num) + for unit in units: + if size < 1024.0 or unit == units[-1]: + if size >= 10: + return f"{size:.1f} {unit}" + return f"{size:.2g} {unit}" + size /= 1024.0 + return f"{num} B" + + +EstimatorFn = Callable[[str], int | None] + + +class BudgetGuard: + """ + Shared implementation for per-query budget enforcement. + """ + + def __init__( + self, + *, + env_var: str | None, + estimator_attr: str, + engine_label: str, + what: str = "query", + ): + self.env_var = env_var + self.estimator_attr = estimator_attr + self.engine_label = engine_label + self.what = what + self._env_limit: int | None = None + self._env_limit_populated = False + + def _env_limit_value(self) -> int | None: + if not self.env_var: + return None + if not self._env_limit_populated: + self._env_limit = parse_max_bytes_env(self.env_var) + self._env_limit_populated = True + return self._env_limit + + def resolve_limit(self, override_limit: int | None = None) -> tuple[int | None, str | None]: + env_limit = self._env_limit_value() + if env_limit is not None and env_limit > 0: + source = f"{self.env_var}" if self.env_var else "environment" + return env_limit, source + if override_limit is None: + return None, None + try: + limit = int(override_limit) + except Exception: + return None, None + if limit <= 0: + return None, None + return limit, "budgets.yml (query_limits)" + + def enforce(self, sql: str, executor: Any, *, limit: int, source: str | None) -> int | None: + estimator_obj = getattr(executor, self.estimator_attr, None) + + if not callable(estimator_obj): + echo( + f"{self.engine_label} cost guard misconfigured: " + f"missing estimator '{self.estimator_attr}'. Guard ignored." + ) + return None + + estimator = cast(EstimatorFn, estimator_obj) + + try: + estimated = estimator(sql) + except Exception as exc: + echo( + f"{self.engine_label} cost estimation failed " + f"(limit ignored for this {self.what}): {exc}" + ) + return None + + if estimated is None: + return None + + value = estimated + + if value <= 0: + return None + + if value > limit: + label = source or "configured limit" + msg = ( + f"Aborting {self.engine_label} {self.what}: estimated scanned bytes " + f"{value} ({format_bytes(value)}) exceed " + f"{label}={limit} ({format_bytes(limit)}).\n" + f"Adjust {label} to allow this {self.what}." + ) + raise RuntimeError(msg) + + return value diff --git a/src/fastflowtransform/executors/databricks_spark.py b/src/fastflowtransform/executors/databricks_spark.py index fc06218..5a297de 100644 --- a/src/fastflowtransform/executors/databricks_spark.py +++ b/src/fastflowtransform/executors/databricks_spark.py @@ -5,7 +5,8 @@ from contextlib import suppress from functools import reduce from pathlib import Path -from typing import Any +from time import perf_counter +from typing import Any, cast from urllib.parse import unquote, urlparse from jinja2 import Environment @@ -18,6 +19,8 @@ get_spark_window, ) from fastflowtransform.executors.base import BaseExecutor +from fastflowtransform.executors.budget import BudgetGuard +from fastflowtransform.executors.query_stats import QueryStats from fastflowtransform.logging import echo, echo_debug from fastflowtransform.meta import ensure_meta_table, upsert_meta from fastflowtransform.snapshots import resolve_snapshot_config @@ -25,6 +28,10 @@ from fastflowtransform.table_formats.base import SparkFormatHandler from fastflowtransform.typing import SDF, DataType, SparkSession +# --------------------------------------------------------------------------- +# Delta integration +# --------------------------------------------------------------------------- + # Enable Delta Lake via delta-spark when available configure_spark_with_delta_pip: Callable[..., Any] | None try: @@ -36,7 +43,34 @@ _DELTA_EXTENSION = "io.delta.sql.DeltaSparkSessionExtension" _DELTA_CATALOG = "org.apache.spark.sql.delta.catalog.DeltaCatalog" -# _SPARK_DEFAULT_CATALOG = "org.apache.spark.sql.internal.CatalogImpl" # Spark's built-in + + +def _csv_tokens(value: str | None) -> list[str]: + if not value: + return [] + return [part.strip() for part in value.split(",") if part and part.strip()] + + +def _ensure_csv_token(value: str | None, token: str) -> tuple[str | None, bool]: + tokens = _csv_tokens(value) + if token in tokens: + return value, False + tokens.append(token) + return ",".join(tokens), True + + +def _as_nonempty_str(value: Any | None) -> str | None: + if value is None: + return None + text = str(value).strip() + return text or None + + +def _safe_conf(spark: SparkSession, key: str, default: str = "") -> str: + try: + return str(spark.conf.get(key, default)) + except Exception as exc: + return f"" def _has_delta(spark: SparkSession) -> bool: @@ -109,34 +143,6 @@ def _try_for_name(jvm: Any) -> bool: return any(_try_for_name(handle) for handle in _handles()) -def _csv_tokens(value: str | None) -> list[str]: - if not value: - return [] - return [part.strip() for part in value.split(",") if part and part.strip()] - - -def _ensure_csv_token(value: str | None, token: str) -> tuple[str | None, bool]: - tokens = _csv_tokens(value) - if token in tokens: - return value, False - tokens.append(token) - return ",".join(tokens), True - - -def _as_nonempty_str(value: Any | None) -> str | None: - if value is None: - return None - text = str(value).strip() - return text or None - - -def _safe_conf(spark: SparkSession, key: str, default: str = "") -> str: - try: - return str(spark.conf.get(key, default)) - except Exception as exc: - return f"" - - def _log_delta_capabilities( spark: SparkSession, *, @@ -174,9 +180,16 @@ def _log_delta_capabilities( class DatabricksSparkExecutor(BaseExecutor[SDF]): - ENGINE_NAME = "databricks_spark" """Spark/Databricks executor without pandas: Python models operate on Spark DataFrames.""" + ENGINE_NAME = "databricks_spark" + _BUDGET_GUARD = BudgetGuard( + env_var="FF_SPK_MAX_BYTES", + estimator_attr="_estimate_query_bytes", + engine_label="Databricks/Spark", + what="query", + ) + def __init__( self, master: str = "local[*]", @@ -193,10 +206,12 @@ def __init__( ): extra_conf = dict(extra_conf or {}) self._user_spark = spark + builder = SparkSession.builder.master(master).appName(app_name) catalog_key = "spark.sql.catalog.spark_catalog" ext_key = "spark.sql.extensions" + # Warehouse directory warehouse_path: Path | None = None if warehouse_dir: warehouse_path = Path(warehouse_dir).expanduser() @@ -209,6 +224,7 @@ def __init__( if catalog_value: builder = builder.config(catalog_key, catalog_value) + # Extra config if extra_conf: for key, value in extra_conf.items(): if value is not None: @@ -221,6 +237,7 @@ def __init__( fmt_requested = (table_format or "").strip().lower() wants_delta = fmt_requested == "delta" + # Apply Delta configuration last, after all Spark configs are set. if not wants_delta and self._user_spark is None: catalog_overridden = bool(catalog_value) if not catalog_overridden: @@ -259,12 +276,13 @@ def __init__( self.schema = database if database: - self.spark.sql(f"CREATE DATABASE IF NOT EXISTS `{database}`") + self._execute_sql(f"CREATE DATABASE IF NOT EXISTS `{database}`") with suppress(Exception): self.spark.catalog.setCurrentDatabase(database) self.spark_table_format: str | None = fmt_requested or None self.spark_table_options = {str(k): str(v) for k, v in (table_options or {}).items()} + # ---- Delta availability check ---- self._delta_ok = _has_delta(self.spark) @@ -289,8 +307,155 @@ def __init__( self.spark_table_format, self.spark, table_options=self.spark_table_options, + sql_runner=self._execute_sql, + ) + + self._spark_default_size = self._detect_default_size() + + # ---------- Cost estimation & central execution ---------- + + def _detect_default_size(self) -> int: + """ + Detect Spark's defaultSizeInBytes sentinel. + + - Prefer spark.sql.defaultSizeInBytes if available. + - Fall back to Long.MaxValue (2^63 - 1) otherwise. + """ + try: + conf_val = self.spark.conf.get("spark.sql.defaultSizeInBytes") + if conf_val is not None: + return int(conf_val) + except Exception: + # config not set / older Spark / weird environment + pass + + # Fallback: Spark uses Long.MaxValue by default + return 2**63 - 1 # 9223372036854775807 + + def _parse_spark_stats_size(self, size_val: Any) -> int | None: + if size_val is None: + return None + try: + size_int = int(str(size_val)) + except Exception: + return None + return size_int if size_int > 0 else None + + def _jplan_uses_default_size(self, jplan: Any) -> bool: + """ + Recursively walk a JVM LogicalPlan and return True if any node's + stats.sizeInBytes equals spark.sql.defaultSizeInBytes. + """ + if self._spark_default_size is None: + return False + + try: + stats = jplan.stats() + size_val = stats.sizeInBytes() + size_int = int(str(size_val)) + if size_int == self._spark_default_size: + return True + except Exception: + # ignore stats errors and keep walking + pass + + # children() is a Scala Seq[LogicalPlan]; iterate via .size() / .apply(i) + try: + children = jplan.children() + n = children.size() + for idx in range(n): + child = children.apply(idx) + if self._jplan_uses_default_size(child): + return True + except Exception: + # if we can't inspect children, stop here + pass + + return False + + def _spark_plan_bytes(self, sql: str) -> int | None: + """ + Inspect the optimized logical plan via the JVM and return sizeInBytes + as an integer, or None if not available. + + This does *not* execute the query; it only goes through analysis/planning. + """ + try: + normalized = self._selectable_body(sql).rstrip(";\n\t ") + if not normalized: + normalized = sql + except Exception: + normalized = sql + + stmt = normalized.lstrip().lower() + if not stmt.startswith(("select", "with")): + # DDL/DML statements (ALTER/INSERT/etc.) should not be executed twice. + return None + + try: + df = self.spark.sql(normalized) + + jdf = cast(Any, getattr(df, "_jdf", None)) + if jdf is None: + return None + + qe = jdf.queryExecution() + jplan = qe.optimizedPlan() + + # If any node relies on defaultSizeInBytes, we don't trust the stats + if self._jplan_uses_default_size(jplan): + return None + + stats = jplan.stats() + + size_attr = getattr(stats, "sizeInBytes", None) + size_val = size_attr() if callable(size_attr) else size_attr + + return self._parse_spark_stats_size(size_val) + except Exception: + return None + + def _estimate_query_bytes(self, sql: str) -> int | None: + """ + Best-effort logical-plan size estimate using Spark's stats. + + It inspects the optimized plan's sizeInBytes via the JVM API without + executing the query. If unavailable or unsupported, returns None and + the guard is effectively disabled. + """ + return self._spark_plan_bytes(sql) + + def _execute_sql(self, sql: str) -> SDF: + """ + Central Spark SQL runner. + + - Guarded by FF_SPK_MAX_BYTES via the cost guard. + - Returns a Spark DataFrame (same as spark.sql). + - Records best-effort query stats for run_results.json. + """ + estimated_bytes = self._apply_budget_guard(self._BUDGET_GUARD, sql) + if estimated_bytes is None and not self._is_budget_guard_active(): + with suppress(Exception): + estimated_bytes = self._spark_plan_bytes(sql) + t0 = perf_counter() + df = self.spark.sql(sql) + dt_ms = int((perf_counter() - t0) * 1000) + + # Best-effort logical estimate + bytes_processed = ( + estimated_bytes if estimated_bytes is not None else self._spark_plan_bytes(sql) ) + # For Spark we don't attempt row counts without executing the job + self._record_query_stats( + QueryStats( + bytes_processed=bytes_processed, + rows=None, + duration_ms=dt_ms, + ) + ) + return df + # ---------- Frame hooks (required) ---------- def _read_relation(self, relation: str, node: Node, deps: Iterable[str]) -> SDF: # relation may optionally be "db.table" (via source()/ref()) @@ -303,12 +468,16 @@ def _materialize_relation(self, relation: str, df: SDF, node: Node) -> None: storage_meta = self._storage_meta(node, relation) # Delegate managed/unmanaged handling to _save_df_as_table so Iceberg # (or other handlers) can consistently enforce managed tables. + start = perf_counter() self._save_df_as_table(relation, df, storage=storage_meta) + duration_ms = int((perf_counter() - start) * 1000) + self._record_spark_dataframe_stats(df, duration_ms) def _create_view_over_table(self, view_name: str, backing_table: str, node: Node) -> None: + """Compatibility hook: create a simple SELECT * view over an existing table.""" view_sql = self._sql_identifier(view_name) backing_sql = self._sql_identifier(backing_table) - self.spark.sql(f"CREATE OR REPLACE VIEW {view_sql} AS SELECT * FROM {backing_sql}") + self._execute_sql(f"CREATE OR REPLACE VIEW {view_sql} AS SELECT * FROM {backing_sql}") def _validate_required( self, node_name: str, inputs: Any, requires: dict[str, set[str]] @@ -365,6 +534,8 @@ def _storage_meta(self, node: Node | None, relation: str) -> dict[str, Any]: Retrieve configured storage overrides for the logical node backing `relation`. """ rel_clean = self._strip_quotes(relation) + + # 1) Direct node meta / storage config if node is not None: meta = dict((node.meta or {}).get("storage") or {}) if meta: @@ -372,6 +543,8 @@ def _storage_meta(self, node: Node | None, relation: str) -> dict[str, Any]: lookup = storage.get_model_storage(node.name) if lookup: return lookup + + # 2) Search REGISTRY nodes by relation_for(name) for cand in getattr(REGISTRY, "nodes", {}).values(): try: if self._strip_quotes(relation_for(cand.name)) == rel_clean: @@ -383,6 +556,8 @@ def _storage_meta(self, node: Node | None, relation: str) -> dict[str, Any]: return lookup except Exception: continue + + # 3) Direct storage override by relation name return storage.get_model_storage(rel_clean) def _write_to_storage_path(self, relation: str, df: SDF, storage_meta: dict[str, Any]) -> None: @@ -403,6 +578,35 @@ def _write_to_storage_path(self, relation: str, df: SDF, storage_meta: dict[str, with suppress(Exception): self.spark.catalog.refreshByPath(path) + def _record_spark_dataframe_stats(self, df: SDF, duration_ms: int) -> None: + bytes_est = self._spark_dataframe_bytes(df) + self._record_query_stats( + QueryStats( + bytes_processed=bytes_est, + rows=None, + duration_ms=duration_ms, + ) + ) + + def _spark_dataframe_bytes(self, df: SDF) -> int | None: + try: + jdf = cast(Any, getattr(df, "_jdf", None)) + if jdf is None: + return None + + qe = jdf.queryExecution() + jplan = qe.optimizedPlan() + + if self._jplan_uses_default_size(jplan): + return None + + stats = jplan.stats() + size_attr = getattr(stats, "sizeInBytes", None) + size_val = size_attr() if callable(size_attr) else size_attr + return self._parse_spark_stats_size(size_val) + except Exception: + return None + # ---- SQL hooks ---- def _format_relation_for_ref(self, name: str) -> str: """ @@ -567,13 +771,18 @@ def _save_df_as_table( # Managed tables: delegate to the format handler (Delta, Parquet, Iceberg, ...) self._format_handler.save_df_as_table(table_name, df) + with suppress(Exception): + self._execute_sql( + f"ANALYZE TABLE {self._sql_identifier(table_name)} COMPUTE STATISTICS" + ) + def _create_or_replace_view(self, target_sql: str, select_body: str, node: Node) -> None: - self.spark.sql(f"CREATE OR REPLACE VIEW {target_sql} AS {select_body}") + self._execute_sql(f"CREATE OR REPLACE VIEW {target_sql} AS {select_body}") def _create_or_replace_table(self, target_sql: str, select_body: str, node: Node) -> None: preview = f"-- target={target_sql}\n{select_body}" try: - df = self.spark.sql(select_body) + df = self._execute_sql(select_body) storage_meta = self._storage_meta(node, target_sql) self._save_df_as_table(target_sql, df, storage=storage_meta) except Exception as exc: @@ -584,7 +793,7 @@ def _create_or_replace_view_from_table( ) -> None: view_sql = self._sql_identifier(view_name) backing_sql = self._sql_identifier(backing_table) - self.spark.sql(f"CREATE OR REPLACE VIEW {view_sql} AS SELECT * FROM {backing_sql}") + self._execute_sql(f"CREATE OR REPLACE VIEW {view_sql} AS SELECT * FROM {backing_sql}") # ---- Meta hook ---- def on_node_built(self, node: Node, relation: str, fingerprint: str) -> None: @@ -600,7 +809,7 @@ def exists_relation(self, relation: str) -> bool: def create_table_as(self, relation: str, select_sql: str) -> None: """CREATE TABLE AS with cleaned SELECT body.""" body = self._selectable_body(select_sql).strip().rstrip(";\n\t ") - df = self.spark.sql(body) + df = self._execute_sql(body) self._save_df_as_table(relation, df) def full_refresh_table(self, relation: str, select_sql: str) -> None: @@ -610,7 +819,7 @@ def full_refresh_table(self, relation: str, select_sql: str) -> None: """ body = self._selectable_body(select_sql).strip().rstrip(";\n\t ") # Delegate to format handler via _save_df_as_table for managed, or storage for unmanaged - df = self.spark.sql(body) + df = self._execute_sql(body) self._save_df_as_table(relation, df) def incremental_insert(self, relation: str, select_sql: str) -> None: @@ -702,14 +911,13 @@ def alter_table_sync_schema( existing = {f.name for f in target_df.schema.fields} # Output schema from the SELECT body = self._first_select_body(select_sql).strip().rstrip(";\n\t ") - probe = self.spark.sql(f"SELECT * FROM ({body}) q LIMIT 0") + probe = self._execute_sql(f"SELECT * FROM ({body}) q LIMIT 0") to_add = [f for f in probe.schema.fields if f.name not in existing] if not to_add: return - # Map types best-effort (Spark SQL types); default STRING def _spark_sql_type(dt: DataType) -> str: - # Use simple, portable mapping for documentation UIs & broad compatibility + """Simple, portable mapping for Spark SQL types.""" return ( getattr(dt, "simpleString", lambda: "string")().upper() if hasattr(dt, "simpleString") @@ -718,7 +926,7 @@ def _spark_sql_type(dt: DataType) -> str: cols_sql = ", ".join([f"`{f.name}` {_spark_sql_type(f.dataType)}" for f in to_add]) table_sql = self._sql_identifier(relation) - self.spark.sql(f"ALTER TABLE {table_sql} ADD COLUMNS ({cols_sql})") + self._execute_sql(f"ALTER TABLE {table_sql} ADD COLUMNS ({cols_sql})") # ── Snapshot API ───────────────────────────────────────────────────── @@ -824,7 +1032,7 @@ def _snapshot_first_run( hash_col: str, upd_meta: str, ) -> None: - src_df = self.spark.sql(body) + src_df = self._execute_sql(body) echo_debug(f"[snapshot] first run for {rel_name} (strategy={strategy})") @@ -877,7 +1085,7 @@ def _snapshot_incremental_run( echo_debug(f"[snapshot] incremental run for {rel_name} (strategy={strategy})") existing = self.spark.table(physical) - src_df = self.spark.sql(body) + src_df = self._execute_sql(body) missing_keys_src = [k for k in unique_key if k not in src_df.columns] missing_keys_snap = [k for k in unique_key if k not in existing.columns] @@ -1066,7 +1274,10 @@ def execute(self, sql: str, params: Any | None = None) -> _SparkResult: def _split_db_table(qualified: str) -> tuple[str | None, str]: - """Split 'db.table' → (db, table); backticks allowed. Returns (None, name) if unqualified.""" + """ + Split "db.table" → (db, table); backticks allowed. + Returns (None, name) if unqualified. + """ s = qualified.strip("`") parts = s.split(".") part_len = 2 diff --git a/src/fastflowtransform/executors/duckdb.py b/src/fastflowtransform/executors/duckdb.py index 44121c0..a84397b 100644 --- a/src/fastflowtransform/executors/duckdb.py +++ b/src/fastflowtransform/executors/duckdb.py @@ -1,10 +1,13 @@ # fastflowtransform/executors/duckdb.py from __future__ import annotations +import json +import re from collections.abc import Iterable from contextlib import suppress from pathlib import Path -from typing import Any +from time import perf_counter +from typing import Any, ClassVar import duckdb import pandas as pd @@ -13,6 +16,8 @@ from fastflowtransform.core import Node, relation_for from fastflowtransform.executors.base import BaseExecutor +from fastflowtransform.executors.budget import BudgetGuard +from fastflowtransform.executors.query_stats import QueryStats from fastflowtransform.logging import echo from fastflowtransform.meta import ensure_meta_table, upsert_meta from fastflowtransform.snapshots import resolve_snapshot_config @@ -25,6 +30,41 @@ def _q(ident: str) -> str: class DuckExecutor(BaseExecutor[pd.DataFrame]): ENGINE_NAME = "duckdb" + _FIXED_TYPE_SIZES: ClassVar[dict[str, int]] = { + "boolean": 1, + "bool": 1, + "tinyint": 1, + "smallint": 2, + "integer": 4, + "int": 4, + "bigint": 8, + "float": 4, + "real": 4, + "double": 8, + "double precision": 8, + "decimal": 16, + "numeric": 16, + "uuid": 16, + "json": 64, + "jsonb": 64, + "timestamp": 8, + "timestamp_ntz": 8, + "timestamp_ltz": 8, + "timestamptz": 8, + "date": 4, + "time": 4, + "interval": 16, + } + _VARCHAR_DEFAULT_WIDTH = 64 + _VARCHAR_MAX_WIDTH = 1024 + _DEFAULT_ROW_WIDTH = 128 + _BUDGET_GUARD = BudgetGuard( + env_var="FF_DUCKDB_MAX_BYTES", + estimator_attr="_estimate_query_bytes", + engine_label="DuckDB", + what="query", + ) + def __init__( self, db_path: str = ":memory:", schema: str | None = None, catalog: str | None = None ): @@ -36,6 +76,7 @@ def __init__( self.schema = schema.strip() if isinstance(schema, str) and schema.strip() else None catalog_override = catalog.strip() if isinstance(catalog, str) and catalog.strip() else None self.catalog = self._detect_catalog() + self._table_row_width_cache: dict[tuple[str | None, str], int] = {} if catalog_override: if self._apply_catalog_override(catalog_override): self.catalog = catalog_override @@ -43,12 +84,314 @@ def __init__( self.catalog = self._detect_catalog() if self.schema: safe_schema = _q(self.schema) - self.con.execute(f"create schema if not exists {safe_schema}") - self.con.execute(f"set schema '{self.schema}'") + self._execute_sql(f"create schema if not exists {safe_schema}") + self._execute_sql(f"set schema '{self.schema}'") + + def _execute_sql(self, sql: str, *args: Any, **kwargs: Any) -> duckdb.DuckDBPyConnection: + """ + Central DuckDB SQL runner. + + All model-driven SQL in this executor should go through here. + The cost guard may call _estimate_query_bytes(sql) before executing. + This wrapper also records simple per-query stats for run_results.json. + """ + estimated_bytes = self._apply_budget_guard(self._BUDGET_GUARD, sql) + if estimated_bytes is None and not self._is_budget_guard_active(): + with suppress(Exception): + estimated_bytes = self._estimate_query_bytes(sql) + t0 = perf_counter() + cursor = self.con.execute(sql, *args, **kwargs) + dt_ms = int((perf_counter() - t0) * 1000) + + rows: int | None = None + try: + rc = getattr(cursor, "rowcount", None) + if isinstance(rc, int) and rc >= 0: + rows = rc + except Exception: + rows = None + + # DuckDB doesn't expose bytes-scanned in a simple way yet → rely on the + # estimate we already collected or the best-effort fallback. + self._record_query_stats( + QueryStats( + bytes_processed=estimated_bytes, + rows=rows, + duration_ms=dt_ms, + ) + ) + return cursor + + # --- Cost estimation for the shared BudgetGuard ----------------- + + def _estimate_query_bytes(self, sql: str) -> int | None: + """ + Estimate query size via DuckDB's EXPLAIN (FORMAT JSON). + + The JSON plan exposes an \"Estimated Cardinality\" per node. + We walk the parsed tree, take the highest non-zero estimate and + return it as a byte-estimate surrogate (row count ≈ bytes) so the + cost guard can still make a meaningful decision without executing + the query. + """ + try: + body = self._selectable_body(sql).strip().rstrip(";\n\t ") + except AttributeError: + body = sql.strip().rstrip(";\n\t ") + + lower = body.lower() + if not lower.startswith(("select", "with")): + return None + + explain_sql = f"EXPLAIN (FORMAT JSON) {body}" + try: + rows = self.con.execute(explain_sql).fetchall() + except Exception: + return None + + if not rows: + return None + + fragments: list[str] = [] + for row in rows: + for cell in row: + if cell is None: + continue + fragments.append(str(cell)) + + if not fragments: + return None + + plan_text = "\n".join(fragments).strip() + start = plan_text.find("[") + end = plan_text.rfind("]") + if start == -1 or end == -1 or end <= start: + return None + + try: + plan_data = json.loads(plan_text[start : end + 1]) + except Exception: + return None + + def _to_int(value: Any) -> int | None: + if value is None: + return None + if isinstance(value, (int, float)): + try: + converted = int(value) + except Exception: + return None + return converted + text = str(value) + match = re.search(r"(\d+(?:\.\d+)?)", text) + if not match: + return None + try: + return int(float(match.group(1))) + except ValueError: + return None + + def _walk_node(node: dict[str, Any]) -> int: + best = 0 + extra = node.get("extra_info") or {} + for key in ( + "Estimated Cardinality", + "estimated_cardinality", + "Cardinality", + "cardinality", + ): + candidate = _to_int(extra.get(key)) + if candidate is not None: + best = max(best, candidate) + candidate = _to_int(node.get("cardinality")) + if candidate is not None: + best = max(best, candidate) + for child in node.get("children") or []: + if isinstance(child, dict): + best = max(best, _walk_node(child)) + return best + + nodes: list[Any] + nodes = plan_data if isinstance(plan_data, list) else [plan_data] + + estimate = 0 + for entry in nodes: + if isinstance(entry, dict): + estimate = max(estimate, _walk_node(entry)) + + if estimate <= 0: + return None + + tables = self._collect_tables_from_plan(nodes) + row_width = self._row_width_for_tables(tables) + if row_width <= 0: + row_width = self._DEFAULT_ROW_WIDTH + + bytes_estimate = int(estimate * row_width) + return bytes_estimate if bytes_estimate > 0 else None + + def _collect_tables_from_plan(self, nodes: list[dict[str, Any]]) -> set[tuple[str | None, str]]: + tables: set[tuple[str | None, str]] = set() + + def _walk(entry: dict[str, Any]) -> None: + extra = entry.get("extra_info") or {} + table_val = extra.get("Table") + schema_val = extra.get("Schema") or extra.get("Database") or extra.get("Catalog") + if isinstance(table_val, str) and table_val.strip(): + schema, table = self._split_identifier(table_val, schema_val) + if table: + tables.add((schema, table)) + for child in entry.get("children") or []: + if isinstance(child, dict): + _walk(child) + + for node in nodes: + if isinstance(node, dict): + _walk(node) + return tables + + def _split_identifier( + self, identifier: str, explicit_schema: str | None + ) -> tuple[str | None, str]: + parts = [part.strip() for part in identifier.split(".") if part.strip()] + if not parts: + return explicit_schema, identifier + if len(parts) >= 2: + schema_candidate = self._strip_quotes(parts[-2]) + table_candidate = self._strip_quotes(parts[-1]) + return schema_candidate or explicit_schema, table_candidate + return explicit_schema, self._strip_quotes(parts[-1]) + + def _strip_quotes(self, value: str) -> str: + if value.startswith('"') and value.endswith('"'): + return value[1:-1] + return value + + def _row_width_for_tables(self, tables: Iterable[tuple[str | None, str]]) -> int: + widths: list[int] = [] + for schema, table in tables: + width = self._row_width_for_table(schema, table) + if width > 0: + widths.append(width) + return max(widths) if widths else 0 + + def _row_width_for_table(self, schema: str | None, table: str) -> int: + key = (schema or "", table.lower()) + cached = self._table_row_width_cache.get(key) + if cached: + return cached + + columns = self._columns_for_table(table, schema) + width = sum(self._estimate_column_width(col) for col in columns) + if width <= 0: + width = self._DEFAULT_ROW_WIDTH + self._table_row_width_cache[key] = width + return width + + def _columns_for_table( + self, table: str, schema: str | None + ) -> list[tuple[str | None, int | None, int | None, int | None]]: + table_lower = table.lower() + columns: list[tuple[str | None, int | None, int | None, int | None]] = [] + seen_schemas: set[str | None] = set() + for candidate in self._schema_candidates(schema): + if candidate in seen_schemas: + continue + seen_schemas.add(candidate) + if candidate is not None: + try: + rows = self.con.execute( + """ + select lower(data_type) as dtype, + character_maximum_length, + numeric_precision, + numeric_scale + from information_schema.columns + where lower(table_name)=lower(?) + and lower(table_schema)=lower(?) + order by ordinal_position + """, + [table_lower, candidate.lower()], + ).fetchall() + except Exception: + continue + else: + try: + rows = self.con.execute( + """ + select lower(data_type) as dtype, + character_maximum_length, + numeric_precision, + numeric_scale + from information_schema.columns + where lower(table_name)=lower(?) + order by lower(table_schema), ordinal_position + """, + [table_lower], + ).fetchall() + except Exception: + continue + if rows: + return rows + return columns + + def _schema_candidates(self, schema: str | None) -> list[str | None]: + candidates: list[str | None] = [] + + def _add(value: str | None) -> None: + normalized = self._normalize_schema(value) + if normalized not in candidates: + candidates.append(normalized) + + _add(schema) + _add(self.schema) + for alt in ("main", "temp"): + _add(alt) + _add(None) + return candidates + + def _normalize_schema(self, schema: str | None) -> str | None: + if not schema: + return None + stripped = schema.strip() + return stripped or None + + def _estimate_column_width( + self, column_info: tuple[str | None, int | None, int | None, int | None] + ) -> int: + dtype_raw, char_max, numeric_precision, _ = column_info + dtype = self._normalize_data_type(dtype_raw) + if dtype and dtype in self._FIXED_TYPE_SIZES: + return self._FIXED_TYPE_SIZES[dtype] + + if dtype in {"character", "varchar", "char", "text", "string"}: + if char_max and char_max > 0: + return min(char_max, self._VARCHAR_MAX_WIDTH) + return self._VARCHAR_DEFAULT_WIDTH + + if dtype in {"varbinary", "blob", "binary"}: + if char_max and char_max > 0: + return min(char_max, self._VARCHAR_MAX_WIDTH) + return self._VARCHAR_DEFAULT_WIDTH + + if dtype in {"numeric", "decimal"} and numeric_precision and numeric_precision > 0: + return min(max(int(numeric_precision), 16), 128) + + return 16 + + def _normalize_data_type(self, dtype: str | None) -> str | None: + if not dtype: + return None + stripped = dtype.strip().lower() + if "(" in stripped: + stripped = stripped.split("(", 1)[0].strip() + if stripped.endswith("[]"): + stripped = stripped[:-2] + return stripped or None def _detect_catalog(self) -> str | None: try: - rows = self.con.execute("PRAGMA database_list").fetchall() + rows = self._execute_sql("PRAGMA database_list").fetchall() if rows: return str(rows[0][1]) except Exception: @@ -63,9 +406,9 @@ def _apply_catalog_override(self, name: str) -> bool: if self.db_path != ":memory:": resolved = str(Path(self.db_path).resolve()) with suppress(Exception): - self.con.execute(f"detach database {_q(alias)}") - self.con.execute(f"attach database '{resolved}' as {_q(alias)} (READ_ONLY FALSE)") - self.con.execute(f"set catalog '{alias}'") + self._execute_sql(f"detach database {_q(alias)}") + self._execute_sql(f"attach database '{resolved}' as {_q(alias)} (READ_ONLY FALSE)") + self._execute_sql(f"set catalog '{alias}'") return True except Exception: return False @@ -85,7 +428,7 @@ def _exec_many(self, sql: str) -> None: for stmt in (part.strip() for part in sql.split(";")): if not stmt: continue - self.con.execute(stmt) + self._execute_sql(stmt) # ---- Frame hooks ---- def _qualified(self, relation: str, *, quoted: bool = True) -> str: @@ -119,7 +462,7 @@ def _read_relation(self, relation: str, node: Node, deps: Iterable[str]) -> pd.D except CatalogException as e: existing = [ r[0] - for r in self.con.execute( + for r in self._execute_sql( "select table_name from information_schema.tables " "where table_schema in ('main','temp')" ).fetchall() @@ -135,19 +478,20 @@ def _materialize_relation(self, relation: str, df: pd.DataFrame, node: Node) -> try: self.con.register(tmp, df) target = self._qualified(relation) - self.con.execute(f'create or replace table {target} as select * from "{tmp}"') + self._execute_sql(f'create or replace table {target} as select * from "{tmp}"') finally: try: self.con.unregister(tmp) except Exception: - self.con.execute(f'drop view if exists "{tmp}"') + # housekeeping only; stats here are not important but harmless if recorded + self._execute_sql(f'drop view if exists "{tmp}"') def _create_or_replace_view_from_table( self, view_name: str, backing_table: str, node: Node ) -> None: view_target = self._qualified(view_name) backing = self._qualified(backing_table) - self.con.execute(f"create or replace view {view_target} as select * from {backing}") + self._execute_sql(f"create or replace view {view_target} as select * from {backing}") def _frame_name(self) -> str: return "pandas" @@ -193,10 +537,10 @@ def _format_source_reference( return ".".join(_q(str(part)) for part in parts if part) def _create_or_replace_view(self, target_sql: str, select_body: str, node: Node) -> None: - self.con.execute(f"create or replace view {target_sql} as {select_body}") + self._execute_sql(f"create or replace view {target_sql} as {select_body}") def _create_or_replace_table(self, target_sql: str, select_body: str, node: Node) -> None: - self.con.execute(f"create or replace table {target_sql} as {select_body}") + self._execute_sql(f"create or replace table {target_sql} as {select_body}") # ---- Meta hook ---- def on_node_built(self, node: Node, relation: str, fingerprint: str) -> None: @@ -220,20 +564,20 @@ def exists_relation(self, relation: str) -> bool: where_tables.append("table_schema in ('main','temp')") where = " AND ".join(where_tables) sql_tables = f"select 1 from information_schema.tables where {where} limit 1" - if self.con.execute(sql_tables, params).fetchone(): + if self._execute_sql(sql_tables, params).fetchone(): return True sql_views = f"select 1 from information_schema.views where {where} limit 1" - return bool(self.con.execute(sql_views, params).fetchone()) + return bool(self._execute_sql(sql_views, params).fetchone()) def create_table_as(self, relation: str, select_sql: str) -> None: # Use only the SELECT body and strip trailing semicolons for safety. body = self._selectable_body(select_sql).strip().rstrip(";\n\t ") - self.con.execute(f"create table {self._qualified(relation)} as {body}") + self._execute_sql(f"create table {self._qualified(relation)} as {body}") def incremental_insert(self, relation: str, select_sql: str) -> None: # Ensure the inner SELECT is clean (no trailing semicolon; SELECT body only). body = self._selectable_body(select_sql).strip().rstrip(";\n\t ") - self.con.execute(f"insert into {self._qualified(relation)} {body}") + self._execute_sql(f"insert into {self._qualified(relation)} {body}") def incremental_merge(self, relation: str, select_sql: str, unique_key: list[str]) -> None: """ @@ -251,11 +595,11 @@ def incremental_merge(self, relation: str, select_sql: str, unique_key: list[str # 3) first: delete collisions delete_sql = f"delete from {self._qualified(relation)} t using ({body}) s where {keys_pred}" - self.con.execute(delete_sql) + self._execute_sql(delete_sql) # 4) then: insert fresh rows insert_sql = f"insert into {self._qualified(relation)} select * from ({body}) src" - self.con.execute(insert_sql) + self._execute_sql(insert_sql) def alter_table_sync_schema( self, relation: str, select_sql: str, *, mode: str = "append_new_columns" @@ -265,12 +609,11 @@ def alter_table_sync_schema( """ # Probe: empty projection from the SELECT (cleaned to avoid parser issues). body = self._first_select_body(select_sql).strip().rstrip(";\n\t ") - probe = self.con.execute(f"select * from ({body}) as q limit 0") + probe = self._execute_sql(f"select * from ({body}) as q limit 0") cols = [c[0] for c in probe.description or []] - # vorhandene Spalten existing = { r[0] - for r in self.con.execute( + for r in self._execute_sql( "select column_name from information_schema.columns " + "where lower(table_name)=lower(?)" + (" and lower(table_schema)=lower(?)" if self.schema else ""), @@ -279,14 +622,12 @@ def alter_table_sync_schema( } add = [c for c in cols if c not in existing] for c in add: - # Typ heuristisch: typeof aus einer CAST-Probe; fallback VARCHAR col = _q(c) target = self._qualified(relation) try: - # Versuche Typ aus Expression abzuleiten (best effort) - self.con.execute(f"alter table {target} add column {col} varchar") + self._execute_sql(f"alter table {target} add column {col} varchar") except Exception: - self.con.execute(f"alter table {target} add column {col} varchar") + self._execute_sql(f"alter table {target} add column {col} varchar") def run_snapshot_sql(self, node: Node, env: Environment) -> None: """ @@ -371,7 +712,7 @@ def run_snapshot_sql(self, node: Node, env: Environment) -> None: {hash_expr} as {hash_col} from ({body}) as s """ - self.con.execute(create_sql) + self._execute_sql(create_sql) return # ---- Incremental snapshot update ---- @@ -379,7 +720,7 @@ def run_snapshot_sql(self, node: Node, env: Environment) -> None: # Stage current source rows in a temp view for reuse src_view_name = f"__ff_snapshot_src_{rel_name}".replace(".", "_") src_quoted = _q(src_view_name) - self.con.execute(f"create or replace temp view {src_quoted} as {body}") + self._execute_sql(f"create or replace temp view {src_quoted} as {body}") try: # Join predicate on unique keys @@ -413,7 +754,7 @@ def run_snapshot_sql(self, node: Node, env: Environment) -> None: and t.{is_cur} = true and {change_condition}; """ - self.con.execute(close_sql) + self._execute_sql(close_sql) # 2) Insert new current versions (new keys or changed rows) first_key = unique_key[0] @@ -434,10 +775,10 @@ def run_snapshot_sql(self, node: Node, env: Environment) -> None: t.{first_key} is null or {change_condition}; """ - self.con.execute(insert_sql) + self._execute_sql(insert_sql) finally: with suppress(Exception): - self.con.execute(f"drop view if exists {src_quoted}") + self._execute_sql(f"drop view if exists {src_quoted}") def snapshot_prune( self, @@ -484,7 +825,7 @@ def snapshot_prune( from ranked where rn > {int(keep_last)} """ - res = self.con.execute(sql).fetchone() + res = self._execute_sql(sql).fetchone() rows = int(res[0]) if res else 0 echo( @@ -503,4 +844,4 @@ def snapshot_prune( and {" AND ".join([f"t.{k} = r.{k}" for k in keys])} and t.{vf} = r.{vf}; """ - self.con.execute(delete_sql) + self._execute_sql(delete_sql) diff --git a/src/fastflowtransform/executors/postgres.py b/src/fastflowtransform/executors/postgres.py index de57fda..2800d80 100644 --- a/src/fastflowtransform/executors/postgres.py +++ b/src/fastflowtransform/executors/postgres.py @@ -1,5 +1,8 @@ # fastflowtransform/executors/postgres.py +import json from collections.abc import Iterable +from contextlib import suppress +from time import perf_counter from typing import Any import pandas as pd @@ -12,6 +15,8 @@ from fastflowtransform.errors import ModelExecutionError, ProfileConfigError from fastflowtransform.executors._shims import SAConnShim from fastflowtransform.executors.base import BaseExecutor +from fastflowtransform.executors.budget import BudgetGuard +from fastflowtransform.executors.query_stats import QueryStats from fastflowtransform.logging import echo from fastflowtransform.meta import ensure_meta_table, upsert_meta from fastflowtransform.snapshots import resolve_snapshot_config @@ -19,6 +24,13 @@ class PostgresExecutor(BaseExecutor[pd.DataFrame]): ENGINE_NAME = "postgres" + _DEFAULT_PG_ROW_WIDTH = 128 + _BUDGET_GUARD = BudgetGuard( + env_var="FF_PG_MAX_BYTES", + estimator_attr="_estimate_query_bytes", + engine_label="Postgres", + what="query", + ) def __init__(self, dsn: str, schema: str | None = None): """ @@ -46,6 +58,181 @@ def __init__(self, dsn: str, schema: str | None = None): # ⇣ fastflowtransform.testing expects executor.con.execute("SQL") self.con = SAConnShim(self.engine, schema=self.schema) + def _execute_sql( + self, + sql: str, + *args: Any, + conn: Connection | None = None, + **kwargs: Any, + ) -> Any: + """ + Central Postgres SQL runner. + + All model-driven SQL in this executor should go through here. + + If `conn` is provided, reuse that connection (important for temp tables / + snapshots). Otherwise, open a fresh transaction via engine.begin(). + + Also records simple per-query stats for run_results.json. + """ + estimated_bytes = self._apply_budget_guard(self._BUDGET_GUARD, sql) + if estimated_bytes is None and not self._is_budget_guard_active(): + with suppress(Exception): + estimated_bytes = self._estimate_query_bytes(sql) + t0 = perf_counter() + + if conn is None: + # Standalone use: open our own transaction + with self.engine.begin() as local_conn: + self._set_search_path(local_conn) + result = local_conn.execute(text(sql), *args, **kwargs) + else: + # Reuse existing connection / transaction (e.g. in run_snapshot_sql) + self._set_search_path(conn) + result = conn.execute(text(sql), *args, **kwargs) + + dt_ms = int((perf_counter() - t0) * 1000) + + # rows: best-effort from Result.rowcount (DML only; SELECT is often -1) + rows: int | None = None + try: + rc = getattr(result, "rowcount", None) + if isinstance(rc, int) and rc >= 0: + rows = rc + except Exception: + rows = None + + self._record_query_stats( + QueryStats( + bytes_processed=estimated_bytes, + rows=rows, + duration_ms=dt_ms, + ) + ) + return result + + def _analyze_relations( + self, + relations: Iterable[str], + conn: Connection | None = None, + ) -> None: + """ + Run ANALYZE on the given relations. + + - Never goes through _execute_sql (avoids the budget guard recursion). + - Uses passed-in conn if given, otherwise opens its own transaction. + - Best-effort: logs and continues on failure. + """ + owns_conn = False + if conn is None: + conn_ctx = self.engine.begin() + conn = conn_ctx.__enter__() + owns_conn = True + try: + self._set_search_path(conn) + for rel in relations: + try: + # If it already looks qualified, leave it; otherwise qualify. + qrel = self._qualified(rel) if "." not in rel else rel + conn.execute(text(f"ANALYZE {qrel}")) + except Exception: + pass + finally: + if owns_conn: + conn_ctx.__exit__(None, None, None) + + # --- Cost estimation for the shared BudgetGuard ----------------- + + def _estimate_query_bytes(self, sql: str) -> int | None: + """ + Best-effort bytes estimate for a SELECT-ish query using + EXPLAIN (FORMAT JSON). + + Approximation: estimated_rows * avg_row_width (in bytes). + Returns None if: + - the query is not SELECT/CTE + - EXPLAIN fails + - the JSON structure is not what we expect + """ + body = self._extract_select_like(sql) + lower = body.lstrip().lower() + if not lower.startswith(("select", "with")): + # Only try to estimate for read-like queries + return None + + explain_sql = f"EXPLAIN (FORMAT JSON) {body}" + + try: + with self.engine.begin() as conn: + self._set_search_path(conn) + raw = conn.execute(text(explain_sql)).scalar() + except Exception: + return None + + if raw is None: + return None + + try: + data = json.loads(raw) + except Exception: + data = raw + + # Postgres JSON format: list with a single object + if isinstance(data, list) and data: + root = data[0] + elif isinstance(data, dict): + root = data + else: + return None + + plan = root.get("Plan") + if not isinstance(plan, dict): + if isinstance(root, dict) and "Node Type" in root: + plan = root + else: + return None + + return self._estimate_bytes_from_plan(plan) + + def _estimate_bytes_from_plan(self, plan: dict[str, Any]) -> int | None: + """ + Estimate bytes for the *model output* from the root plan node. + + Approximation: root.Plan Rows * root.Plan Width (or DEFAULT_PG_ROW_WIDTH + if width is missing). + """ + + def _to_int(node: dict[str, Any], keys: tuple[str, ...]) -> int | None: + for key in keys: + val = node.get(key) + if val is None: + continue + try: + return int(val) + except (TypeError, ValueError): + continue + return None + + rows = _to_int(plan, ("Plan Rows", "Plan_Rows", "Rows")) + width = _to_int(plan, ("Plan Width", "Plan_Width", "Width")) + + if rows is None and width is None: + return None + + candidate: int | None + + if rows is not None and width is not None: + candidate = rows * width + elif rows is not None: + candidate = rows * self._DEFAULT_PG_ROW_WIDTH + else: + candidate = width + + if candidate is None or candidate <= 0: + return None + + return int(candidate) + # --- Helpers --------------------------------------------------------- def _q_ident(self, ident: str) -> str: # Simple, safe quoting for identifiers @@ -88,6 +275,10 @@ def _read_relation(self, relation: str, node: Node, deps: Iterable[str]) -> pd.D raise e def _materialize_relation(self, relation: str, df: pd.DataFrame, node: Node) -> None: + self._write_dataframe_with_stats(relation, df, node) + + def _write_dataframe_with_stats(self, relation: str, df: pd.DataFrame, node: Node) -> None: + start = perf_counter() try: df.to_sql( relation, @@ -101,6 +292,21 @@ def _materialize_relation(self, relation: str, df: pd.DataFrame, node: Node) -> raise ModelExecutionError( node_name=node.name, relation=self._qualified(relation), message=str(e) ) from e + else: + self._analyze_relations([relation]) + self._record_dataframe_stats(df, int((perf_counter() - start) * 1000)) + + def _record_dataframe_stats(self, df: pd.DataFrame, duration_ms: int) -> None: + rows = len(df) + bytes_estimate = int(df.memory_usage(deep=True).sum()) if rows > 0 else 0 + bytes_val = bytes_estimate if bytes_estimate > 0 else None + self._record_query_stats( + QueryStats( + bytes_processed=bytes_val, + rows=rows if rows > 0 else None, + duration_ms=duration_ms, + ) + ) # ---------- Python view helper ---------- def _create_or_replace_view_from_table( @@ -109,10 +315,12 @@ def _create_or_replace_view_from_table( q_view = self._qualified(view_name) q_back = self._qualified(backing_table) try: - with self.engine.begin() as c: - self._set_search_path(c) - c.execute(text(f"DROP VIEW IF EXISTS {q_view} CASCADE")) - c.execute(text(f"CREATE OR REPLACE VIEW {q_view} AS SELECT * FROM {q_back}")) + with self.engine.begin() as conn: + self._execute_sql(f"DROP VIEW IF EXISTS {q_view} CASCADE", conn=conn) + self._execute_sql( + f"CREATE OR REPLACE VIEW {q_view} AS SELECT * FROM {q_back}", conn=conn + ) + except Exception as e: raise ModelExecutionError(node.name, q_view, str(e)) from e @@ -141,10 +349,8 @@ def _format_source_reference( def _create_or_replace_view(self, target_sql: str, select_body: str, node: Node) -> None: try: - with self.engine.begin() as conn: - self._set_search_path(conn) - conn.execute(text(f"DROP VIEW IF EXISTS {target_sql} CASCADE")) - conn.execute(text(f"CREATE OR REPLACE VIEW {target_sql} AS {select_body}")) + self._execute_sql(f"DROP VIEW IF EXISTS {target_sql} CASCADE") + self._execute_sql(f"CREATE OR REPLACE VIEW {target_sql} AS {select_body}") except Exception as e: preview = f"-- target={target_sql}\n{select_body}" raise ModelExecutionError(node.name, target_sql, str(e), sql_snippet=preview) from e @@ -155,10 +361,9 @@ def _create_or_replace_table(self, target_sql: str, select_body: str, node: Node Use DROP TABLE IF EXISTS + CREATE TABLE AS, and accept CTE bodies. """ try: - with self.engine.begin() as conn: - self._set_search_path(conn) - conn.execute(text(f"DROP TABLE IF EXISTS {target_sql} CASCADE")) - conn.execute(text(f"CREATE TABLE {target_sql} AS {select_body}")) + self._execute_sql(f"DROP TABLE IF EXISTS {target_sql} CASCADE") + self._execute_sql(f"CREATE TABLE {target_sql} AS {select_body}") + self._analyze_relations([target_sql]) except Exception as e: preview = f"-- target={target_sql}\n{select_body}" raise ModelExecutionError(node.name, target_sql, str(e), sql_snippet=preview) from e @@ -176,8 +381,7 @@ def exists_relation(self, relation: str) -> bool: """ Return True if a table OR view exists for 'relation' in current schema. """ - sql = text( - """ + sql = """ select 1 from information_schema.tables where table_schema = current_schema() @@ -189,17 +393,14 @@ def exists_relation(self, relation: str) -> bool: and lower(table_name) = lower(:t) limit 1 """ - ) - with self.engine.begin() as con: - self._set_search_path(con) - return bool(con.execute(sql, {"t": relation}).fetchone()) + + return bool(self._execute_sql(sql, {"t": relation}).fetchone()) def create_table_as(self, relation: str, select_sql: str) -> None: body = self._extract_select_like(select_sql) qrel = self._qualified(relation) - with self.engine.begin() as con: - self._set_search_path(con) - con.execute(text(f"create table {qrel} as {body}")) + self._execute_sql(f"create table {qrel} as {body}") + self._analyze_relations([relation]) def full_refresh_table(self, relation: str, select_sql: str) -> None: """ @@ -208,17 +409,15 @@ def full_refresh_table(self, relation: str, select_sql: str) -> None: """ body = self._selectable_body(select_sql).strip().rstrip(";\n\t ") qrel = self._qualified(relation) - with self.engine.begin() as conn: - self._set_search_path(conn) - conn.execute(text(f"drop table if exists {qrel}")) - conn.execute(text(f"create table {qrel} as {body}")) + self._execute_sql(f"drop table if exists {qrel}") + self._execute_sql(f"create table {qrel} as {body}") + self._analyze_relations([relation]) def incremental_insert(self, relation: str, select_sql: str) -> None: body = self._extract_select_like(select_sql) qrel = self._qualified(relation) - with self.engine.begin() as con: - self._set_search_path(con) - con.execute(text(f"insert into {qrel} {body}")) + self._execute_sql(f"insert into {qrel} {body}") + self._analyze_relations([relation]) def incremental_merge(self, relation: str, select_sql: str, unique_key: list[str]) -> None: """ @@ -227,14 +426,13 @@ def incremental_merge(self, relation: str, select_sql: str, unique_key: list[str body = self._extract_select_like(select_sql) qrel = self._qualified(relation) pred = " AND ".join([f"t.{k}=s.{k}" for k in unique_key]) - with self.engine.begin() as con: - self._set_search_path(con) - con.execute(text(f"create temporary table ff_stg as {body}")) - try: - con.execute(text(f"delete from {qrel} t using ff_stg s where {pred}")) - con.execute(text(f"insert into {qrel} select * from ff_stg")) - finally: - con.execute(text("drop table if exists ff_stg")) + self._execute_sql(f"create temporary table ff_stg as {body}") + try: + self._execute_sql(f"delete from {qrel} t using ff_stg s where {pred}") + self._execute_sql(f"insert into {qrel} select * from ff_stg") + self._analyze_relations([relation]) + finally: + self._execute_sql("drop table if exists ff_stg") def alter_table_sync_schema( self, relation: str, select_sql: str, *, mode: str = "append_new_columns" @@ -244,22 +442,28 @@ def alter_table_sync_schema( """ body = self._extract_select_like(select_sql) qrel = self._qualified(relation) - with self.engine.begin() as con: - self._set_search_path(con) - cols = [r[0] for r in con.execute(text(f"select * from ({body}) q limit 0"))] + + with self.engine.begin() as conn: + # Probe output columns + cols = [r[0] for r in self._execute_sql(f"select * from ({body}) q limit 0")] + + # Existing columns in target table existing = { r[0] - for r in con.execute( - text( - "select column_name from information_schema.columns " - "where table_schema = current_schema() and lower(table_name)=lower(:t)" - ), + for r in self._execute_sql( + """ + select column_name + from information_schema.columns + where table_schema = current_schema() + and lower(table_name)=lower(:t) + """, {"t": relation}, ).fetchall() } + add = [c for c in cols if c not in existing] for c in add: - con.execute(text(f'alter table {qrel} add column "{c}" text')) + self._execute_sql(f'alter table {qrel} add column "{c}" text', conn=conn) # ── Snapshot API ────────────────────────────────────────────────────── @@ -346,9 +550,7 @@ def run_snapshot_sql(self, node: Node, env: Environment) -> None: {hash_expr} as {hash_col} from ({body}) as s """ - with self.engine.begin() as conn: - self._set_search_path(conn) - conn.execute(text(create_sql)) + self._execute_sql(create_sql) return # ---- Incremental snapshot update ---- @@ -358,11 +560,9 @@ def run_snapshot_sql(self, node: Node, env: Environment) -> None: src_q = self._q_ident(src_name) with self.engine.begin() as conn: - self._set_search_path(conn) - # (Re-)create temp staging table - conn.execute(text(f"drop table if exists {src_q}")) - conn.execute(text(f"create temporary table {src_q} as {body}")) + self._execute_sql(f"drop table if exists {src_q}", conn=conn) + self._execute_sql(f"create temporary table {src_q} as {body}", conn=conn) # Join predicate on unique keys keys_pred = " AND ".join([f"t.{k} = s.{k}" for k in unique_key]) @@ -397,7 +597,7 @@ def run_snapshot_sql(self, node: Node, env: Environment) -> None: and t.{is_cur} = true and {change_condition}; """ - conn.execute(text(close_sql)) + self._execute_sql(close_sql, conn=conn) # 2) Insert new current versions (new keys or changed rows) first_key = unique_key[0] @@ -418,11 +618,12 @@ def run_snapshot_sql(self, node: Node, env: Environment) -> None: t.{first_key} is null or {change_condition}; """ - conn.execute(text(insert_sql)) + self._execute_sql(insert_sql, conn=conn) # Temp table will be dropped automatically at end of session; dropping # explicitly here is harmless and keeps the connection clean for tests. - conn.execute(text(f"drop table if exists {src_q}")) + self._execute_sql(f"drop table if exists {src_q}", conn=conn) + self._analyze_relations([target], conn=conn) def snapshot_prune( self, @@ -468,10 +669,8 @@ def snapshot_prune( from ranked where rn > {int(keep_last)} """ - with self.engine.begin() as conn: - self._set_search_path(conn) - res = conn.execute(text(sql)).fetchone() - rows = int(res[0]) if res else 0 + res = self._execute_sql(sql).fetchone() + rows = int(res[0]) if res else 0 echo( f"[DRY-RUN] snapshot_prune({relation}): would delete {rows} row(s) " f"(keep_last={keep_last})" @@ -488,6 +687,5 @@ def snapshot_prune( and {" AND ".join([f"t.{k} = r.{k}" for k in keys])} and t.{vf} = r.{vf}; """ - with self.engine.begin() as conn: - self._set_search_path(conn) - conn.execute(text(delete_sql)) + self._execute_sql(delete_sql) + self._analyze_relations([relation]) diff --git a/src/fastflowtransform/executors/query_stats.py b/src/fastflowtransform/executors/query_stats.py new file mode 100644 index 0000000..a4d58b7 --- /dev/null +++ b/src/fastflowtransform/executors/query_stats.py @@ -0,0 +1,52 @@ +# fastflowtransform/executors/query_stats.py +from __future__ import annotations + +from collections.abc import Callable +from contextlib import suppress +from dataclasses import dataclass +from typing import Any + + +@dataclass +class QueryStats: + """ + Normalised query metrics. + + Engines can fill whichever fields they support; unsupported ones + can simply stay None. + """ + + bytes_processed: int | None = None + rows: int | None = None + duration_ms: int | None = None + + +class _TrackedQueryJob: + """ + Thin proxy around an engine-specific query job (BigQuery, Snowflake, Spark...). + + - Forwards all unknown attributes to the wrapped job. + - Intercepts `.result(...)` and calls `on_complete(inner_job)` exactly once. + - Never raises from the callback; stats collection is strictly best-effort. + """ + + def __init__(self, inner_job: Any, *, on_complete: Callable[[Any], None]) -> None: + self._inner_job = inner_job + self._on_complete = on_complete + self._done = False + + def result(self, *args: Any, **kwargs: Any) -> Any: + # Delegate to the underlying job + res = self._inner_job.result(*args, **kwargs) + + # Fire the completion callback once + if not self._done: + self._done = True + with suppress(Exception): + self._on_complete(self._inner_job) + + return res + + def __getattr__(self, name: str) -> Any: + # Forward all other attributes to the underlying job + return getattr(self._inner_job, name) diff --git a/src/fastflowtransform/executors/snowflake_snowpark.py b/src/fastflowtransform/executors/snowflake_snowpark.py index d60fd8a..d166bf8 100644 --- a/src/fastflowtransform/executors/snowflake_snowpark.py +++ b/src/fastflowtransform/executors/snowflake_snowpark.py @@ -1,14 +1,18 @@ # src/fastflowtransform/executors/snowflake_snowpark.py from __future__ import annotations +import json from collections.abc import Iterable from contextlib import suppress +from time import perf_counter from typing import Any, cast from jinja2 import Environment from fastflowtransform.core import Node, relation_for from fastflowtransform.executors.base import BaseExecutor +from fastflowtransform.executors.budget import BudgetGuard +from fastflowtransform.executors.query_stats import QueryStats from fastflowtransform.logging import echo from fastflowtransform.meta import ensure_meta_table, upsert_meta from fastflowtransform.snapshots import resolve_snapshot_config @@ -18,6 +22,12 @@ class SnowflakeSnowparkExecutor(BaseExecutor[SNDF]): ENGINE_NAME = "snowflake_snowpark" """Snowflake executor operating on Snowpark DataFrames (no pandas).""" + _BUDGET_GUARD = BudgetGuard( + env_var="FF_SF_MAX_BYTES", + estimator_attr="_estimate_query_bytes", + engine_label="Snowflake", + what="query", + ) def __init__(self, cfg: dict): # cfg: {account, user, password, warehouse, database, schema, role?} @@ -31,6 +41,107 @@ def __init__(self, cfg: dict): # Provide a tiny testing shim so tests can call executor.con.execute("SQL") self.con = _SFCursorShim(self.session) + # ---------- Cost estimation & central execution ---------- + + def _estimate_query_bytes(self, sql: str) -> int | None: + """ + Best-effort Snowflake bytes estimation. + + Uses `EXPLAIN USING TEXT` and tries to extract a "bytes="-style + metric from the textual plan. If parsing fails or Snowflake doesn't + expose such info, returns None and the guard is effectively disabled. + """ + try: + body = self._selectable_body(sql) + except Exception: + body = sql + + try: + rows = self.session.sql(f"EXPLAIN USING JSON {body}").collect() + if not rows: + return None + + parts: list[str] = [] + for r in rows: + try: + parts.append(str(r[0])) + except Exception: + as_dict: dict[str, Any] = getattr(r, "asDict", lambda: {})() + if as_dict: + parts.extend(str(v) for v in as_dict.values()) + + plan_text = "\n".join(parts).strip() + if not plan_text: + return None + + try: + plan_data = json.loads(plan_text) + except Exception: + return None + + bytes_val = self._extract_bytes_from_plan(plan_data) + if bytes_val is None or bytes_val <= 0: + return None + return bytes_val + except Exception: + # Any parsing / EXPLAIN issues → no estimate, guard skipped + return None + + def _extract_bytes_from_plan(self, plan_data: Any) -> int | None: + def _to_int(value: Any) -> int | None: + if value is None: + return None + try: + return int(value) + except Exception: + return None + + if isinstance(plan_data, dict): + global_stats = plan_data.get("GlobalStats") or plan_data.get("globalStats") + if isinstance(global_stats, dict): + candidate = _to_int( + global_stats.get("bytesAssigned") or global_stats.get("bytes_assigned") + ) + if candidate: + return candidate + for val in plan_data.values(): + bytes_val = self._extract_bytes_from_plan(val) + if bytes_val: + return bytes_val + elif isinstance(plan_data, list): + for item in plan_data: + bytes_val = self._extract_bytes_from_plan(item) + if bytes_val: + return bytes_val + return None + + def _execute_sql(self, sql: str) -> SNDF: + """ + Central Snowflake SQL runner. + + - Returns a Snowpark DataFrame (same as session.sql). + - Records best-effort query stats for run_results.json. + """ + estimated_bytes = self._apply_budget_guard(self._BUDGET_GUARD, sql) + if estimated_bytes is None and not self._is_budget_guard_active(): + with suppress(Exception): + estimated_bytes = self._estimate_query_bytes(sql) + t0 = perf_counter() + df = self.session.sql(sql) + dt_ms = int((perf_counter() - t0) * 1000) + + # We *don't* call df.count() here - that would execute the query again. + # For Snowflake we also don't cheaply access bytes/rows here; the cost + # guard already did a dry-run EXPLAIN if FF_SF_MAX_BYTES is set. + self._record_query_stats( + QueryStats( + bytes_processed=estimated_bytes, + rows=None, + duration_ms=dt_ms, + ) + ) + return df + # ---------- Helpers ---------- def _q(self, s: str) -> str: return '"' + s.replace('"', '""') + '"' @@ -79,12 +190,85 @@ def _materialize_relation(self, relation: str, df: SNDF, node: Node) -> None: if cols != upper_cols: df = df.toDF(*upper_cols) + start = perf_counter() df.write.save_as_table(self._qualified(relation), mode="overwrite") + duration_ms = int((perf_counter() - start) * 1000) + bytes_est = self._estimate_frame_bytes(df) + self._record_query_stats( + QueryStats( + bytes_processed=bytes_est, + rows=None, + duration_ms=duration_ms, + ) + ) + + def _estimate_frame_bytes(self, df: SNDF) -> int | None: + """ + Best-effort bytes estimate for a Snowpark DataFrame. + + Strategy: + 1) Use DataFrame.queries["queries"] (public Snowpark API) to get SQL. + 2) Optionally fall back to df._plan.sql() if queries is missing/empty. + 3) Run our existing _estimate_query_bytes(sql_text). + """ + try: + sql_text = self._snowpark_df_sql(df) + if not isinstance(sql_text, str) or not sql_text.strip(): + return None + return self._estimate_query_bytes(sql_text) + except Exception: + return None + + def _snowpark_df_sql(self, df: Any) -> str | None: + """ + Extract the main SQL statement for a Snowpark DataFrame. + + Uses the documented public APIs: + - DataFrame.queries -> {"queries": [sql1, sql2, ...], "post_actions": [...]} + - Optionally falls back to df._plan.sql() if needed. + """ + # 1) Primary source: DataFrame.queries + queries_dict = getattr(df, "queries", None) + + if isinstance(queries_dict, dict): + queries = queries_dict.get("queries") + if isinstance(queries, list) and queries: + # Pick the most likely "main" query. + # Snowflake examples use queries['queries'][0], + # but we can be a bit safer and pick the longest non-empty SQL. + candidates = [q.strip() for q in queries if isinstance(q, str) and q.strip()] + if candidates: + # Heuristic: longest SQL string is usually the main SELECT/CTE. + return max(candidates, key=len) + + # 2) Fallback: internal plan (undocumented but widely used) + plan = getattr(df, "_plan", None) + if plan is not None: + # Prefer simplified plan if available + with suppress(Exception): + simplify = getattr(plan, "simplify", None) + if callable(simplify): + simplified = simplify() + to_sql = getattr(simplified, "sql", None) + if callable(to_sql): + sql = to_sql() + if isinstance(sql, str) and sql.strip(): + return sql.strip() + + # Raw plan.sql() + with suppress(Exception): + to_sql = getattr(plan, "sql", None) + if callable(to_sql): + sql = to_sql() + if isinstance(sql, str) and sql.strip(): + return sql.strip() + + return None def _create_view_over_table(self, view_name: str, backing_table: str, node: Node) -> None: qv = self._qualified(view_name) qb = self._qualified(backing_table) - self.session.sql(f"CREATE OR REPLACE VIEW {qv} AS SELECT * FROM {qb}").collect() + self._execute_sql(f"CREATE OR REPLACE VIEW {qv} AS SELECT * FROM {qb}").collect() def _validate_required( self, node_name: str, inputs: Any, requires: dict[str, set[str]] @@ -168,17 +352,17 @@ def _format_source_reference( return f"{db}.{sch}.{ident}" def _create_or_replace_view(self, target_sql: str, select_body: str, node: Node) -> None: - self.session.sql(f"CREATE OR REPLACE VIEW {target_sql} AS {select_body}").collect() + self._execute_sql(f"CREATE OR REPLACE VIEW {target_sql} AS {select_body}").collect() def _create_or_replace_table(self, target_sql: str, select_body: str, node: Node) -> None: - self.session.sql(f"CREATE OR REPLACE TABLE {target_sql} AS {select_body}").collect() + self._execute_sql(f"CREATE OR REPLACE TABLE {target_sql} AS {select_body}").collect() def _create_or_replace_view_from_table( self, view_name: str, backing_table: str, node: Node ) -> None: view_id = self._qualified(view_name) back_id = self._qualified(backing_table) - self.session.sql(f"CREATE OR REPLACE VIEW {view_id} AS SELECT * FROM {back_id}").collect() + self._execute_sql(f"CREATE OR REPLACE VIEW {view_id} AS SELECT * FROM {back_id}").collect() def _format_test_table(self, table: str | None) -> str | None: formatted = super()._format_test_table(table) @@ -213,24 +397,28 @@ def exists_relation(self, relation: str) -> bool: limit 1 """ try: - return bool(self.session.sql(q).collect()) + return bool(self._execute_sql(q).collect()) except Exception: return False def create_table_as(self, relation: str, select_sql: str) -> None: body = self._selectable_body(select_sql).strip().rstrip(";\n\t ") - self.session.sql(f"CREATE OR REPLACE TABLE {self._qualified(relation)} AS {body}").collect() + self._execute_sql( + f"CREATE OR REPLACE TABLE {self._qualified(relation)} AS {body}" + ).collect() def full_refresh_table(self, relation: str, select_sql: str) -> None: """ Engine-specific full refresh for incremental fallbacks. """ body = self._selectable_body(select_sql).strip().rstrip(";\n\t ") - self.session.sql(f"CREATE OR REPLACE TABLE {self._qualified(relation)} AS {body}").collect() + self._execute_sql( + f"CREATE OR REPLACE TABLE {self._qualified(relation)} AS {body}" + ).collect() def incremental_insert(self, relation: str, select_sql: str) -> None: body = self._selectable_body(select_sql).strip().rstrip(";\n\t ") - self.session.sql(f"INSERT INTO {self._qualified(relation)} {body}").collect() + self._execute_sql(f"INSERT INTO {self._qualified(relation)} {body}").collect() def incremental_merge(self, relation: str, select_sql: str, unique_key: list[str]) -> None: body = self._selectable_body(select_sql).strip().rstrip(";\n\t ") @@ -243,11 +431,11 @@ def incremental_merge(self, relation: str, select_sql: str, unique_key: list[str USING ({body}) AS s WHERE {pred} """ - self.session.sql(delete_sql).collect() + self._execute_sql(delete_sql).collect() # 2) Insert all rows from the delta insert_sql = f"INSERT INTO {qrel} SELECT * FROM ({body})" - self.session.sql(insert_sql).collect() + self._execute_sql(insert_sql).collect() def alter_table_sync_schema( self, relation: str, select_sql: str, *, mode: str = "append_new_columns" @@ -270,7 +458,7 @@ def alter_table_sync_schema( try: existing = { r[0] - for r in self.session.sql( + for r in self._execute_sql( f""" select column_name from {db_ident}.information_schema.columns @@ -293,7 +481,7 @@ def alter_table_sync_schema( # Column names are identifiers → _q is correct here cols_sql = ", ".join(f"{self._q(c)} STRING" for c in to_add) - self.session.sql(f"ALTER TABLE {qrel} ADD COLUMN {cols_sql}").collect() + self._execute_sql(f"ALTER TABLE {qrel} ADD COLUMN {cols_sql}").collect() # ── Snapshot API ───────────────────────────────────────────────────── def run_snapshot_sql(self, node: Node, env: Environment) -> None: @@ -374,14 +562,14 @@ def run_snapshot_sql(self, node: Node, env: Environment) -> None: {hash_expr} AS {hash_col} FROM ({body}) AS s """ - self.session.sql(create_sql).collect() + self._execute_sql(create_sql).collect() return # ---- Incremental snapshot update ---- src_name = f"__ff_snapshot_src_{rel_name}".replace(".", "_") # Use a temporary view for the current source rows - self.session.sql(f"CREATE OR REPLACE TEMPORARY VIEW {src_name} AS {body}").collect() + self._execute_sql(f"CREATE OR REPLACE TEMPORARY VIEW {src_name} AS {body}").collect() try: keys_pred = " AND ".join([f"t.{k} = s.{k}" for k in cfg.unique_key]) or "FALSE" @@ -419,7 +607,7 @@ def run_snapshot_sql(self, node: Node, env: Environment) -> None: AND t.{is_cur} = TRUE AND {change_condition} """ - self.session.sql(close_sql).collect() + self._execute_sql(close_sql).collect() # 2) Insert new current versions (new keys or changed rows) first_key = cfg.unique_key[0] @@ -440,10 +628,10 @@ def run_snapshot_sql(self, node: Node, env: Environment) -> None: t.{first_key} IS NULL OR {change_condition} """ - self.session.sql(insert_sql).collect() + self._execute_sql(insert_sql).collect() finally: with suppress(Exception): - self.session.sql(f"DROP VIEW IF EXISTS {src_name}").collect() + self._execute_sql(f"DROP VIEW IF EXISTS {src_name}").collect() def snapshot_prune( self, @@ -490,7 +678,7 @@ def snapshot_prune( FROM ranked WHERE rn > {int(keep_last)} """ - res_raw = self.session.sql(sql).collect() + res_raw = self._execute_sql(sql).collect() # Snowflake returns a list of Row objects; treat them as tuples for typing. res = cast("list[tuple[Any, ...]]", res_raw) rows = int(res[0][0]) if res else 0 @@ -511,7 +699,7 @@ def snapshot_prune( AND {" AND ".join([f"t.{k} = r.{k}" for k in keys])} AND t.{vf} = r.{vf} """ - self.session.sql(delete_sql).collect() + self._execute_sql(delete_sql).collect() # ────────────────────────── local testing shim ─────────────────────────── diff --git a/src/fastflowtransform/incremental.py b/src/fastflowtransform/incremental.py index bc00d5a..cc70c65 100644 --- a/src/fastflowtransform/incremental.py +++ b/src/fastflowtransform/incremental.py @@ -73,7 +73,13 @@ def _is_merge_not_supported_error(exc: Exception) -> bool: def _exec_sql(exe: Any, sql: str) -> None: """Best-effort SQL execution across engines (DuckDB/PG/Snowflake/BQ shims).""" - if hasattr(exe, "con") and hasattr(exe.con, "execute"): # DuckDB + # Prefer an engine-provided '_execute_sql' hook if available. + hook = getattr(exe, "_execute_sql", None) + if callable(hook): + hook(sql) + return + + if hasattr(exe, "con") and hasattr(exe.con, "execute"): # DuckDB / BQ shim etc. exe.con.execute(sql) return if hasattr(exe, "engine"): # SQLAlchemy Engine diff --git a/src/fastflowtransform/run_executor.py b/src/fastflowtransform/run_executor.py index cd2ba49..75547ed 100644 --- a/src/fastflowtransform/run_executor.py +++ b/src/fastflowtransform/run_executor.py @@ -1,6 +1,7 @@ # src/fastflowtransform/run_executor.py from __future__ import annotations +import os import threading from collections.abc import Callable from concurrent.futures import Future, ThreadPoolExecutor, as_completed @@ -9,7 +10,7 @@ from time import perf_counter from typing import Literal -from .log_queue import LogQueue +from fastflowtransform.log_queue import LogQueue FailPolicy = Literal["fail_fast", "keep_going"] @@ -21,6 +22,15 @@ class ScheduleResult: failed: dict[str, BaseException] +class _NodeFailed(Exception): + """Wrapper to propagate which node inside a batch failed.""" + + def __init__(self, name: str, error: BaseException) -> None: + super().__init__(str(error)) + self.name = name + self.error = error + + # ----------------- Helpers (außerhalb von `schedule`) ----------------- @@ -65,6 +75,103 @@ def _call_before(before_cb: Callable[..., None] | None, name: str, lvl_idx: int) before_cb(name) # Legacy: nur (name) +def _resolve_jobs( + jobs: int | str, + level_size: int, + engine_abbr: str, +) -> int: + """ + Turn the CLI `jobs` param (int or 'auto') into an effective max_workers + for a given level. + + Rules: + - If level is empty → 0 (no executor). + - If jobs is an int: + * <= 0 → treated as 1 + * > 0 → min(jobs, level_size) + - If jobs is 'auto': + * Base concurrency depends on engine + CPU count. + * Never exceed level_size. + - If jobs is an invalid string: + * Fallback to 1. + """ + if level_size <= 0: + return 0 + + # numeric → cap at level_size, enforce >=1 + if isinstance(jobs, int): + if jobs <= 0: + return 1 + return min(jobs, level_size) + + # string → allow 'auto' or numeric fallback + text = jobs.strip().lower() + if text != "auto": + # try to interpret as integer anyway + try: + val = int(text) + except ValueError: + return 1 + return _resolve_jobs(val, level_size, engine_abbr) + + # 'auto' mode + cpu = os.cpu_count() or 4 + + # very rough heuristics per engine: + # - BQ: more conservative; remote service with quotas + # - SPK: allow a bit higher, it parallelizes internally + # - others: moderate parallelism + if engine_abbr == "BQ": + base = min(cpu, 4) + elif engine_abbr == "SPK": + base = min(cpu * 2, 16) + elif engine_abbr in {"DUCK", "PG", "SNOW"}: + base = min(cpu, 8) + else: + base = cpu + + # never exceed level size, always at least 1 + return max(1, min(base, level_size)) + + +def _partition_groups( + names: list[str], + max_workers: int, + durations_s: dict[str, float] | None, +) -> list[list[str]]: + """ + Partition `names` into <= max_workers batches. + + If durations are known, use a greedy bin-packing to balance total + estimated time per batch. Otherwise, fall back to round-robin. + """ + n = len(names) + if n == 0: + return [] + if max_workers <= 1 or n == 1: + return [names] + + k = min(max_workers, n) + + if not durations_s: + groups: list[list[str]] = [[] for _ in range(k)] + for idx, name in enumerate(names): + groups[idx % k].append(name) + return groups + + # Greedy by descending duration + sorted_names = sorted(names, key=lambda nm: durations_s.get(nm, 0.0), reverse=True) + groups = [[] for _ in range(k)] + loads: list[float] = [0.0] * k + + for nm in sorted_names: + j = loads.index(min(loads)) + groups[j].append(nm) + loads[j] += durations_s.get(nm, 0.0) + + return groups + + def _make_task( lvl_idx: int, before_cb: Callable[..., None] | None, @@ -72,7 +179,7 @@ def _make_task( per_node: dict[str, float], per_node_lock: threading.Lock, ) -> Callable[[str], None]: - def _task(name: str, _lvl: int = lvl_idx) -> None: # bindet lvl_idx → kein B023 + def _task(name: str, _lvl: int = lvl_idx) -> None: if before_cb is not None: _call_before(before_cb, name, _lvl) t0 = perf_counter() @@ -89,7 +196,7 @@ def _task(name: str, _lvl: int = lvl_idx) -> None: # bindet lvl_idx → kein B0 def _run_level( lvl_idx: int, names: list[str], - jobs: int, + jobs: int | str, fail_policy: FailPolicy, before_cb: Callable[..., None] | None, run_node: Callable[[str], None], @@ -100,48 +207,97 @@ def _run_level( engine_abbr: str, name_width: int, name_formatter: Callable[[str], str] | None, + durations_s: dict[str, float] | None, ) -> tuple[bool, int, int, int]: """Executes one level and logs. Returns: (had_error, ok_count, fail_count, lvl_ms).""" if not names: return False, 0, 0, 0 - task = _make_task(lvl_idx, before_cb, run_node, per_node, per_node_lock) + # Adaptive per-level worker count (supports '--jobs auto') + max_workers = _resolve_jobs(jobs, len(names), engine_abbr) + if max_workers <= 0: + # no work or something odd → treat as no-op + return False, 0, 0, 0 + + groups = _partition_groups(names, max_workers, durations_s) + lvl_t0 = perf_counter() - ok_in_level = 0 - fail_in_level = 0 level_had_error = False - display_names: dict[str, str] = {} + prev_nodes = set(per_node.keys()) + prev_failed = set(failed.keys()) - with ThreadPoolExecutor(max_workers=max(1, int(jobs)), thread_name_prefix="ff-worker") as pool: - futures: dict[Future[None], str] = {} - for nm in names: + def _group_task(group_names: list[str]) -> None: + for nm in group_names: label = name_formatter(nm) if name_formatter else nm - display_names[nm] = label _log_start(logger, lvl_idx, engine_abbr, label, name_width) - futures[pool.submit(task, nm)] = nm + t0_node = perf_counter() + try: + if before_cb is not None: + _call_before(before_cb, nm, lvl_idx) + run_node(nm) + except BaseException as e: + dt = perf_counter() - t0_node + with per_node_lock: + per_node[nm] = dt + _log_end( + logger, + lvl_idx, + engine_abbr, + label, + False, + int(dt * 1000), + name_width, + ) + # propagate which node failed inside the batch + raise _NodeFailed(nm, e) from e + else: + dt = perf_counter() - t0_node + with per_node_lock: + per_node[nm] = dt + _log_end( + logger, + lvl_idx, + engine_abbr, + label, + True, + int(dt * 1000), + name_width, + ) + + with ThreadPoolExecutor( + max_workers=len(groups), + thread_name_prefix="ff-worker", + ) as pool: + futures: dict[Future[None], list[str]] = {} + for grp in groups: + futures[pool.submit(_group_task, grp)] = grp for fut in as_completed(futures): - nm = futures[fut] - label = display_names.get(nm, nm) try: fut.result() - ok_in_level += 1 - _log_end( - logger, lvl_idx, engine_abbr, label, True, int(per_node[nm] * 1000), name_width - ) + except _NodeFailed as e: + level_had_error = True + failed[e.name] = e.error + if fail_policy == "fail_fast": + for f in futures: + if not f.done(): + f.cancel() except BaseException as e: level_had_error = True - failed[nm] = e - fail_in_level += 1 - ms = int((per_node.get(nm, perf_counter() - lvl_t0)) * 1000) - _log_end(logger, lvl_idx, engine_abbr, label, False, ms, name_width) + # No specific node known; attach under a synthetic key + failed[f""] = e if fail_policy == "fail_fast": for f in futures: if not f.done(): f.cancel() lvl_ms = int((perf_counter() - lvl_t0) * 1000) + new_nodes = set(per_node.keys()) - prev_nodes + new_failed = set(failed.keys()) - prev_failed + ok_in_level = len(new_nodes - new_failed) + fail_in_level = len(new_failed) + _log_level_summary(logger, lvl_idx, ok_in_level, fail_in_level, lvl_ms) return level_had_error, ok_in_level, fail_in_level, lvl_ms @@ -151,7 +307,7 @@ def _run_level( def schedule( levels: list[list[str]], - jobs: int, + jobs: int | str, fail_policy: FailPolicy, run_node: Callable[[str], None], before: Callable[..., None] | None = None, @@ -160,6 +316,7 @@ def schedule( engine_abbr: str = "", name_width: int = 28, name_formatter: Callable[[str], str] | None = None, + durations_s: dict[str, float] | None = None, ) -> ScheduleResult: """Run levels sequentially; within a level run up to `jobs` nodes in parallel.""" per_node: dict[str, float] = {} @@ -182,6 +339,7 @@ def schedule( engine_abbr=engine_abbr, name_width=name_width, name_formatter=name_formatter, + durations_s=durations_s, ) if had_error: if on_error: diff --git a/src/fastflowtransform/seeding.py b/src/fastflowtransform/seeding.py index 2fa9c46..008d4c0 100644 --- a/src/fastflowtransform/seeding.py +++ b/src/fastflowtransform/seeding.py @@ -511,6 +511,10 @@ def _handle_sqlalchemy(table: str, df: pd.DataFrame, executor: Any, schema: str df.to_sql(table, eng, if_exists="replace", index=False, schema=schema, method="multi") dt_ms = int((perf_counter() - t0) * 1000) + if dialect_name.lower() == "postgresql": + with eng.begin() as conn: + conn.exec_driver_sql(f"ANALYZE {full_name}") + dialect = dialect_name or getattr(getattr(eng, "dialect", None), "name", "sqlalchemy") _echo_seed_line( full_name=full_name, diff --git a/src/fastflowtransform/table_formats/__init__.py b/src/fastflowtransform/table_formats/__init__.py index 06dbf2b..110c9bc 100644 --- a/src/fastflowtransform/table_formats/__init__.py +++ b/src/fastflowtransform/table_formats/__init__.py @@ -1,16 +1,16 @@ # fastflowtransform/table_formats/__init__.py from __future__ import annotations +from collections.abc import Callable from typing import Any +from fastflowtransform.table_formats.base import SparkFormatHandler +from fastflowtransform.table_formats.spark_default import DefaultSparkFormatHandler +from fastflowtransform.table_formats.spark_delta import DeltaFormatHandler +from fastflowtransform.table_formats.spark_hudi import HudiFormatHandler +from fastflowtransform.table_formats.spark_iceberg import IcebergFormatHandler from fastflowtransform.typing import SparkSession -from .base import SparkFormatHandler -from .spark_default import DefaultSparkFormatHandler -from .spark_delta import DeltaFormatHandler -from .spark_hudi import HudiFormatHandler -from .spark_iceberg import IcebergFormatHandler - # Mapping: normalized format name -> handler class _SPARK_FORMAT_REGISTRY: dict[str, type[SparkFormatHandler]] = { "delta": DeltaFormatHandler, @@ -37,6 +37,7 @@ def get_spark_format_handler( spark: SparkSession, *, table_options: dict[str, Any] | None = None, + sql_runner: Callable[[str], Any] | None = None, ) -> SparkFormatHandler: """ Factory for SparkFormatHandler based on a logical format name. @@ -56,4 +57,5 @@ def get_spark_format_handler( spark, table_format=fmt or None, table_options=table_options or {}, + sql_runner=sql_runner, ) diff --git a/src/fastflowtransform/table_formats/base.py b/src/fastflowtransform/table_formats/base.py index 5b3b7a7..beed55e 100644 --- a/src/fastflowtransform/table_formats/base.py +++ b/src/fastflowtransform/table_formats/base.py @@ -2,6 +2,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Callable from typing import Any from fastflowtransform.typing import SDF, SparkSession @@ -28,6 +29,7 @@ def __init__( *, table_format: str | None = None, table_options: dict[str, Any] | None = None, + sql_runner: Callable[[str], Any] | None = None, ) -> None: self.spark = spark self.table_format: str | None = (table_format or "").lower() or None @@ -36,6 +38,14 @@ def __init__( str(k): str(v) for k, v in (table_options or {}).items() } + # central hook for executing SQL (can be engine-guarded) + self._sql_runner: Callable[[str], Any] = sql_runner or spark.sql + + # ---- SQL helper ---- + def run_sql(self, sql: str) -> Any: + """Execute SQL via the injected runner (guardable in the executor).""" + return self._sql_runner(sql) + # ---- Identifier helpers ---- def qualify_identifier(self, table_name: str, *, database: str | None = None) -> str: """Return the physical table identifier for Spark APIs (unquoted).""" @@ -95,7 +105,7 @@ def incremental_insert(self, table_name: str, select_body_sql: str) -> None: if not body.lower().startswith("select"): # This is a guard; DatabricksSparkExecutor uses _selectable_body already. raise ValueError(f"incremental_insert expects SELECT body, got: {body[:40]!r}") - self.spark.sql(f"INSERT INTO {table_name} {body}") + self.run_sql(f"INSERT INTO {table_name} {body}") def incremental_merge( self, diff --git a/src/fastflowtransform/table_formats/spark_default.py b/src/fastflowtransform/table_formats/spark_default.py index f1735d4..5fdcef8 100644 --- a/src/fastflowtransform/table_formats/spark_default.py +++ b/src/fastflowtransform/table_formats/spark_default.py @@ -1,6 +1,7 @@ # fastflowtransform/table_formats/spark_default.py from __future__ import annotations +from collections.abc import Callable from typing import Any from fastflowtransform.table_formats.base import SparkFormatHandler @@ -25,8 +26,14 @@ def __init__( *, table_format: str | None = None, table_options: dict[str, Any] | None = None, + sql_runner: Callable[[str], Any] | None = None, ) -> None: - super().__init__(spark, table_format=table_format, table_options=table_options) + super().__init__( + spark, + table_format=table_format, + table_options=table_options, + sql_runner=sql_runner, + ) def save_df_as_table(self, table_name: str, df: SDF) -> None: """ diff --git a/src/fastflowtransform/table_formats/spark_delta.py b/src/fastflowtransform/table_formats/spark_delta.py index b858b8f..1be9325 100644 --- a/src/fastflowtransform/table_formats/spark_delta.py +++ b/src/fastflowtransform/table_formats/spark_delta.py @@ -1,6 +1,7 @@ # fastflowtransform/table_formats/spark_delta.py from __future__ import annotations +from collections.abc import Callable from contextlib import suppress from typing import TYPE_CHECKING, Any @@ -41,8 +42,14 @@ def __init__( spark: SparkSession, *, table_options: dict[str, Any] | None = None, + sql_runner: Callable[[str], Any] | None = None, ) -> None: - super().__init__(spark, table_format="delta", table_options=table_options or {}) + super().__init__( + spark, + table_format="delta", + table_options=table_options or {}, + sql_runner=sql_runner, + ) # ---------- Core helpers ---------- def _delta_table_for(self, table_name: str) -> DeltaTable: @@ -82,7 +89,7 @@ def _restore_table_metadata( ) -> None: if table_comment: with suppress(Exception): - self.spark.sql( + self.run_sql( f"COMMENT ON TABLE {table_ident} IS {self._sql_literal(table_comment)}" ) @@ -98,14 +105,14 @@ def _restore_table_metadata( if assignments: props_sql = ", ".join(assignments) with suppress(Exception): - self.spark.sql(f"ALTER TABLE {table_ident} SET TBLPROPERTIES ({props_sql})") + self.run_sql(f"ALTER TABLE {table_ident} SET TBLPROPERTIES ({props_sql})") for col, comment in column_comments.items(): if not comment: continue col_ident = f"{table_ident}.{self._quote_identifier(col)}" with suppress(Exception): - self.spark.sql(f"COMMENT ON COLUMN {col_ident} IS {self._sql_literal(comment)}") + self.run_sql(f"COMMENT ON COLUMN {col_ident} IS {self._sql_literal(comment)}") # ---------- Required API ---------- def save_df_as_table(self, table_name: str, df: SDF) -> None: @@ -176,7 +183,7 @@ def _writer() -> Any: if not location: # Fallback: drop and recreate if we can't resolve the location. with suppress(Exception): - self.spark.sql(f"DROP TABLE IF EXISTS {table_ident}") + self.run_sql(f"DROP TABLE IF EXISTS {table_ident}") _writer().saveAsTable(table_name) self._restore_table_metadata( table_ident, @@ -228,7 +235,7 @@ def incremental_merge( return # Materialize the source DataFrame for the merge - source_df = self.spark.sql(body) + source_df = self.run_sql(body) # Build the join predicate: t.k = s.k AND ... condition = " AND ".join([f"t.`{k}` = s.`{k}`" for k in unique_key]) diff --git a/src/fastflowtransform/table_formats/spark_hudi.py b/src/fastflowtransform/table_formats/spark_hudi.py index 318b9fd..4dd7598 100644 --- a/src/fastflowtransform/table_formats/spark_hudi.py +++ b/src/fastflowtransform/table_formats/spark_hudi.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Callable from typing import Any from fastflowtransform.table_formats.base import SparkFormatHandler @@ -23,9 +24,15 @@ def __init__( *, default_database: str | None = None, table_options: dict[str, Any] | None = None, + sql_runner: Callable[[str], Any] | None = None, ) -> None: # table_format="hudi" so the base class knows what we're dealing with - super().__init__(spark, table_format="hudi", table_options=table_options or {}) + super().__init__( + spark, + table_format="hudi", + table_options=table_options or {}, + sql_runner=sql_runner, + ) self.default_database = default_database or spark.catalog.currentDatabase() # ---------- Core helpers ---------- @@ -99,7 +106,7 @@ def incremental_insert(self, table_name: str, select_body_sql: str) -> None: raise ValueError(f"incremental_insert expects SELECT body, got: {body[:40]!r}") full_name = self._qualify_table_name(table_name) - self.spark.sql(f"INSERT INTO {full_name} {body}") + self.run_sql(f"INSERT INTO {full_name} {body}") def incremental_merge( self, @@ -127,7 +134,7 @@ def incremental_merge( full_name = self._qualify_table_name(table_name) pred = " AND ".join([f"t.`{k}` = s.`{k}`" for k in unique_key]) - self.spark.sql( + self.run_sql( f""" MERGE INTO {full_name} AS t USING ({body}) AS s diff --git a/src/fastflowtransform/table_formats/spark_iceberg.py b/src/fastflowtransform/table_formats/spark_iceberg.py index 0f8f66f..b4e4b89 100644 --- a/src/fastflowtransform/table_formats/spark_iceberg.py +++ b/src/fastflowtransform/table_formats/spark_iceberg.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Callable from contextlib import suppress from typing import Any @@ -24,11 +25,17 @@ def __init__( spark: SparkSession, *, table_options: dict[str, Any] | None = None, + sql_runner: Callable[[str], Any] | None = None, ) -> None: options = dict(table_options or {}) catalog = options.pop("catalog_name", None) or options.pop("__catalog_name__", None) self.catalog_name = str(catalog) if catalog else "iceberg" - super().__init__(spark, table_format="iceberg", table_options=options) + super().__init__( + spark, + table_format="iceberg", + table_options=options, + sql_runner=sql_runner, + ) # ---------- Core helpers ---------- def _qualify_table_name(self, table_name: str, database: str | None = None) -> str: @@ -90,7 +97,7 @@ def _restore_table_metadata( ) -> None: if table_comment: with suppress(Exception): - self.spark.sql( + self.run_sql( f"COMMENT ON TABLE {table_ident} IS {self._sql_literal(table_comment)}" ) @@ -106,14 +113,14 @@ def _restore_table_metadata( if assignments: props = ", ".join(assignments) with suppress(Exception): - self.spark.sql(f"ALTER TABLE {table_ident} SET TBLPROPERTIES ({props})") + self.run_sql(f"ALTER TABLE {table_ident} SET TBLPROPERTIES ({props})") for name, comment in column_comments.items(): if not comment: continue col_ident = f"{table_ident}.{self._quote_part(name)}" with suppress(Exception): - self.spark.sql(f"COMMENT ON COLUMN {col_ident} IS {self._sql_literal(comment)}") + self.run_sql(f"COMMENT ON COLUMN {col_ident} IS {self._sql_literal(comment)}") # ---------- Required API ---------- def save_df_as_table(self, table_name: str, df: SDF) -> None: @@ -177,7 +184,7 @@ def incremental_insert(self, table_name: str, select_body_sql: str) -> None: raise ValueError(f"incremental_insert expects SELECT body, got: {body[:40]!r}") full_name = self._sql_identifier(table_name) - self.spark.sql(f"INSERT INTO {full_name} {body}") + self.run_sql(f"INSERT INTO {full_name} {body}") def incremental_merge( self, @@ -202,7 +209,7 @@ def incremental_merge( full_name = self._sql_identifier(table_name) pred = " AND ".join([f"t.`{k}` = s.`{k}`" for k in unique_key]) - self.spark.sql( + self.run_sql( f""" MERGE INTO {full_name} AS t USING ({body}) AS s diff --git a/src/fastflowtransform/utils/timefmt.py b/src/fastflowtransform/utils/timefmt.py index 3026602..0e3bcc7 100644 --- a/src/fastflowtransform/utils/timefmt.py +++ b/src/fastflowtransform/utils/timefmt.py @@ -12,3 +12,19 @@ def format_duration_minutes(minutes: float | None) -> str: if mins >= 60: return f"{mins / 60:.1f}h" return f"{mins:.1f}m" + + +def _format_duration_ms(ms: int) -> str: + """ + Human-friendly duration from milliseconds. + """ + if ms < 1000: + return f"{ms} ms" + sec = ms / 1000.0 + if sec < 60: + return f"{sec:.1f} s" + minutes = sec / 60.0 + if minutes < 60: + return f"{minutes:.1f} min" + hours = minutes / 60.0 + return f"{hours:.1f} h" diff --git a/tests/common/fixtures.py b/tests/common/fixtures.py index ccfaef7..47e9eb9 100644 --- a/tests/common/fixtures.py +++ b/tests/common/fixtures.py @@ -1,8 +1,6 @@ -# tests/common/fixtures.py from __future__ import annotations import os -import types from pathlib import Path from types import SimpleNamespace from typing import TYPE_CHECKING, Any @@ -11,55 +9,41 @@ import pandas as pd import pytest import sqlalchemy as sa -from dotenv import load_dotenv from jinja2 import DictLoader, Environment, FileSystemLoader, select_autoescape from sqlalchemy import text import fastflowtransform.executors.bigquery.base as bq_base import fastflowtransform.executors.bigquery.pandas as bq_pandas import fastflowtransform.typing as fft_typing +from fastflowtransform import utest +from fastflowtransform.core import REGISTRY from fastflowtransform.executors.bigquery.pandas import BigQueryExecutor from tests.common.mock.bigquery import install_fake_bigquery from tests.common.mock.snowflake_snowpark import FakeSnowflakeSession +from tests.common.utils import ROOT if TYPE_CHECKING: # pragma: no cover - typing only - import psycopg - from psycopg import sql - from fastflowtransform.executors.databricks_spark import ( DatabricksSparkExecutor as DatabricksSparkExecutorType, ) else: - try: - import psycopg # type: ignore - from psycopg import sql # type: ignore - except ModuleNotFoundError: - psycopg = None # type: ignore - sql = None # type: ignore - DatabricksSparkExecutorType = Any -try: # Optional: Spark deps may not be installed in core runs - from fastflowtransform.executors.databricks_spark import DatabricksSparkExecutor -except ModuleNotFoundError: # pragma: no cover - import guard - DatabricksSparkExecutor = None # type: ignore - -from fastflowtransform import utest -from fastflowtransform.core import REGISTRY -from tests.common.utils import ROOT - -try: # Optional: Spark deps may not be installed in core runs - from fastflowtransform.executors.databricks_spark import DatabricksSparkExecutor -except ModuleNotFoundError: # pragma: no cover - import guard - DatabricksSparkExecutor = None # type: ignore +# ---- Optional executors ------------------------------------------------------ +# Postgres +try: + from fastflowtransform.executors.postgres import PostgresExecutor +except ModuleNotFoundError: # pragma: no cover - optional dep + PostgresExecutor = None # type: ignore[assignment] +# Spark (Databricks) try: # Optional: Spark deps may not be installed in core runs from fastflowtransform.executors.databricks_spark import DatabricksSparkExecutor except ModuleNotFoundError: # pragma: no cover - import guard - DatabricksSparkExecutor = None # type: ignore + DatabricksSparkExecutor = None # type: ignore[assignment] -# --- Snowflake ---------------------------------------------------- +# Snowflake try: from fastflowtransform.executors.snowflake_snowpark import ( SnowflakeSnowparkExecutor, @@ -70,26 +54,13 @@ _SFCursorShim = None # type: ignore[assignment] -# ---- Load Env Variables ---- -@pytest.fixture(scope="session", autouse=True) -def load_test_env(): - candidates = [ - ROOT / "tests" / ".env", - ROOT / "tests" / ".env.dev_databricks", - ROOT / "tests" / ".env.dev_duckdb", - ROOT / "tests" / ".env.dev_postgres", - ] +# ---- Jinja env ---------------------------------------------------------------- - for env_file in candidates: - if env_file.is_file(): - load_dotenv(env_file, override=False) - -# ---- Jinja Env ---- @pytest.fixture(scope="session") -def jinja_env(): +def jinja_env() -> Environment: """ - Minimal Jinja2 rnvironment for rendering-tests. + Minimal Jinja2 environment for rendering tests. """ env = Environment( loader=FileSystemLoader(["."]), @@ -109,38 +80,53 @@ def jinja_env(): return env -# ---- DuckDB ---- +# ---- DuckDB ------------------------------------------------------------------- + + @pytest.fixture(scope="session") -def duckdb_project(): +def duckdb_project() -> Path: return ROOT / "examples" / "simple_duckdb" @pytest.fixture(scope="session") -def duckdb_db_path(duckdb_project): +def duckdb_db_path(duckdb_project: Path) -> Path: p = duckdb_project / ".local" / "demo.duckdb" p.parent.mkdir(parents=True, exist_ok=True) return p @pytest.fixture(scope="session") -def duckdb_env(duckdb_db_path): +def duckdb_env(duckdb_db_path: Path) -> dict[str, str]: return {"FF_ENGINE": "duckdb", "FF_DUCKDB_PATH": str(duckdb_db_path)} -# ---- Postgres ---- +# ---- Postgres ----------------------------------------------------------------- + + @pytest.fixture(scope="session") -def pg_project(): +def pg_project() -> Path: return ROOT / "examples" / "postgres" @pytest.fixture(scope="session") -def pg_env(): +def pg_env() -> dict[str, str]: dsn = os.environ.get("FF_PG_DSN", "postgresql+psycopg://postgres:postgres@localhost:5432/ffdb") schema = os.environ.get("FF_PG_SCHEMA", "public") return {"FF_ENGINE": "postgres", "FF_PG_DSN": dsn, "FF_PG_SCHEMA": schema} -# ---- Spark ---- +@pytest.fixture(scope="session") +def postgres_exec(): + dsn = os.environ.get("FF_PG_DSN", "postgresql+psycopg://postgres:postgres@localhost:5432/ffdb") + schema = os.environ.get("FF_PG_SCHEMA", "public") + if PostgresExecutor is None: + pytest.skip("FF_PG_DSN not set; skipping Postgres snapshot tests") + return PostgresExecutor(dsn=dsn, schema=schema) + + +# ---- Spark -------------------------------------------------------------------- + + @pytest.fixture def exec_minimal(monkeypatch): if DatabricksSparkExecutor is None: @@ -151,6 +137,9 @@ def exec_minimal(monkeypatch): fake_spark = MagicMock() SP.builder.master.return_value.appName.return_value.getOrCreate.return_value = fake_spark ex = DatabricksSparkExecutor() + # JVM plan inspection loops forever on MagicMocks; skip in unit tests. + monkeypatch.setattr(ex, "_spark_plan_bytes", lambda *_, **__: None) + monkeypatch.setattr(ex, "_spark_dataframe_bytes", lambda *_, **__: None) # accept mocks as frames in unit tests monkeypatch.setattr(ex, "_is_frame", lambda obj: True) return ex @@ -161,6 +150,7 @@ def exec_factory(): """ Build a DatabricksSparkExecutor with arbitrary __init__ kwargs, but always with mocked SparkSession (no real JVM). + Returns (executor, fake_builder, fake_spark). """ @@ -181,6 +171,8 @@ def _make(**kwargs) -> tuple[DatabricksSparkExecutorType, Any, MagicMock]: fake_builder.getOrCreate.return_value = fake_spark ex = DatabricksSparkExecutor(**kwargs) + ex._spark_plan_bytes = lambda *_, **__: None + ex._spark_dataframe_bytes = lambda *_, **__: None return ex, fake_builder, fake_spark return _make @@ -206,7 +198,7 @@ def spark_exec(spark_tmpdir: Path) -> DatabricksSparkExecutorType: @pytest.fixture(scope="session") -def spark_exec_delta(spark_tmpdir): +def spark_exec_delta(spark_tmpdir: Path) -> DatabricksSparkExecutorType: if DatabricksSparkExecutor is None: pytest.skip( "pyspark/delta not installed; install fastflowtransform[spark] to run Spark tests" @@ -227,9 +219,11 @@ def spark_exec_delta(spark_tmpdir): ) -# ---- utest ---- +# ---- utest helpers ------------------------------------------------------------ + + @pytest.fixture -def fake_registry(tmp_path, monkeypatch): +def fake_registry(tmp_path: Path, monkeypatch) -> SimpleNamespace: node = SimpleNamespace(name="model_a", kind="sql", deps=["src1"]) reg = SimpleNamespace( nodes={"model_a": node}, @@ -282,9 +276,11 @@ def __init__(self, engine): return PgEx(engine) -# ---- Examples ---- +# ---- Example engine envs ------------------------------------------------------ + + @pytest.fixture(scope="session") -def duckdb_engine_env(tmp_path_factory): +def duckdb_engine_env(tmp_path_factory: pytest.TempPathFactory) -> dict[str, str]: """Basic env for DuckDB examples.""" db_dir = tmp_path_factory.mktemp("duckdb") db_path = db_dir / "examples.duckdb" @@ -295,7 +291,7 @@ def duckdb_engine_env(tmp_path_factory): @pytest.fixture(scope="session") -def postgres_engine_env(): +def postgres_engine_env() -> dict[str, str]: """Basic env for Postgres. Skipped if DSN is missing or DB not reachable.""" dsn = os.environ.get( "FF_PG_DSN", @@ -303,12 +299,12 @@ def postgres_engine_env(): ) schema = os.environ.get("FF_PG_SCHEMA", "public") - # Optional: Connectivity-Check + # Optional: connectivity check try: engine = sa.create_engine(dsn) with engine.connect() as conn: conn.execute(text("select 1")) - except Exception as exc: + except Exception as exc: # pragma: no cover - integration guard pytest.skip(f"Postgres not reachable at DSN={dsn!r}: {exc}") return { @@ -319,7 +315,7 @@ def postgres_engine_env(): @pytest.fixture(scope="session") -def spark_engine_env(tmp_path_factory): +def spark_engine_env(tmp_path_factory: pytest.TempPathFactory) -> dict[str, str]: """Basic env for Databricks-Spark-Executor. Skipped if JAVA_HOME is missing.""" if not os.environ.get("JAVA_HOME"): pytest.skip("JAVA_HOME not set for Spark tests") @@ -334,36 +330,11 @@ def spark_engine_env(tmp_path_factory): } -@pytest.fixture(scope="session") -def bigquery_engine_env(): - """ - Basic env for BigQuery examples. Skips if required env vars are missing. - """ - project = os.environ.get("FF_BQ_PROJECT") - dataset = os.environ.get("FF_BQ_DATASET") - location = os.environ.get("FF_BQ_LOCATION") - - if not (project and dataset and location): - pytest.skip("FF_BQ_PROJECT/FF_BQ_DATASET/FF_BQ_LOCATION not set for BigQuery tests") - - env = { - "FF_ENGINE": "bigquery", - "FF_ENGINE_VARIANT": os.environ.get("FF_ENGINE_VARIANT", "bigframes"), - "FF_BQ_PROJECT": project, - "FF_BQ_DATASET": dataset, - "FF_BQ_LOCATION": location, - } - - creds = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS") - if creds: - env["GOOGLE_APPLICATION_CREDENTIALS"] = creds +# ---- Snowflake Snowpark ------------------------------------------------------- - return env - -# ---- Snowflake Snowpark ---- @pytest.fixture(scope="session") -def snowflake_engine_env(): +def snowflake_engine_env() -> dict[str, str]: """ Basic env for Snowflake Snowpark examples / integration tests. @@ -388,14 +359,14 @@ def snowflake_engine_env(): "FF_SF_WAREHOUSE, FF_SF_DATABASE, FF_SF_SCHEMA)." ) - env = { + env: dict[str, str] = { "FF_ENGINE": "snowflake_snowpark", - "FF_SF_ACCOUNT": account, - "FF_SF_USER": user, - "FF_SF_PASSWORD": password, - "FF_SF_WAREHOUSE": warehouse, - "FF_SF_DATABASE": database, - "FF_SF_SCHEMA": schema, + "FF_SF_ACCOUNT": account or "", + "FF_SF_USER": user or "", + "FF_SF_PASSWORD": password or "", + "FF_SF_WAREHOUSE": warehouse or "", + "FF_SF_DATABASE": database or "", + "FF_SF_SCHEMA": schema or "", "FF_SF_ALLOW_CREATE_SCHEMA": allow_create_schema, } if role: @@ -404,7 +375,7 @@ def snowflake_engine_env(): @pytest.fixture(scope="session") -def snowflake_cfg(snowflake_engine_env): +def snowflake_cfg(snowflake_engine_env: dict[str, str]) -> dict[str, Any]: """ Canonical config dict for SnowflakeSnowparkExecutor, derived from env. """ @@ -426,6 +397,7 @@ def snowflake_cfg(snowflake_engine_env): def snowflake_executor_fake() -> Any: """ Fake SnowflakeSnowparkExecutor using an in-memory FakeSnowflakeSession. + This does NOT talk to a real Snowflake account and does not need env vars. It is intended for SQL-shape tests, similar to the BigQuery fake. """ @@ -447,17 +419,19 @@ def snowflake_executor_fake() -> Any: ex.con = _SFCursorShim(session) # type: ignore[arg-type] else: # Cheap fallback if for some reason the shim isn't available - ex.con = types.SimpleNamespace(execute=lambda sql, params=None: None) + ex.con = SimpleNamespace(execute=lambda sql, params=None: None) return ex +# ---- BigQuery ----------------------------------------------------------------- + + @pytest.fixture(scope="session") def jinja_env_bigquery() -> Environment: """ Very small Jinja environment for BigQuery unit tests. - (You can also reuse your global jinja_env from tests/common/fixtures.py - and just import that instead.) + (You can also reuse your global jinja_env from this module and import it.) """ env = Environment( loader=DictLoader({}), @@ -475,6 +449,7 @@ def jinja_env_bigquery() -> Environment: def bq_executor_fake(monkeypatch) -> BigQueryExecutor: """ BigQueryExecutor wired against the FakeClient from tests/common/mock/bigquery.py. + No real BigQuery project / dataset / credentials required. """ # Make sure all FFT modules that cache a `bigquery` symbol diff --git a/tests/conftest.py b/tests/conftest.py index fe0f52f..a208f47 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,2 +1,27 @@ # tests/conftest.py +import pytest + pytest_plugins = ["tests.common.fixtures"] + + +def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None: + # Marker expression from CLI: -m "..." + markexpr = getattr(config.option, "markexpr", "") or "" + + # If the user explicitly mentions remote_engine in -m, + # we assume they *want* those tests. Don't deselect them. + if "remote_engine" in markexpr: + return + + deselected: list[pytest.Item] = [] + kept: list[pytest.Item] = [] + + for item in items: + if "remote_engine" in item.keywords: + deselected.append(item) + else: + kept.append(item) + + if deselected: + config.hook.pytest_deselected(items=deselected) + items[:] = kept diff --git a/tests/integration/examples/config.py b/tests/integration/examples/config.py index 0973195..d4c3ba4 100644 --- a/tests/integration/examples/config.py +++ b/tests/integration/examples/config.py @@ -14,6 +14,7 @@ class ExampleConfig: make_target: str env_by_engine: dict[str, str] spark_table_formats: list[str] | None = None + bigquery_env_by_backend: dict[str, str] | None = None EXAMPLES: list[ExampleConfig] = [ @@ -25,6 +26,12 @@ class ExampleConfig: "duckdb": "dev_duckdb", "postgres": "dev_postgres", "databricks_spark": "dev_databricks", + "bigquery": "dev_bigquery_pandas", + "snowflake_snowpark": "dev_snowflake", + }, + bigquery_env_by_backend={ + "pandas": "dev_bigquery_pandas", + "bigframes": "dev_bigquery_bigframes", }, ), ExampleConfig( @@ -35,6 +42,12 @@ class ExampleConfig: "duckdb": "dev_duckdb", "postgres": "dev_postgres", "databricks_spark": "dev_databricks", + "bigquery": "dev_bigquery_pandas", + "snowflake_snowpark": "dev_snowflake", + }, + bigquery_env_by_backend={ + "pandas": "dev_bigquery_pandas", + "bigframes": "dev_bigquery_bigframes", }, ), ExampleConfig( @@ -45,6 +58,13 @@ class ExampleConfig: "duckdb": "dev_duckdb", "postgres": "dev_postgres", "databricks_spark": "dev_databricks", + "bigquery": "dev_bigquery_pandas", + "snowflake_snowpark": "dev_snowflake", + }, + bigquery_env_by_backend={ + "pandas": "dev_bigquery_pandas", + "bigframes": "dev_bigquery_bigframes", + "bigquery": "dev_bigquery_pandas", }, ), ExampleConfig( @@ -55,6 +75,12 @@ class ExampleConfig: "duckdb": "dev_duckdb", "postgres": "dev_postgres", "databricks_spark": "dev_databricks", + "bigquery": "dev_bigquery_pandas", + "snowflake_snowpark": "dev_snowflake", + }, + bigquery_env_by_backend={ + "pandas": "dev_bigquery_pandas", + "bigframes": "dev_bigquery_bigframes", }, ), ExampleConfig( @@ -65,6 +91,12 @@ class ExampleConfig: "duckdb": "dev_duckdb", "postgres": "dev_postgres", "databricks_spark": "dev_databricks", + "bigquery": "dev_bigquery_pandas", + "snowflake_snowpark": "dev_snowflake", + }, + bigquery_env_by_backend={ + "pandas": "dev_bigquery_pandas", + "bigframes": "dev_bigquery_bigframes", }, ), ExampleConfig( @@ -75,8 +107,14 @@ class ExampleConfig: "duckdb": "dev_duckdb", "postgres": "dev_postgres", "databricks_spark": "dev_databricks_parquet", + "bigquery": "dev_bigquery_pandas", + "snowflake_snowpark": "dev_snowflake", }, spark_table_formats=["parquet", "delta", "iceberg"], + bigquery_env_by_backend={ + "pandas": "dev_bigquery_pandas", + "bigframes": "dev_bigquery_bigframes", + }, ), ExampleConfig( name="macros_demo", @@ -86,6 +124,12 @@ class ExampleConfig: "duckdb": "dev_duckdb", "postgres": "dev_postgres", "databricks_spark": "dev_databricks", + "bigquery": "dev_bigquery_pandas", + "snowflake_snowpark": "dev_snowflake", + }, + bigquery_env_by_backend={ + "pandas": "dev_bigquery_pandas", + "bigframes": "dev_bigquery_bigframes", }, ), ExampleConfig( @@ -96,6 +140,12 @@ class ExampleConfig: "duckdb": "dev_duckdb", "postgres": "dev_postgres", "databricks_spark": "dev_databricks", + "bigquery": "dev_bigquery_pandas", + "snowflake_snowpark": "dev_snowflake", + }, + bigquery_env_by_backend={ + "pandas": "dev_bigquery_pandas", + "bigframes": "dev_bigquery_bigframes", }, ), ExampleConfig( @@ -106,7 +156,13 @@ class ExampleConfig: "duckdb": "dev_duckdb", "postgres": "dev_postgres", "databricks_spark": "dev_databricks_parquet", + "bigquery": "dev_bigquery_pandas", + "snowflake_snowpark": "dev_snowflake", }, spark_table_formats=["parquet", "delta", "iceberg"], + bigquery_env_by_backend={ + "pandas": "dev_bigquery_pandas", + "bigframes": "dev_bigquery_bigframes", + }, ), ] diff --git a/tests/integration/examples/test_examples_matrix.py b/tests/integration/examples/test_examples_matrix.py index d6affe5..2e3872a 100644 --- a/tests/integration/examples/test_examples_matrix.py +++ b/tests/integration/examples/test_examples_matrix.py @@ -33,6 +33,8 @@ def _run_cmd(cmd: list[str], cwd: Path, extra_env: dict[str, str] | None = None) "duckdb": pytest.mark.duckdb, "postgres": pytest.mark.postgres, "databricks_spark": pytest.mark.databricks_spark, + "bigquery": pytest.mark.bigquery, + "snowflake_snowpark": pytest.mark.snowflake_snowpark, } # build only the actually-supported (example, engine) combinations @@ -41,44 +43,58 @@ def _run_cmd(cmd: list[str], cwd: Path, extra_env: dict[str, str] | None = None) example, engine, id=f"{example.name}[{engine}]", - marks=ENGINE_MARKS[engine], + marks=[ + ENGINE_MARKS[engine], + *([pytest.mark.remote_engine] if engine in {"bigquery", "snowflake_snowpark"} else []), + ], ) for example in EXAMPLES for engine in example.env_by_engine ] -ENGINE_ENV_FIXTURE = { - "duckdb": "duckdb_engine_env", - "postgres": "postgres_engine_env", - "databricks_spark": "spark_engine_env", -} - @pytest.mark.integration @pytest.mark.example @pytest.mark.flaky(reruns=2, reruns_delay=2) @pytest.mark.parametrize("example,engine", EXAMPLE_ENGINE_PARAMS) def test_examples_with_all_engines(example, engine, request): - fixture_name = ENGINE_ENV_FIXTURE[engine] - engine_env: dict[str, str] = request.getfixturevalue(fixture_name) + # fixture_name = ENGINE_ENV_FIXTURE[engine] + # engine_env: dict[str, str] = request.getfixturevalue(fixture_name) - env = dict(engine_env) - env["FFT_ACTIVE_ENV"] = example.env_by_engine[engine] + # env = dict(engine_env) + # env["FFT_ACTIVE_ENV"] = example.env_by_engine[engine] cmd = ["make", example.make_target, f"ENGINE={engine}"] + base_env: dict[str, str] = {} if engine == "databricks_spark": + base_env["FFT_ACTIVE_ENV"] = example.env_by_engine[engine] formats = example.spark_table_formats or ["parquet"] for fmt in formats: - env_fmt = dict(env) + env_fmt = dict(base_env) env_fmt["DBR_TABLE_FORMAT"] = fmt _run_cmd( [*cmd, f"DBR_TABLE_FORMAT={fmt}"], cwd=example.path, extra_env=env_fmt, ) + + elif engine == "bigquery": + # Loop over backends: pandas / bigframes + backends = example.bigquery_env_by_backend + for backend, env_name in backends.items(): + env_backend = dict(base_env) + env_backend["FFT_ACTIVE_ENV"] = env_name + env_backend["BQ_FRAME"] = backend + _run_cmd( + [*cmd, f"BQ_FRAME={backend}"], + cwd=example.path, + extra_env=env_backend, + ) + else: - _run_cmd(cmd, cwd=example.path, extra_env=env) + base_env["FFT_ACTIVE_ENV"] = example.env_by_engine[engine] + _run_cmd(cmd, cwd=example.path, extra_env=base_env) target_dir = example.path / ".fastflowtransform" / "target" manifest = target_dir / "manifest.json" diff --git a/tests/integration/executors/postgres/test_snapshots_postgres_integration.py b/tests/integration/executors/postgres/test_snapshots_postgres_integration.py index c697b57..3a7031f 100644 --- a/tests/integration/executors/postgres/test_snapshots_postgres_integration.py +++ b/tests/integration/executors/postgres/test_snapshots_postgres_integration.py @@ -1,8 +1,6 @@ # tests/unit/test_snapshots_postgres.py from __future__ import annotations -import os - import pandas as pd import pytest from sqlalchemy import text @@ -58,14 +56,6 @@ def _reset_snapshot_table(executor: PostgresExecutor, node_name: str) -> None: conn.execute(text(f"DROP TABLE IF EXISTS {qrel} CASCADE")) -def _make_pg_executor() -> PostgresExecutor: - dsn = os.environ.get("FF_PG_DSN") - schema = os.environ.get("FF_PG_SCHEMA", "public") - if not dsn: - pytest.skip("FF_PG_DSN not set; skipping Postgres snapshot tests") - return PostgresExecutor(dsn=dsn, schema=schema) - - def _read_pg(ex: PostgresExecutor, relation: str) -> pd.DataFrame: qualified = ex._qualified(relation) with ex.engine.begin() as conn: @@ -75,13 +65,12 @@ def _read_pg(ex: PostgresExecutor, relation: str) -> pd.DataFrame: @pytest.mark.postgres @pytest.mark.integration -def test_postgres_snapshot_timestamp_first_and_second_run(jinja_env): - ex = _make_pg_executor() +def test_postgres_snapshot_timestamp_first_and_second_run(jinja_env, postgres_exec): node = make_timestamp_snapshot_node() - _reset_snapshot_table(executor=ex, node_name=node.name) + _reset_snapshot_table(executor=postgres_exec, node_name=node.name) scenario_timestamp_first_and_second_run( - executor=ex, + executor=postgres_exec, node=node, jinja_env=jinja_env, read_fn=_read_pg, @@ -92,13 +81,12 @@ def test_postgres_snapshot_timestamp_first_and_second_run(jinja_env): @pytest.mark.postgres @pytest.mark.integration -def test_postgres_snapshot_prune_keep_last(jinja_env): - ex = _make_pg_executor() +def test_postgres_snapshot_prune_keep_last(jinja_env, postgres_exec): node = make_timestamp_snapshot_node() - _reset_snapshot_table(executor=ex, node_name=node.name) + _reset_snapshot_table(executor=postgres_exec, node_name=node.name) scenario_snapshot_prune_keep_last( - executor=ex, + executor=postgres_exec, node=node, jinja_env=jinja_env, read_fn=_read_pg, @@ -110,13 +98,12 @@ def test_postgres_snapshot_prune_keep_last(jinja_env): @pytest.mark.postgres @pytest.mark.integration -def test_postgres_snapshot_check_strategy(jinja_env): - ex = _make_pg_executor() +def test_postgres_snapshot_check_strategy(jinja_env, postgres_exec): node = make_check_snapshot_node() - _reset_snapshot_table(executor=ex, node_name=node.name) + _reset_snapshot_table(executor=postgres_exec, node_name=node.name) scenario_check_strategy_detects_changes( - executor=ex, + executor=postgres_exec, node=node, jinja_env=jinja_env, read_fn=_read_pg, diff --git a/tests/unit/cache/test_cache_policy_cli.py b/tests/unit/cache/test_cache_policy_cli.py index 5b181bb..328b619 100644 --- a/tests/unit/cache/test_cache_policy_cli.py +++ b/tests/unit/cache/test_cache_policy_cli.py @@ -79,6 +79,7 @@ def schedule( engine_abbr="", name_width=28, name_formatter=None, + durations_s=None, ): start = time.perf_counter() per: dict[str, float] = {} diff --git a/tests/unit/executors/test_databricks_spark_exec_unit.py b/tests/unit/executors/test_databricks_spark_exec_unit.py index ed3711d..db83455 100644 --- a/tests/unit/executors/test_databricks_spark_exec_unit.py +++ b/tests/unit/executors/test_databricks_spark_exec_unit.py @@ -287,7 +287,6 @@ def test__create_view_over_table_executes_expected_sql(exec_minimal): "v_users", "t_users", Node(name="n", kind="sql", path=Path(".")) ) - exec_minimal.spark.sql.assert_called_once() sql = exec_minimal.spark.sql.call_args[0][0] assert "CREATE OR REPLACE VIEW `v_users` AS SELECT * FROM `t_users`" in sql diff --git a/tests/unit/executors/test_postgres_exec_unit.py b/tests/unit/executors/test_postgres_exec_unit.py index 909f6d9..46076cc 100644 --- a/tests/unit/executors/test_postgres_exec_unit.py +++ b/tests/unit/executors/test_postgres_exec_unit.py @@ -528,14 +528,13 @@ def test_create_or_replace_view_from_table_happy(fake_engine_and_conn): or "CREATE OR REPLACE VIEW" in str(stmt) ] - # now we should have exactly the 3 we expect - expected_statement_len = 3 + expected_statement_len = 5 assert len(stmts) == expected_statement_len assert 'SET LOCAL search_path = "public"' in stmts[0][0] assert 'DROP VIEW IF EXISTS "public"."v_out" CASCADE' in stmts[1][0] assert ( - 'CREATE OR REPLACE VIEW "public"."v_out" AS SELECT * FROM "public"."src_tbl"' in stmts[2][0] + 'CREATE OR REPLACE VIEW "public"."v_out" AS SELECT * FROM "public"."src_tbl"' in stmts[4][0] ) diff --git a/tests/unit/executors/test_snowflake_snowpark_exec.py b/tests/unit/executors/test_snowflake_snowpark_exec.py index e018c78..fdd1fd4 100644 --- a/tests/unit/executors/test_snowflake_snowpark_exec.py +++ b/tests/unit/executors/test_snowflake_snowpark_exec.py @@ -502,9 +502,15 @@ def test_incremental_merge_builds_two_statements(sf_exec): sf_exec.incremental_merge("DST", "SELECT 1 AS id", ["id"]) - # We now emit two separate statements: DELETE ... and INSERT ... - assert len(sf_exec.session.sql_calls) == 2 - delete_sql, insert_sql = sf_exec.session.sql_calls + # Each statement incurs an EXPLAIN + actual execution → 4 SQL calls total. + assert len(sf_exec.session.sql_calls) == 4 + dml_calls = [ + sql + for sql in sf_exec.session.sql_calls + if sql.strip().upper().startswith("DELETE") or sql.strip().upper().startswith("INSERT") + ] + assert len(dml_calls) == 2 + delete_sql, insert_sql = dml_calls # First: DELETE FROM ... USING () AS s WHERE ... assert "DELETE FROM" in delete_sql diff --git a/uv.lock b/uv.lock index 80f8d10..25af95b 100644 --- a/uv.lock +++ b/uv.lock @@ -733,7 +733,7 @@ wheels = [ [[package]] name = "fastflowtransform" -version = "0.6.7" +version = "0.6.8" source = { editable = "." } dependencies = [ { name = "duckdb" }, From c2f90b39d967e93306edcde2eb2ad27db5becba8 Mon Sep 17 00:00:00 2001 From: Marko Lekic Date: Tue, 25 Nov 2025 17:15:12 +0100 Subject: [PATCH 2/2] Added pypi release yml --- .github/workflows/release.yml | 77 +++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 .github/workflows/release.yml diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..ffccf58 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,77 @@ +name: Release + +on: + push: + branches: + - main + +jobs: + build-and-publish: + name: Build and publish to (Test)PyPI + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install build tooling + run: | + python -m pip install --upgrade pip + python -m pip install build twine + + - name: Build sdist / wheel + run: python -m build + + - name: Read version from pyproject.toml + id: meta + run: | + python - << 'PY' + import pathlib, tomllib + + data = tomllib.loads(pathlib.Path("pyproject.toml").read_text()) + version = data.get("project", {}).get("version") + if not version: + raise SystemExit("No [project].version found in pyproject.toml") + print(f"Version: {version}") + with open(os.environ["GITHUB_OUTPUT"], "a", encoding="utf8") as fh: + fh.write(f"version={version}\n") + PY + + - name: Decide repository (TestPyPI vs PyPI) + id: target + shell: bash + run: | + version="${{ steps.meta.outputs.version }}" + echo "Detected version: $version" + + if [[ "$version" == *"rc"* ]]; then + echo "repo=testpypi" >> "$GITHUB_OUTPUT" + else + echo "repo=pypi" >> "$GITHUB_OUTPUT" + fi + + - name: Publish to TestPyPI + if: steps.target.outputs.repo == 'testpypi' + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.TEST_PYPI_API_TOKEN }} + run: | + twine upload \ + --repository-url https://test.pypi.org/legacy/ \ + --skip-existing \ + dist/* + + - name: Publish to PyPI + if: steps.target.outputs.repo == 'pypi' + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} + run: | + twine upload \ + --skip-existing \ + dist/*